Skip to content

Commit e91196b

Browse files
liji-nvmikeiovine
authored andcommitted
[https://nvbugs/5451280][fix] Reduce memory fraction problem by warmu… (NVIDIA#7999)
Signed-off-by: Jin Li <[email protected]>
1 parent 2f17009 commit e91196b

File tree

2 files changed

+82
-23
lines changed

2 files changed

+82
-23
lines changed

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,19 @@ def _get_token_num_for_estimation(self) -> int:
216216
# Requests cannot share KV cache blocks. Round up to nearest integer multiple of block size.
217217
num_cache_blocks += (num_req_tokens + self._tokens_per_block -
218218
1) // self._tokens_per_block
219+
220+
# Max cuda graph warmup required tokens
221+
max_cuda_graph_bs = min(self._model_engine.batch_size,
222+
self._model_engine._max_cuda_graph_batch_size)
223+
cuda_graph_warmup_block = (
224+
self._model_engine.max_seq_len +
225+
1) // self._tokens_per_block + max_cuda_graph_bs - 1
226+
num_cache_blocks = max(cuda_graph_warmup_block, num_cache_blocks)
227+
228+
# This is the minimal blocks required to run with max bs
229+
# If not able to allocate self._model_engine.batch_size blocks, the max batch size should be adjusted.
230+
num_cache_blocks = max(num_cache_blocks, self._model_engine.batch_size)
231+
219232
# Multiply by beam width, to prevent rescaling of the max_seq_len caused by the influence of beam width during the preparation for kv_cache_estimation
220233
return num_cache_blocks * self._tokens_per_block * self._dummy_reqs[
221234
0].sampling_config.beam_width

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 69 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -526,12 +526,16 @@ def get_cuda_graph_warmup_request(batch_size, draft_len):
526526
result = None
527527
return result
528528

529-
def get_warmup_request(num_tokens: int, num_gen_tokens: int):
529+
def get_warmup_request(num_tokens: int,
530+
num_gen_tokens: int,
531+
least_requests: bool = True):
530532
available_tokens = kv_cache_manager.get_num_available_tokens(
531533
self.runtime_draft_len)
532534
available_blocks = kv_cache_manager.get_num_free_blocks()
533535
if num_tokens > self.max_num_tokens or num_tokens > available_tokens:
534536
return None
537+
if num_gen_tokens > self.batch_size:
538+
return None
535539

536540
num_extra_decoding_steps = get_num_extra_decoding_steps()
537541
if num_extra_decoding_steps > 0:
@@ -549,14 +553,28 @@ def get_warmup_request(num_tokens: int, num_gen_tokens: int):
549553
num_full_seqs = 0
550554
num_left_over_tokens = 0
551555

556+
max_context_requests = self.batch_size - num_gen_tokens
557+
if max_context_requests * max_seq_len < num_ctx_tokens:
558+
return None
559+
552560
if num_ctx_tokens > 0:
553-
# We will try to assign as less context requests as possible to
554-
# fill the num_ctx_tokens.
561+
if least_requests:
562+
# We will try to assign as less context requests as possible to
563+
# fill the num_ctx_tokens.
555564

556-
# Num full sequences:
557-
num_full_seqs = num_ctx_tokens // max_seq_len
558-
num_left_over_tokens = num_ctx_tokens - num_full_seqs * max_seq_len
565+
# Num full sequences:
566+
num_full_seqs = num_ctx_tokens // max_seq_len
567+
num_left_over_tokens = num_ctx_tokens - num_full_seqs * max_seq_len
559568

569+
else:
570+
max_bs = min(num_ctx_tokens,
571+
self.batch_size - num_gen_tokens)
572+
if num_ctx_tokens % max_bs == 0:
573+
num_full_seqs = max_bs
574+
else:
575+
num_full_seqs = max_bs - 1
576+
max_seq_len = num_ctx_tokens // num_full_seqs
577+
num_left_over_tokens = num_ctx_tokens - max_seq_len * num_full_seqs
560578
num_ctx_requests = num_full_seqs + (1 if num_left_over_tokens
561579
> 0 else 0)
562580

@@ -634,33 +652,38 @@ def release_batch(result: ScheduledRequests | None):
634652
cp_type.name)
635653
return
636654

637-
if self._torch_compile_enabled:
638-
655+
def general_warmup(reverse: bool = False):
639656
warmup_requests = set([
640657
(1, 1), # Specialize for 1 token.
641658
(self.batch_size,
642659
self.batch_size), # max_batch_size, pure generation
643660
(2, 0), # Non-one, pure context
644661
(curr_max_num_tokens, 0), # max_num_tokens, pure context
645662
])
663+
if reverse:
664+
warmup_requests = sorted(list(warmup_requests), reverse=reverse)
665+
666+
for warmup_num_tokens, warmup_num_gen_tokens in warmup_requests:
667+
with release_batch(
668+
get_warmup_request(warmup_num_tokens,
669+
warmup_num_gen_tokens)) as batch:
670+
if batch is None:
671+
# No KV cache space!
672+
continue
673+
logger.info(
674+
f"Run warmup with {warmup_num_tokens} tokens, include {warmup_num_gen_tokens} generation tokens"
675+
)
676+
self.forward(batch,
677+
new_tensors_device=None,
678+
resource_manager=resource_manager)
679+
torch.cuda.synchronize()
646680

681+
if self._torch_compile_enabled:
647682
# Disable cuda graph capture here so that we can properly capture it later
648683
with self.no_cuda_graph():
649-
for warmup_num_tokens, warmup_num_gen_tokens in warmup_requests:
650-
651-
with release_batch(
652-
get_warmup_request(warmup_num_tokens,
653-
warmup_num_gen_tokens)) as batch:
654-
if batch is None:
655-
# No KV cache space!
656-
continue
657-
logger.info(
658-
f"Run warmup with {warmup_num_tokens} tokens, include {warmup_num_gen_tokens} generation tokens"
659-
)
660-
self.forward(batch,
661-
new_tensors_device=None,
662-
resource_manager=resource_manager)
663-
torch.cuda.synchronize()
684+
# From small case to large to make sure the 1 token case is run first.
685+
# If the first graph is not the 1 token case, dynamo will specialize the non-1 token case.
686+
general_warmup()
664687

665688
if self.pytorch_backend_config.enable_autotuner:
666689
# handle multiple rank issue
@@ -767,6 +790,29 @@ def _update_draft_inference_state(is_first_draft: bool,
767790
gc.collect()
768791
torch.cuda.empty_cache()
769792

793+
# When using piecewise cuda graph, the logits may suffer severe memory faction problem.
794+
# When the num of requests is growing, the block allocated by torch cannot be reused.
795+
# So after piecewise cuda graph capture, a request with most requests is triggered to make
796+
# sure that large enough blocks are allocated and can be correctly reused.
797+
for num_tokens in piecewise_cuda_graph_num_tokens:
798+
batch = get_warmup_request(num_tokens, 0, least_requests=False)
799+
if batch is None:
800+
continue
801+
with release_batch(batch) as batch:
802+
logger.info(
803+
f"Run piecewise CUDA graph warmup for num tokens={num_tokens} with most requests"
804+
)
805+
self.forward(batch,
806+
new_tensors_device=None,
807+
resource_manager=resource_manager)
808+
809+
torch.cuda.synchronize()
810+
811+
# Also, we run a general warmup from large to small to make sure that blocks are allocated well.
812+
# The cudagraph and piecewise cuda graph capture calls torch.cuda.empty_cache() and block may already
813+
# be freed even we calls general_warmup for torch compile.
814+
general_warmup(reverse=True)
815+
770816
# Set the value back to the original value
771817
self.enable_spec_decode = self.is_spec_decode
772818

0 commit comments

Comments
 (0)