diff --git a/cpp/tensorrt_llm/batch_manager/pauseRequests.cpp b/cpp/tensorrt_llm/batch_manager/pauseRequests.cpp index fc7e8b51fd2..1ce26ef497c 100644 --- a/cpp/tensorrt_llm/batch_manager/pauseRequests.cpp +++ b/cpp/tensorrt_llm/batch_manager/pauseRequests.cpp @@ -50,8 +50,8 @@ void tensorrt_llm::batch_manager::PauseRequests::operator()(RequestVector& reque for (auto& llmReq : requestsToPause) { auto const reqId = llmReq->mRequestId; - inflightReqIds.erase(reqId); - TLLM_LOG_DEBUG("request with ID %lu removed from DECODER model inflight set", reqId); + auto const removed = inflightReqIds.erase(reqId); + TLLM_LOG_DEBUG("request with ID %lu removed from DECODER model inflight set: %d", reqId, removed); // If a request in this context had been flagged to be paused, pause it right away if (reqIdsToPause.find(reqId) != reqIdsToPause.end()) diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp index f128dbb2ee4..9210fe95874 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp @@ -881,8 +881,6 @@ void TrtGptModelInflightBatching::forwardSync() } } - (*mPauseRequests)(currRequests.contextRequests, mInflightReqIds, mReqIdsToPause, true, *mSeqSlotManager, - mKvCacheManager, mCrossKvCacheManager, mPeftCacheManager); (*mPauseRequests)(currRequests.generationRequests, mInflightReqIds, mReqIdsToPause, true, *mSeqSlotManager, mKvCacheManager, mCrossKvCacheManager, mPeftCacheManager); @@ -1051,14 +1049,23 @@ void TrtGptModelInflightBatching::forwardAsync(RequestList const& activeRequests { NVTX3_SCOPED_RANGE(updateInflightReqIds); // Add requests to in-flight set, so they can be skipped in other micro batches - for (auto const& requests : {currRequests.contextRequests, currRequests.generationRequests}) + for (auto const& llmReq : currRequests.contextRequests) { - for (auto const& llmReq : requests) + // Context requests that are chunking are not added to inflight set, so they are scheduled in the + // next micro batch. + if (llmReq->isLastContextChunk()) { - TLLM_LOG_DEBUG("request with ID %lu added to DECODER model inflight set", llmReq->mRequestId); + TLLM_LOG_DEBUG( + "Context request with ID %lu added to DECODER model inflight set", llmReq->mRequestId); mInflightReqIds.insert(llmReq->mRequestId); } } + for (auto const& llmReq : currRequests.generationRequests) + { + TLLM_LOG_DEBUG( + "Generation request with ID %lu added to DECODER model inflight set", llmReq->mRequestId); + mInflightReqIds.insert(llmReq->mRequestId); + } } (*mAssignReqSeqSlots)(*mSeqSlotManager, currRequests.contextRequests, currRequests.generationRequests); diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 8da982aba2b..d691093874b 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -674,8 +674,10 @@ def create_py_executor_instance( spec_config = model_engine.spec_config + max_num_sequences = max_batch_size * mapping.pp_size + logger.info( - f"max_seq_len={max_seq_len}, max_num_requests={max_batch_size}, max_num_tokens={max_num_tokens}, max_batch_size={max_batch_size}" + f"max_seq_len={max_seq_len}, max_num_requests={max_num_sequences}, max_num_tokens={max_num_tokens}, max_batch_size={max_batch_size}" ) for key, value in llm_args.extra_resource_managers.items(): @@ -760,8 +762,6 @@ def create_py_executor_instance( lora_config.trtllm_modules_to_hf_modules, lora_config.swap_gate_up_proj_lora_b_weight) - max_num_sequences = max_batch_size * mapping.pp_size - resources[ResourceManagerType.SEQ_SLOT_MANAGER] = SeqSlotManager( max_num_sequences) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 00df964ca63..5d5a2cefeed 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -106,6 +106,7 @@ class BatchState: class BatchStatePP(BatchState): microbatch_id: int = -1 scheduled_ctx_reqs: list[LlmRequest] = None + finished_ctx_reqs: list[LlmRequest] = None class PyExecutor: @@ -232,6 +233,8 @@ def __init__(self, | None] = [None] * self.num_micro_batches self.send_handles = [None] * self.num_micro_batches + # Set of request IDs that are currently in flight across all micro batches. + # The scheduler will avoid scheduling requests that are already in flight. self.inflight_req_ids = ReqIdsSet() # During warmup, we don't enable the profiler @@ -694,7 +697,7 @@ def get_queued_req_stats(request_id: int) -> RequestStats: return req_stats def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests, - scheduled_batch) -> IterationStats: + scheduled_batch, micro_batch_id) -> IterationStats: stats.iter_latency_ms = iter_latency_ms stats.num_queued_requests = self.executor_request_queue.get_request_queue_size( @@ -735,7 +738,7 @@ def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests, stats.inflight_batching_stats.num_paused_requests = len( scheduled_batch.paused_requests) stats.inflight_batching_stats.avg_num_decoded_tokens_per_iter = 0 - stats.inflight_batching_stats.micro_batch_id = 0 + stats.inflight_batching_stats.micro_batch_id = micro_batch_id if stats.specdec_stats is not None: stats.specdec_stats.draft_overhead = 0.0 if iter_latency_ms <= 0.0 else float( stats.specdec_stats.iter_latency_ms) / float(iter_latency_ms) @@ -748,9 +751,13 @@ def _append_iter_stats(self, with self.stats_lock: self.stats.append((stats, req_stats)) - def _process_iter_stats(self, finished_requests: list[LlmRequest], - active_requests: List[LlmRequest], - batch_state: BatchState): + def _process_iter_stats( + self, + finished_requests: list[LlmRequest], + active_requests: List[LlmRequest], + batch_state: BatchState, + micro_batch_id: int = 0, + ): iter_end_time = time.time() iter_latency_ms = (iter_end_time - batch_state.iter_start_time) * 1e3 if batch_state.iter_stats is None: @@ -763,9 +770,10 @@ def _process_iter_stats(self, finished_requests: list[LlmRequest], and self.enable_iter_perf_stats) else None self._append_iter_stats( - self._update_iter_stats( - batch_state.iter_stats, iter_latency_ms, len(finished_requests), - batch_state.sample_state.scheduled_requests), req_stats) + self._update_iter_stats(batch_state.iter_stats, iter_latency_ms, + len(finished_requests), + batch_state.sample_state.scheduled_requests, + micro_batch_id), req_stats) def _executor_loop_cleanup(self): @@ -825,6 +833,7 @@ def _executor_loop_pp(self): self.num_scheduled_requests = scheduled_batch.batch_size logger.debug( + f'iteration {self.iter_counter}, microbatch {microbatch_id}, ' f'has {len(self.active_requests)} active_requests, ' f'scheduled {len(scheduled_batch.context_requests)} context requests and ' f'{len(scheduled_batch.generation_requests)} generation requests' @@ -833,9 +842,13 @@ def _executor_loop_pp(self): can_queue = self._can_queue(scheduled_batch) if not can_queue: + logger.debug( + f"microbatch {microbatch_id} cannot be queued, skipping" + ) self.micro_batches[microbatch_id] = None else: - self._add_inflight_ids(scheduled_batch) + logger.debug(f"microbatch {microbatch_id} can be queued") + finished_ctx_reqs = self._add_inflight_ids(scheduled_batch) if self.kv_cache_transceiver: # For generation requests which have completed KV cache transfer @@ -895,6 +908,7 @@ def _executor_loop_pp(self): iter_stats=iter_stats, microbatch_id=microbatch_id, scheduled_ctx_reqs=scheduled_batch.context_requests, + finished_ctx_reqs=finished_ctx_reqs, ) self.micro_batches[microbatch_id] = batch_state @@ -949,6 +963,8 @@ def _executor_loop_pp(self): finished_requests = [] if previous_batch is not None: with torch.cuda.nvtx.range("_handle_previous_batch_pp"): + sample_state = previous_batch.sample_state + sample_state.scheduled_requests.context_requests = previous_batch.finished_ctx_reqs self._update_requests(previous_batch.sample_state) if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver: @@ -980,7 +996,8 @@ def _executor_loop_pp(self): self.resource_manager.update_resources( previous_scheduled_batch, attn_metadata, kv_cache_dtype_byte_size) - self._remove_inflight_ids(previous_scheduled_batch) + + self._remove_inflight_ids(previous_batch) self.wait_on_pp_send_handles(prev_microbatch_id) self.micro_batches[prev_microbatch_id] = None @@ -997,9 +1014,11 @@ def _executor_loop_pp(self): microbatch_id = (microbatch_id + 1) % self.num_micro_batches if self.enable_iter_perf_stats and previous_batch is not None: + sample_state = previous_batch.sample_state + sample_state.scheduled_requests.context_requests = previous_batch.scheduled_ctx_reqs self._process_iter_stats(finished_requests, self.active_requests, - previous_batch) + previous_batch, microbatch_id) self.iter_counter += 1 @@ -2485,13 +2504,43 @@ def _pause_requests(self, requests_to_pause): self._terminate_request(req) def _add_inflight_ids(self, scheduled_requests): - """Add reqids of current requests to self.inflight_req_ids.""" - for req in scheduled_requests.all_requests(): + """Add request IDs of current requests to self.inflight_req_ids. + + Non‑final context chunks are not added to the inflight set, so the scheduler can keep scheduling further + context chunks while earlier ones are in the PP pipeline. Only context requests that finish context phase + are inserted into the inflight set and collected into finished_ctx_reqs. + All generation requests are still inserted into the inflight set. + """ + finished_ctx_reqs = [] + for req in scheduled_requests.context_requests: + if req.is_last_context_chunk: + logger.debug( + f"Context request with ID {req.request_id} added to DECODER model inflight set" + ) + self.inflight_req_ids.insert(req.request_id) + finished_ctx_reqs.append(req) + for req in scheduled_requests.generation_requests: + logger.debug( + f"Generation request with ID {req.request_id} added to DECODER model inflight set" + ) self.inflight_req_ids.insert(req.request_id) + return finished_ctx_reqs + + def _remove_inflight_ids(self, batch_state: BatchStatePP): + """Remove request IDs of current requests from self.inflight_req_ids. - def _remove_inflight_ids(self, scheduled_requests): - """Remove reqids of current requests from self.inflight_req_ids.""" - for req in scheduled_requests.all_requests(): + Context IDs are erased from the inflight set using batch_state.finished_ctx_reqs. + Generation IDs are erased using batch_state.sample_state.scheduled_requests.generation_requests. + """ + for req in batch_state.finished_ctx_reqs: + logger.debug( + f"Context request with ID {req.request_id} removed from DECODER model inflight set" + ) + self.inflight_req_ids.erase(req.request_id) + for req in batch_state.sample_state.scheduled_requests.generation_requests: + logger.debug( + f"Generation request with ID {req.request_id} removed from DECODER model inflight set" + ) self.inflight_req_ids.erase(req.request_id) def _handle_speculative_decoding(self, scheduled_batch, previous_tensors, diff --git a/tests/unittest/llmapi/test_llm.py b/tests/unittest/llmapi/test_llm.py index f576cba4a1c..fb6e24b81a0 100644 --- a/tests/unittest/llmapi/test_llm.py +++ b/tests/unittest/llmapi/test_llm.py @@ -2046,39 +2046,81 @@ def batch_task(): batch_task() -def validate_stats(results, - pytorch_backend, - max_tokens, - enable_iter_req_stats=False): +def validate_stats( + *, + results, + pytorch_backend, + max_tokens, + pp_size=1, + use_overlap=False, + enable_chunked_prefill=False, + enable_iter_req_stats=False, +): assert results - assert len(results) == max_tokens if pytorch_backend else max_tokens + 1 for iter, result in enumerate(results): ifbStats = result["inflightBatchingStats"] - expected_num_scheduled = 1 if (iter < max_tokens) else 0 - assert ifbStats["numScheduledRequests"] == expected_num_scheduled - if iter == 0: - assert ifbStats["numContextRequests"] == 1 - assert ifbStats["numGenRequests"] == 0 - assert result["numActiveRequests"] == 1 - elif iter == max_tokens: - assert ifbStats["numContextRequests"] == 0 - assert ifbStats["numGenRequests"] == 0 - assert result["numActiveRequests"] == 0 + print(f"iter: {iter}, ifbStats: {ifbStats}") + + expected_num_results = max_tokens if pytorch_backend else max_tokens + 1 + if enable_chunked_prefill: + expected_num_results += 1 + assert len(results) == expected_num_results + + context_iterations = 2 if enable_chunked_prefill else 1 + generation_iterations = max_tokens - 1 + microbatch_id = 0 + for iter, result in enumerate(results): + ifbStats = result["inflightBatchingStats"] + + if iter < context_iterations: + assert ifbStats["numScheduledRequests"] == 1, f"iter: {iter}" + assert ifbStats["numContextRequests"] == 1, f"iter: {iter}" + assert ifbStats["numGenRequests"] == 0, f"iter: {iter}" + assert result["numActiveRequests"] == 1, f"iter: {iter}" + assert ifbStats["microBatchId"] == microbatch_id, f"iter: {iter}" + elif iter < (context_iterations + generation_iterations): + assert ifbStats["numScheduledRequests"] == 1, f"iter: {iter}" + assert ifbStats["numContextRequests"] == 0, f"iter: {iter}" + assert ifbStats["numGenRequests"] == 1, f"iter: {iter}" + assert result["numActiveRequests"] == 1, f"iter: {iter}" + assert ifbStats["microBatchId"] == microbatch_id, f"iter: {iter}" else: - assert ifbStats["numContextRequests"] == 0 - assert ifbStats["numGenRequests"] == 1 - assert result["numActiveRequests"] == 1 + assert ifbStats["numScheduledRequests"] == 0, f"iter: {iter}" + assert ifbStats["numContextRequests"] == 0, f"iter: {iter}" + assert ifbStats["numGenRequests"] == 0, f"iter: {iter}" + assert result["numActiveRequests"] == 0, f"iter: {iter}" + assert ifbStats["microBatchId"] == microbatch_id, f"iter: {iter}" + + # In pipeline parallel mode, increment microbatch_id for each context iteration except the last one, + # since the context chunks can be scheduled in each iteration. + if pp_size > 1 and iter < context_iterations - 1: + microbatch_id += 1 if enable_iter_req_stats: - assert "requestStats" in result + assert "requestStats" in result, f"iter: {iter}" req_stats = result["requestStats"] - assert len(req_stats) == 1 + assert len(req_stats) == 1, f"iter: {iter}" req_stat = req_stats[0] - assert req_stat["numGeneratedTokens"] == iter + 1 - assert req_stat["scheduled"] == True - assert req_stat[ - "stage"] == "GENERATION_IN_PROGRESS" if iter + 1 < max_tokens else "GENERATION_COMPLETE" - assert req_stat["contextPrefillPosition"] == 4 + if iter < (context_iterations - 1): + # If use_overlap, the stats are one iteration ahead + assert req_stat[ + "stage"] == "GENERATION_IN_PROGRESS" if use_overlap else "CONTEXT_IN_PROGRESS", f"iter: {iter}" + assert req_stat[ + "contextPrefillPosition"] == 54 if use_overlap else 32, f"iter: {iter}" + assert req_stat["numGeneratedTokens"] == 0, f"iter: {iter}" + elif iter < (context_iterations - 1 + generation_iterations): + assert req_stat[ + "stage"] == "GENERATION_IN_PROGRESS", f"iter: {iter}" + assert req_stat["contextPrefillPosition"] == 54, f"iter: {iter}" + assert req_stat["numGeneratedTokens"] == iter - ( + context_iterations - 1) + 1, f"iter: {iter}" + else: + assert req_stat[ + "stage"] == "GENERATION_COMPLETE", f"iter: {iter}" + assert req_stat["contextPrefillPosition"] == 54, f"iter: {iter}" + assert req_stat[ + "numGeneratedTokens"] == max_tokens, f"iter: {iter}" + assert req_stat["scheduled"] == True, f"iter: {iter}" expected_num_completed = 1 if iter == len(results) - 1 else 0 @@ -2088,9 +2130,11 @@ def validate_stats(results, def llm_get_stats_test_harness(tp_size: int = 1, + pp_size: int = 1, return_context_logits: bool = False, pytorch_backend: bool = False, use_overlap: bool = False, + enable_chunked_prefill: bool = False, enable_iter_req_stats: bool = False): if return_context_logits and pytorch_backend: @@ -2104,6 +2148,7 @@ def llm_get_stats_test_harness(tp_size: int = 1, print("return_context_logits: ", return_context_logits) print("pytorch_backend: ", pytorch_backend) print("use_overlap: ", use_overlap) + print("enable_chunked_prefill: ", enable_chunked_prefill) print("enable_iter_req_stats: ", enable_iter_req_stats) print("-------------") @@ -2114,6 +2159,10 @@ def llm_get_stats_test_harness(tp_size: int = 1, llm_args_extra["gather_generation_logits"] = True sampling_args_extra["return_context_logits"] = True + if enable_chunked_prefill: + llm_args_extra["enable_chunked_prefill"] = True + llm_args_extra["max_num_tokens"] = 32 + if pytorch_backend: llm_args_extra.update( dict(enable_iter_perf_stats=True, @@ -2126,27 +2175,39 @@ def llm_get_stats_test_harness(tp_size: int = 1, if not pytorch_backend: llm_args_extra["fast_build"] = True - llm = LLM_CLASS(model=llama_model_path, - kv_cache_config=global_kvcache_config, - tensor_parallel_size=tp_size, - **llm_args_extra) + with LLM_CLASS(model=llama_model_path, + kv_cache_config=global_kvcache_config, + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + **llm_args_extra) as llm: - max_tokens = 5 - sampling_params = SamplingParams(max_tokens=max_tokens, - **sampling_args_extra) + max_tokens = 5 + sampling_params = SamplingParams(max_tokens=max_tokens, + **sampling_args_extra) - for output in llm.generate(prompts, sampling_params=sampling_params): - print(output) + long_prompts = [ + "A B C D E F G H I J K L M N O P Q R S T U V W X Y Z " * 2 + ] - results = llm.get_stats(2) + for output in llm.generate(long_prompts, + sampling_params=sampling_params): + print(output) - validate_stats(results, pytorch_backend, max_tokens, enable_iter_req_stats) + results = llm.get_stats(2) + + validate_stats(results=results, + pp_size=pp_size, + pytorch_backend=pytorch_backend, + max_tokens=max_tokens, + use_overlap=use_overlap, + enable_chunked_prefill=enable_chunked_prefill, + enable_iter_req_stats=enable_iter_req_stats) - assert not llm.get_stats(2) + assert not llm.get_stats(2) - # test that IterationResult()._done is properly set - _ = llm.generate(prompts, sampling_params=sampling_params) - assert llm.get_stats(2) + # test that IterationResult()._done is properly set + _ = llm.generate(prompts, sampling_params=sampling_params) + assert llm.get_stats(2) @pytest.mark.parametrize("return_context_logits", [True, False]) @@ -2214,9 +2275,11 @@ def test_llm_get_queued_stats(): def llm_get_stats_async_test_harness(tp_size: int = 1, + pp_size: int = 1, return_context_logits: bool = False, pytorch_backend: bool = False, use_overlap: bool = False, + enable_chunked_prefill: bool = False, enable_iter_req_stats: bool = False): if return_context_logits and pytorch_backend: @@ -2230,6 +2293,7 @@ def llm_get_stats_async_test_harness(tp_size: int = 1, print("return_context_logits: ", return_context_logits) print("pytorch_backend: ", pytorch_backend) print("use_overlap: ", use_overlap) + print("enable_chunked_prefill: ", enable_chunked_prefill) print("enable_iter_req_stats: ", enable_iter_req_stats) print("-------------") @@ -2239,6 +2303,10 @@ def llm_get_stats_async_test_harness(tp_size: int = 1, llm_args_extra["build_config"] = BuildConfig(gather_context_logits=True) sampling_args_extra["return_context_logits"] = True + if enable_chunked_prefill: + llm_args_extra["enable_chunked_prefill"] = True + llm_args_extra["max_num_tokens"] = 32 + if pytorch_backend: llm_args_extra.update( dict(enable_iter_perf_stats=True, @@ -2249,38 +2317,51 @@ def llm_get_stats_async_test_harness(tp_size: int = 1, LLM_CLASS = LLM llm_args_extra["fast_build"] = True - llm = LLM_CLASS(model=llama_model_path, - kv_cache_config=global_kvcache_config, - tensor_parallel_size=tp_size, - **llm_args_extra) + with LLM_CLASS(model=llama_model_path, + kv_cache_config=global_kvcache_config, + tensor_parallel_size=tp_size, + **llm_args_extra) as llm: - max_tokens = 6 - sampling_params = SamplingParams(max_tokens=max_tokens, - **sampling_args_extra) + max_tokens = 6 + sampling_params = SamplingParams(max_tokens=max_tokens, + **sampling_args_extra) - async def task0(): - async for output in llm.generate_async(prompts[0], - streaming=True, - sampling_params=sampling_params): - print(output) + long_prompts = [ + "A B C D E F G H I J K L M N O P Q R S T U V W X Y Z " * 2 + ] - async def task1(): - results = [] - await asyncio.sleep( - 3) # ensure there's stats to collect for the assertion - async for stats in llm.get_stats_async(timeout=2): - results.append(stats) + async def task0(): + async for output in llm.generate_async( + long_prompts[0], + streaming=True, + sampling_params=sampling_params): + print(output) - assert results - if not use_overlap: - validate_stats(results, pytorch_backend, max_tokens, - enable_iter_req_stats) + async def task1(repetition_index: int): + results = [] + await asyncio.sleep( + 3) # ensure there's stats to collect for the assertion + async for stats in llm.get_stats_async(timeout=2): + results.append(stats) + + assert results + if not use_overlap: + validate_stats( + results=results, + pp_size=pp_size, + pytorch_backend=pytorch_backend, + max_tokens=max_tokens, + use_overlap=use_overlap, + # After the first repetition, context will be reused and there will be no chunking. + enable_chunked_prefill=enable_chunked_prefill + if repetition_index == 0 else False, + enable_iter_req_stats=enable_iter_req_stats) - async def main(): - for i in range(2): # test recurrent usage - await asyncio.gather(task0(), task1()) + async def main(): + for repetition_index in range(2): # test recurrent usage + await asyncio.gather(task0(), task1(repetition_index)) - asyncio.run(main()) + asyncio.run(main()) @pytest.mark.parametrize("return_context_logits", [True, False]) diff --git a/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py b/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py index a43d92fa5a4..30815bc5744 100644 --- a/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py +++ b/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py @@ -11,6 +11,7 @@ check_llama_7b_multi_lora_from_request_test_harness, check_phi3_lora_fused_modules_output_tp2_identical_to_tp1) from .test_llm import (_test_llm_capture_request_error, llama_model_path, + llm_get_stats_test_harness, llm_return_logprobs_test_harness, tinyllama_logits_processor_test_harness) from .test_llm_pytorch import llama_7b_lora_from_dir_test_harness @@ -125,3 +126,45 @@ def test_llm_return_logprobs_streaming_tp2(prompt_logprobs, logprobs, streaming=True, backend="pytorch", tp_size=2) + + +@skip_ray +@pytest.mark.gpu2 +@pytest.mark.parametrize( + "return_context_logits, enable_chunked_prefill, enable_iter_req_stats", + [ + (False, False, True), + (False, True, True), + ], +) +def test_llm_get_stats_pp2(return_context_logits, enable_chunked_prefill, + enable_iter_req_stats): + llm_get_stats_test_harness( + tp_size=1, + pp_size=2, + return_context_logits=return_context_logits, + pytorch_backend=True, + enable_chunked_prefill=enable_chunked_prefill, + enable_iter_req_stats=enable_iter_req_stats, + ) + + +@skip_ray +@pytest.mark.gpu4 +@pytest.mark.parametrize( + "return_context_logits, enable_chunked_prefill, enable_iter_req_stats", + [ + (False, False, True), + (False, True, True), + ], +) +def test_llm_get_stats_pp4(return_context_logits, enable_chunked_prefill, + enable_iter_req_stats): + llm_get_stats_test_harness( + tp_size=1, + pp_size=4, + return_context_logits=return_context_logits, + pytorch_backend=True, + enable_chunked_prefill=enable_chunked_prefill, + enable_iter_req_stats=enable_iter_req_stats, + ) diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index 1bdd2dfbeb5..4883a1a1ed4 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -54,36 +54,42 @@ def test_tinyllama_logits_processor(enable_chunked_prefill): @skip_ray @pytest.mark.parametrize( - "return_context_logits, use_overlap, enable_iter_req_stats", [ - (False, False, False), - (False, False, True), - (False, True, False), - (False, True, True), + "return_context_logits, use_overlap, enable_chunked_prefill, enable_iter_req_stats", + [ + (False, False, False, True), + (False, False, True, True), + (False, True, False, True), + (False, True, True, True), ]) def test_llm_get_stats(return_context_logits, use_overlap, - enable_iter_req_stats): + enable_chunked_prefill, enable_iter_req_stats): llm_get_stats_test_harness(tp_size=1, + pp_size=1, return_context_logits=return_context_logits, pytorch_backend=True, use_overlap=use_overlap, + enable_chunked_prefill=enable_chunked_prefill, enable_iter_req_stats=enable_iter_req_stats) @skip_ray @pytest.mark.parametrize( - "return_context_logits, use_overlap, enable_iter_req_stats", [ - (False, False, False), - (False, False, True), - (False, True, False), - (False, True, True), + "return_context_logits, use_overlap, enable_chunked_prefill, enable_iter_req_stats", + [ + (False, False, False, True), + (False, False, True, True), + (False, True, False, True), + (False, True, True, True), ]) def test_llm_get_stats_async(return_context_logits, use_overlap, - enable_iter_req_stats): + enable_chunked_prefill, enable_iter_req_stats): llm_get_stats_async_test_harness( tp_size=1, + pp_size=1, return_context_logits=return_context_logits, pytorch_backend=True, use_overlap=use_overlap, + enable_chunked_prefill=enable_chunked_prefill, enable_iter_req_stats=enable_iter_req_stats)