Skip to content

Commit

Permalink
Validate subscribe IDs in most APIs
Browse files Browse the repository at this point in the history
Summary: Some sanity checks that make subsequent refactors a little smaller

Reviewed By: NEUDitao

Differential Revision: D67147834

fbshipit-source-id: 89177bc42212dc132526e9529a13a551956c0832
  • Loading branch information
afrind authored and facebook-github-bot committed Dec 12, 2024
1 parent 867cedb commit 874e587
Showing 1 changed file with 102 additions and 10 deletions.
112 changes: 102 additions & 10 deletions moxygen/MoQSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,22 +409,30 @@ void MoQSession::onSubscribe(SubscribeRequest subscribeRequest) {
// MoQForwarder] Track Alias -> Track Name
// If ths session holds this state, it can check for duplicate
// subscriptions
pubTracks_[subscribeID].priority = subscribeRequest.priority;
auto it = pubTracks_.find(subscribeRequest.subscribeID);
if (it != pubTracks_.end()) {
XLOG(ERR) << "Duplicate subscribe ID=" << subscribeRequest.subscribeID
<< " sess=" << this;
subscribeError({subscribeRequest.subscribeID, 400, "dup sub ID"});
return;
}
pubTracks_[subscribeRequest.subscribeID].priority = subscribeRequest.priority;
controlMessages_.enqueue(std::move(subscribeRequest));
}

void MoQSession::onSubscribeUpdate(SubscribeUpdate subscribeUpdate) {
XLOG(DBG1) << __func__ << " sess=" << this;
const auto subscribeID = subscribeUpdate.subscribeID;
if (!pubTracks_.contains(subscribeID)) {
auto it = pubTracks_.find(subscribeID);
if (it == pubTracks_.end()) {
XLOG(ERR) << "No matching subscribe ID=" << subscribeID << " sess=" << this;
return;
}
if (closeSessionIfSubscribeIdInvalid(subscribeID)) {
return;
}

pubTracks_[subscribeID].priority = subscribeUpdate.priority;
it->second.priority = subscribeUpdate.priority;
// TODO: update priority of tracks in flight
controlMessages_.enqueue(std::move(subscribeUpdate));
}
Expand Down Expand Up @@ -465,7 +473,9 @@ void MoQSession::onSubscribeError(SubscribeError subErr) {
}

void MoQSession::onSubscribeDone(SubscribeDone subscribeDone) {
XLOG(DBG1) << __func__ << " sess=" << this;
XLOG(DBG1) << "SubscribeDone id=" << subscribeDone.subscribeID
<< " code=" << folly::to_underlying(subscribeDone.statusCode)
<< " reason=" << subscribeDone.reasonPhrase;
auto trackAliasIt = subIdToTrackAlias_.find(subscribeDone.subscribeID);
if (trackAliasIt == subIdToTrackAlias_.end()) {
// unknown
Expand All @@ -478,11 +488,18 @@ void MoQSession::onSubscribeDone(SubscribeDone subscribeDone) {
// TODO: there could still be objects in flight. Removing from maps now
// will prevent their delivery. I think the only way to handle this is with
// timeouts.
subTracks_[trackAliasIt->second]->fin();
subTracks_.erase(trackAliasIt->second);
auto trackHandleIt = subTracks_.find(trackAliasIt->second);
if (trackHandleIt != subTracks_.end()) {
auto trackHandle = trackHandleIt->second;
subTracks_.erase(trackHandleIt);
trackHandle->fin();
} else {
XLOG(DFATAL) << "trackAliasIt but no trackHandleIt for id="
<< subscribeDone.subscribeID << " sess=" << this;
}
subIdToTrackAlias_.erase(trackAliasIt);
checkForCloseOnDrain();
controlMessages_.enqueue(std::move(subscribeDone));
checkForCloseOnDrain();
}

void MoQSession::onMaxSubscribeId(MaxSubscribeId maxSubscribeId) {
Expand Down Expand Up @@ -516,6 +533,15 @@ void MoQSession::onFetch(Fetch fetch) {
"End must be after start"});
return;
}
auto it = pubTracks_.find(fetch.subscribeID);
if (it != pubTracks_.end()) {
XLOG(ERR) << "Duplicate subscribe ID=" << fetch.subscribeID
<< " sess=" << this;
fetchError({fetch.subscribeID, 400, "dup sub ID"});
return;
}
pubTracks_[fetch.subscribeID].priority = fetch.priority;
pubTracks_[fetch.subscribeID].groupOrder = fetch.groupOrder;
controlMessages_.enqueue(std::move(fetch));
}

Expand Down Expand Up @@ -824,7 +850,12 @@ MoQSession::subscribe(SubscribeRequest sub) {

void MoQSession::subscribeOk(SubscribeOk subOk) {
XLOG(DBG1) << __func__ << " sess=" << this;
pubTracks_[subOk.subscribeID].groupOrder = subOk.groupOrder;
auto it = pubTracks_.find(subOk.subscribeID);
if (it == pubTracks_.end()) {
XLOG(ERR) << "Invalid Subscribe OK, id=" << subOk.subscribeID;
return;
}
it->second.groupOrder = subOk.groupOrder;
auto res = writeSubscribeOk(controlWriteBuf_, subOk);
if (!res) {
XLOG(ERR) << "writeSubscribeOk failed sess=" << this;
Expand All @@ -835,7 +866,12 @@ void MoQSession::subscribeOk(SubscribeOk subOk) {

void MoQSession::subscribeError(SubscribeError subErr) {
XLOG(DBG1) << __func__ << " sess=" << this;
pubTracks_.erase(subErr.subscribeID);
auto it = pubTracks_.find(subErr.subscribeID);
if (it == pubTracks_.end()) {
XLOG(ERR) << "Invalid Subscribe OK, id=" << subErr.subscribeID;
return;
}
pubTracks_.erase(it);
auto res = writeSubscribeError(controlWriteBuf_, std::move(subErr));
retireSubscribeId(/*signal=*/false);
if (!res) {
Expand All @@ -847,6 +883,25 @@ void MoQSession::subscribeError(SubscribeError subErr) {

void MoQSession::unsubscribe(Unsubscribe unsubscribe) {
XLOG(DBG1) << __func__ << " sess=" << this;
auto trackAliasIt = subIdToTrackAlias_.find(unsubscribe.subscribeID);
if (trackAliasIt == subIdToTrackAlias_.end()) {
// unknown
XLOG(ERR) << "No matching subscribe ID=" << unsubscribe.subscribeID
<< " sess=" << this;
return;
}
auto trackIt = subTracks_.find(trackAliasIt->second);
if (trackIt == subTracks_.end()) {
// unknown
XLOG(ERR) << "No matching subscribe ID=" << unsubscribe.subscribeID
<< " sess=" << this;
return;
}
// no more callbacks after unsubscribe
XLOG(DBG1) << "unsubscribing from ftn=" << trackIt->second->fullTrackName()
<< " sess=" << this;
// if there are open streams for this subscription, we should STOP_SENDING
// them?
auto res = writeUnsubscribe(controlWriteBuf_, std::move(unsubscribe));
if (!res) {
XLOG(ERR) << "writeUnsubscribe failed sess=" << this;
Expand All @@ -859,7 +914,13 @@ void MoQSession::unsubscribe(Unsubscribe unsubscribe) {

void MoQSession::subscribeDone(SubscribeDone subDone) {
XLOG(DBG1) << __func__ << " sess=" << this;
pubTracks_.erase(subDone.subscribeID);
auto it = pubTracks_.find(subDone.subscribeID);
if (it == pubTracks_.end()) {
XLOG(ERR) << "subscribeDone for invalid id=" << subDone.subscribeID
<< " sess=" << this;
return;
}
pubTracks_.erase(it);
auto res = writeSubscribeDone(controlWriteBuf_, std::move(subDone));
if (!res) {
XLOG(ERR) << "writeSubscribeDone failed sess=" << this;
Expand Down Expand Up @@ -897,6 +958,20 @@ void MoQSession::sendMaxSubscribeID(bool signal) {

void MoQSession::subscribeUpdate(SubscribeUpdate subUpdate) {
XLOG(DBG1) << __func__ << " sess=" << this;
auto trackAliasIt = subIdToTrackAlias_.find(subUpdate.subscribeID);
if (trackAliasIt == subIdToTrackAlias_.end()) {
// unknown
XLOG(ERR) << "No matching subscribe ID=" << subUpdate.subscribeID
<< " sess=" << this;
return;
}
auto trackIt = subTracks_.find(trackAliasIt->second);
if (trackIt == subTracks_.end()) {
// unknown
XLOG(ERR) << "No matching subscribe ID=" << subUpdate.subscribeID
<< " sess=" << this;
return;
}
auto res = writeSubscribeUpdate(controlWriteBuf_, std::move(subUpdate));
if (!res) {
XLOG(ERR) << "writeSubscribeUpdate failed sess=" << this;
Expand Down Expand Up @@ -941,6 +1016,11 @@ MoQSession::fetch(Fetch fetch) {

void MoQSession::fetchOk(FetchOk fetchOk) {
XLOG(DBG1) << __func__ << " sess=" << this;
auto it = pubTracks_.find(fetchOk.subscribeID);
if (it == pubTracks_.end()) {
XLOG(ERR) << "Invalid Fetch OK, id=" << fetchOk.subscribeID;
return;
}
auto res = writeFetchOk(controlWriteBuf_, fetchOk);
if (!res) {
XLOG(ERR) << "writeFetchOk failed sess=" << this;
Expand All @@ -951,6 +1031,12 @@ void MoQSession::fetchOk(FetchOk fetchOk) {

void MoQSession::fetchError(FetchError fetchErr) {
XLOG(DBG1) << __func__ << " sess=" << this;
if (pubTracks_.erase(fetchErr.subscribeID) == 0) {
// fetchError is called sometimes before adding publisher state, so this
// is not an error
XLOG(DBG1) << "fetchErr for invalid id=" << fetchErr.subscribeID
<< " sess=" << this;
}
auto res = writeFetchError(controlWriteBuf_, std::move(fetchErr));
if (!res) {
XLOG(ERR) << "writeFetchError failed sess=" << this;
Expand All @@ -961,6 +1047,12 @@ void MoQSession::fetchError(FetchError fetchErr) {

void MoQSession::fetchCancel(FetchCancel fetchCan) {
XLOG(DBG1) << __func__ << " sess=" << this;
auto trackIt = fetches_.find(fetchCan.subscribeID);
if (trackIt == fetches_.end()) {
XLOG(ERR) << "unknown subscribe ID=" << fetchCan.subscribeID
<< " sess=" << this;
return;
}
auto res = writeFetchCancel(controlWriteBuf_, std::move(fetchCan));
if (!res) {
XLOG(ERR) << "writeFetchCancel failed sess=" << this;
Expand Down

0 comments on commit 874e587

Please sign in to comment.