Skip to content

Commit b4e6a16

Browse files
authored
[https://nvbugs/5451280][fix] Reduce memory fraction problem by warmu… (#7999)
Signed-off-by: Jin Li <[email protected]>
1 parent ef8e217 commit b4e6a16

File tree

2 files changed

+82
-23
lines changed

2 files changed

+82
-23
lines changed

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,19 @@ def _get_token_num_for_estimation(self) -> int:
207207
# Requests cannot share KV cache blocks. Round up to nearest integer multiple of block size.
208208
num_cache_blocks += (num_req_tokens + self._tokens_per_block -
209209
1) // self._tokens_per_block
210+
211+
# Max cuda graph warmup required tokens
212+
max_cuda_graph_bs = min(self._model_engine.batch_size,
213+
self._model_engine._max_cuda_graph_batch_size)
214+
cuda_graph_warmup_block = (
215+
self._model_engine.max_seq_len +
216+
1) // self._tokens_per_block + max_cuda_graph_bs - 1
217+
num_cache_blocks = max(cuda_graph_warmup_block, num_cache_blocks)
218+
219+
# This is the minimal blocks required to run with max bs
220+
# If not able to allocate self._model_engine.batch_size blocks, the max batch size should be adjusted.
221+
num_cache_blocks = max(num_cache_blocks, self._model_engine.batch_size)
222+
210223
# Multiply by beam width, to prevent rescaling of the max_seq_len caused by the influence of beam width during the preparation for kv_cache_estimation
211224
return num_cache_blocks * self._tokens_per_block * self._dummy_reqs[
212225
0].sampling_config.beam_width

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 69 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -527,12 +527,16 @@ def get_cuda_graph_warmup_request(batch_size, draft_len):
527527
result = None
528528
return result
529529

530-
def get_warmup_request(num_tokens: int, num_gen_tokens: int):
530+
def get_warmup_request(num_tokens: int,
531+
num_gen_tokens: int,
532+
least_requests: bool = True):
531533
available_tokens = kv_cache_manager.get_num_available_tokens(
532534
self.runtime_draft_len)
533535
available_blocks = kv_cache_manager.get_num_free_blocks()
534536
if num_tokens > self.max_num_tokens or num_tokens > available_tokens:
535537
return None
538+
if num_gen_tokens > self.batch_size:
539+
return None
536540

537541
num_extra_decoding_steps = get_num_extra_decoding_steps()
538542
if num_extra_decoding_steps > 0:
@@ -550,14 +554,28 @@ def get_warmup_request(num_tokens: int, num_gen_tokens: int):
550554
num_full_seqs = 0
551555
num_left_over_tokens = 0
552556

557+
max_context_requests = self.batch_size - num_gen_tokens
558+
if max_context_requests * max_seq_len < num_ctx_tokens:
559+
return None
560+
553561
if num_ctx_tokens > 0:
554-
# We will try to assign as less context requests as possible to
555-
# fill the num_ctx_tokens.
562+
if least_requests:
563+
# We will try to assign as less context requests as possible to
564+
# fill the num_ctx_tokens.
556565

557-
# Num full sequences:
558-
num_full_seqs = num_ctx_tokens // max_seq_len
559-
num_left_over_tokens = num_ctx_tokens - num_full_seqs * max_seq_len
566+
# Num full sequences:
567+
num_full_seqs = num_ctx_tokens // max_seq_len
568+
num_left_over_tokens = num_ctx_tokens - num_full_seqs * max_seq_len
560569

570+
else:
571+
max_bs = min(num_ctx_tokens,
572+
self.batch_size - num_gen_tokens)
573+
if num_ctx_tokens % max_bs == 0:
574+
num_full_seqs = max_bs
575+
else:
576+
num_full_seqs = max_bs - 1
577+
max_seq_len = num_ctx_tokens // num_full_seqs
578+
num_left_over_tokens = num_ctx_tokens - max_seq_len * num_full_seqs
561579
num_ctx_requests = num_full_seqs + (1 if num_left_over_tokens
562580
> 0 else 0)
563581

@@ -633,33 +651,38 @@ def release_batch(result: ScheduledRequests | None):
633651
if cp_type == CpType.STAR:
634652
return
635653

636-
if self._torch_compile_enabled:
637-
654+
def general_warmup(reverse: bool = False):
638655
warmup_requests = set([
639656
(1, 1), # Specialize for 1 token.
640657
(self.batch_size,
641658
self.batch_size), # max_batch_size, pure generation
642659
(2, 0), # Non-one, pure context
643660
(curr_max_num_tokens, 0), # max_num_tokens, pure context
644661
])
662+
if reverse:
663+
warmup_requests = sorted(list(warmup_requests), reverse=reverse)
664+
665+
for warmup_num_tokens, warmup_num_gen_tokens in warmup_requests:
666+
with release_batch(
667+
get_warmup_request(warmup_num_tokens,
668+
warmup_num_gen_tokens)) as batch:
669+
if batch is None:
670+
# No KV cache space!
671+
continue
672+
logger.info(
673+
f"Run warmup with {warmup_num_tokens} tokens, include {warmup_num_gen_tokens} generation tokens"
674+
)
675+
self.forward(batch,
676+
new_tensors_device=None,
677+
resource_manager=resource_manager)
678+
torch.cuda.synchronize()
645679

680+
if self._torch_compile_enabled:
646681
# Disable cuda graph capture here so that we can properly capture it later
647682
with self.no_cuda_graph():
648-
for warmup_num_tokens, warmup_num_gen_tokens in warmup_requests:
649-
650-
with release_batch(
651-
get_warmup_request(warmup_num_tokens,
652-
warmup_num_gen_tokens)) as batch:
653-
if batch is None:
654-
# No KV cache space!
655-
continue
656-
logger.info(
657-
f"Run warmup with {warmup_num_tokens} tokens, include {warmup_num_gen_tokens} generation tokens"
658-
)
659-
self.forward(batch,
660-
new_tensors_device=None,
661-
resource_manager=resource_manager)
662-
torch.cuda.synchronize()
683+
# From small case to large to make sure the 1 token case is run first.
684+
# If the first graph is not the 1 token case, dynamo will specialize the non-1 token case.
685+
general_warmup()
663686

664687
if self.pytorch_backend_config.enable_autotuner:
665688
with self.no_cuda_graph(), autotune():
@@ -763,6 +786,29 @@ def _update_draft_inference_state(is_first_draft: bool,
763786
gc.collect()
764787
torch.cuda.empty_cache()
765788

789+
# When using piecewise cuda graph, the logits may suffer severe memory faction problem.
790+
# When the num of requests is growing, the block allocated by torch cannot be reused.
791+
# So after piecewise cuda graph capture, a request with most requests is triggered to make
792+
# sure that large enough blocks are allocated and can be correctly reused.
793+
for num_tokens in piecewise_cuda_graph_num_tokens:
794+
batch = get_warmup_request(num_tokens, 0, least_requests=False)
795+
if batch is None:
796+
continue
797+
with release_batch(batch) as batch:
798+
logger.info(
799+
f"Run piecewise CUDA graph warmup for num tokens={num_tokens} with most requests"
800+
)
801+
self.forward(batch,
802+
new_tensors_device=None,
803+
resource_manager=resource_manager)
804+
805+
torch.cuda.synchronize()
806+
807+
# Also, we run a general warmup from large to small to make sure that blocks are allocated well.
808+
# The cudagraph and piecewise cuda graph capture calls torch.cuda.empty_cache() and block may already
809+
# be freed even we calls general_warmup for torch compile.
810+
general_warmup(reverse=True)
811+
766812
# Set the value back to the original value
767813
self.enable_spec_decode = self.is_spec_decode
768814

0 commit comments

Comments
 (0)