Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/tensorrt_llm/batch_manager/pauseRequests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ void tensorrt_llm::batch_manager::PauseRequests::operator()(RequestVector& reque
for (auto& llmReq : requestsToPause)
{
auto const reqId = llmReq->mRequestId;
inflightReqIds.erase(reqId);
TLLM_LOG_DEBUG("request with ID %lu removed from DECODER model inflight set", reqId);
auto const removed = inflightReqIds.erase(reqId);
TLLM_LOG_DEBUG("request with ID %lu removed from DECODER model inflight set: %d", reqId, removed);

// If a request in this context had been flagged to be paused, pause it right away
if (reqIdsToPause.find(reqId) != reqIdsToPause.end())
Expand Down
17 changes: 12 additions & 5 deletions cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -881,8 +881,6 @@ void TrtGptModelInflightBatching::forwardSync()
}
}

(*mPauseRequests)(currRequests.contextRequests, mInflightReqIds, mReqIdsToPause, true, *mSeqSlotManager,
mKvCacheManager, mCrossKvCacheManager, mPeftCacheManager);
(*mPauseRequests)(currRequests.generationRequests, mInflightReqIds, mReqIdsToPause, true, *mSeqSlotManager,
mKvCacheManager, mCrossKvCacheManager, mPeftCacheManager);

Expand Down Expand Up @@ -1051,14 +1049,23 @@ void TrtGptModelInflightBatching::forwardAsync(RequestList const& activeRequests
{
NVTX3_SCOPED_RANGE(updateInflightReqIds);
// Add requests to in-flight set, so they can be skipped in other micro batches
for (auto const& requests : {currRequests.contextRequests, currRequests.generationRequests})
for (auto const& llmReq : currRequests.contextRequests)
{
for (auto const& llmReq : requests)
// Context requests that are chunking are not added to inflight set, so they are scheduled in the
// next micro batch.
if (llmReq->isLastContextChunk())
{
TLLM_LOG_DEBUG("request with ID %lu added to DECODER model inflight set", llmReq->mRequestId);
TLLM_LOG_DEBUG(
"Context request with ID %lu added to DECODER model inflight set", llmReq->mRequestId);
mInflightReqIds.insert(llmReq->mRequestId);
}
}
for (auto const& llmReq : currRequests.generationRequests)
{
TLLM_LOG_DEBUG(
"Generation request with ID %lu added to DECODER model inflight set", llmReq->mRequestId);
mInflightReqIds.insert(llmReq->mRequestId);
}
}

(*mAssignReqSeqSlots)(*mSeqSlotManager, currRequests.contextRequests, currRequests.generationRequests);
Expand Down
6 changes: 3 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,8 +674,10 @@ def create_py_executor_instance(

spec_config = model_engine.spec_config

max_num_sequences = max_batch_size * mapping.pp_size

logger.info(
f"max_seq_len={max_seq_len}, max_num_requests={max_batch_size}, max_num_tokens={max_num_tokens}, max_batch_size={max_batch_size}"
f"max_seq_len={max_seq_len}, max_num_requests={max_num_sequences}, max_num_tokens={max_num_tokens}, max_batch_size={max_batch_size}"
)

for key, value in llm_args.extra_resource_managers.items():
Expand Down Expand Up @@ -760,8 +762,6 @@ def create_py_executor_instance(
lora_config.trtllm_modules_to_hf_modules,
lora_config.swap_gate_up_proj_lora_b_weight)

max_num_sequences = max_batch_size * mapping.pp_size

resources[ResourceManagerType.SEQ_SLOT_MANAGER] = SeqSlotManager(
max_num_sequences)

Expand Down
39 changes: 34 additions & 5 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class BatchState:
class BatchStatePP(BatchState):
microbatch_id: int = -1
scheduled_ctx_reqs: list[LlmRequest] = None
finished_ctx_reqs: list[LlmRequest] = None


class PyExecutor:
Expand Down Expand Up @@ -836,9 +837,13 @@ def _executor_loop_pp(self):
can_queue = self._can_queue(scheduled_batch)

if not can_queue:
logger.debug(
f"microbatch {microbatch_id} cannot be queued, skipping"
)
self.micro_batches[microbatch_id] = None
else:
self._add_inflight_ids(scheduled_batch)
logger.debug(f"microbatch {microbatch_id} can be queued")
finished_ctx_reqs = self._add_inflight_ids(scheduled_batch)

if self.kv_cache_transceiver:
# For generation requests which have completed KV cache transfer
Expand Down Expand Up @@ -898,6 +903,7 @@ def _executor_loop_pp(self):
iter_stats=iter_stats,
microbatch_id=microbatch_id,
scheduled_ctx_reqs=scheduled_batch.context_requests,
finished_ctx_reqs=finished_ctx_reqs,
)

self.micro_batches[microbatch_id] = batch_state
Expand Down Expand Up @@ -952,6 +958,8 @@ def _executor_loop_pp(self):
finished_requests = []
if previous_batch is not None:
with torch.cuda.nvtx.range("_handle_previous_batch_pp"):
sample_state = previous_batch.sample_state
sample_state.scheduled_requests.context_requests = previous_batch.finished_ctx_reqs
self._update_requests(previous_batch.sample_state)

if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver:
Expand Down Expand Up @@ -983,7 +991,8 @@ def _executor_loop_pp(self):
self.resource_manager.update_resources(
previous_scheduled_batch, attn_metadata,
kv_cache_dtype_byte_size)
self._remove_inflight_ids(previous_scheduled_batch)

self._remove_inflight_ids(previous_batch)

self.wait_on_pp_send_handles(prev_microbatch_id)
self.micro_batches[prev_microbatch_id] = None
Expand Down Expand Up @@ -2469,12 +2478,32 @@ def _pause_requests(self, requests_to_pause):

def _add_inflight_ids(self, scheduled_requests):
"""Add reqids of current requests to self.inflight_req_ids."""
for req in scheduled_requests.all_requests():
finished_ctx_reqs = []
for req in scheduled_requests.context_requests:
if req.is_last_context_chunk:
logger.debug(
f"Context request with ID {req.request_id} added to DECODER model inflight set"
)
self.inflight_req_ids.insert(req.request_id)
finished_ctx_reqs.append(req)
for req in scheduled_requests.generation_requests:
logger.debug(
f"Generation request with ID {req.request_id} added to DECODER model inflight set"
)
self.inflight_req_ids.insert(req.request_id)
return finished_ctx_reqs

def _remove_inflight_ids(self, scheduled_requests):
def _remove_inflight_ids(self, batch_state: BatchStatePP):
"""Remove reqids of current requests from self.inflight_req_ids."""
for req in scheduled_requests.all_requests():
for req in batch_state.finished_ctx_reqs:
logger.debug(
f"Context request with ID {req.request_id} removed from DECODER model inflight set"
)
self.inflight_req_ids.erase(req.request_id)
for req in batch_state.sample_state.scheduled_requests.generation_requests:
logger.debug(
f"Generation request with ID {req.request_id} removed from DECODER model inflight set"
)
self.inflight_req_ids.erase(req.request_id)

def _handle_speculative_decoding(self, scheduled_batch, previous_tensors,
Expand Down
Loading