@@ -526,12 +526,16 @@ def get_cuda_graph_warmup_request(batch_size, draft_len):
526526 result = None
527527 return result
528528
529- def get_warmup_request (num_tokens : int , num_gen_tokens : int ):
529+ def get_warmup_request (num_tokens : int ,
530+ num_gen_tokens : int ,
531+ least_requests : bool = True ):
530532 available_tokens = kv_cache_manager .get_num_available_tokens (
531533 self .runtime_draft_len )
532534 available_blocks = kv_cache_manager .get_num_free_blocks ()
533535 if num_tokens > self .max_num_tokens or num_tokens > available_tokens :
534536 return None
537+ if num_gen_tokens > self .batch_size :
538+ return None
535539
536540 num_extra_decoding_steps = get_num_extra_decoding_steps ()
537541 if num_extra_decoding_steps > 0 :
@@ -549,14 +553,28 @@ def get_warmup_request(num_tokens: int, num_gen_tokens: int):
549553 num_full_seqs = 0
550554 num_left_over_tokens = 0
551555
556+ max_context_requests = self .batch_size - num_gen_tokens
557+ if max_context_requests * max_seq_len < num_ctx_tokens :
558+ return None
559+
552560 if num_ctx_tokens > 0 :
553- # We will try to assign as less context requests as possible to
554- # fill the num_ctx_tokens.
561+ if least_requests :
562+ # We will try to assign as less context requests as possible to
563+ # fill the num_ctx_tokens.
555564
556- # Num full sequences:
557- num_full_seqs = num_ctx_tokens // max_seq_len
558- num_left_over_tokens = num_ctx_tokens - num_full_seqs * max_seq_len
565+ # Num full sequences:
566+ num_full_seqs = num_ctx_tokens // max_seq_len
567+ num_left_over_tokens = num_ctx_tokens - num_full_seqs * max_seq_len
559568
569+ else :
570+ max_bs = min (num_ctx_tokens ,
571+ self .batch_size - num_gen_tokens )
572+ if num_ctx_tokens % max_bs == 0 :
573+ num_full_seqs = max_bs
574+ else :
575+ num_full_seqs = max_bs - 1
576+ max_seq_len = num_ctx_tokens // num_full_seqs
577+ num_left_over_tokens = num_ctx_tokens - max_seq_len * num_full_seqs
560578 num_ctx_requests = num_full_seqs + (1 if num_left_over_tokens
561579 > 0 else 0 )
562580
@@ -634,33 +652,38 @@ def release_batch(result: ScheduledRequests | None):
634652 cp_type .name )
635653 return
636654
637- if self ._torch_compile_enabled :
638-
655+ def general_warmup (reverse : bool = False ):
639656 warmup_requests = set ([
640657 (1 , 1 ), # Specialize for 1 token.
641658 (self .batch_size ,
642659 self .batch_size ), # max_batch_size, pure generation
643660 (2 , 0 ), # Non-one, pure context
644661 (curr_max_num_tokens , 0 ), # max_num_tokens, pure context
645662 ])
663+ if reverse :
664+ warmup_requests = sorted (list (warmup_requests ), reverse = reverse )
665+
666+ for warmup_num_tokens , warmup_num_gen_tokens in warmup_requests :
667+ with release_batch (
668+ get_warmup_request (warmup_num_tokens ,
669+ warmup_num_gen_tokens )) as batch :
670+ if batch is None :
671+ # No KV cache space!
672+ continue
673+ logger .info (
674+ f"Run warmup with { warmup_num_tokens } tokens, include { warmup_num_gen_tokens } generation tokens"
675+ )
676+ self .forward (batch ,
677+ new_tensors_device = None ,
678+ resource_manager = resource_manager )
679+ torch .cuda .synchronize ()
646680
681+ if self ._torch_compile_enabled :
647682 # Disable cuda graph capture here so that we can properly capture it later
648683 with self .no_cuda_graph ():
649- for warmup_num_tokens , warmup_num_gen_tokens in warmup_requests :
650-
651- with release_batch (
652- get_warmup_request (warmup_num_tokens ,
653- warmup_num_gen_tokens )) as batch :
654- if batch is None :
655- # No KV cache space!
656- continue
657- logger .info (
658- f"Run warmup with { warmup_num_tokens } tokens, include { warmup_num_gen_tokens } generation tokens"
659- )
660- self .forward (batch ,
661- new_tensors_device = None ,
662- resource_manager = resource_manager )
663- torch .cuda .synchronize ()
684+ # From small case to large to make sure the 1 token case is run first.
685+ # If the first graph is not the 1 token case, dynamo will specialize the non-1 token case.
686+ general_warmup ()
664687
665688 if self .pytorch_backend_config .enable_autotuner :
666689 # handle multiple rank issue
@@ -767,6 +790,29 @@ def _update_draft_inference_state(is_first_draft: bool,
767790 gc .collect ()
768791 torch .cuda .empty_cache ()
769792
793+ # When using piecewise cuda graph, the logits may suffer severe memory faction problem.
794+ # When the num of requests is growing, the block allocated by torch cannot be reused.
795+ # So after piecewise cuda graph capture, a request with most requests is triggered to make
796+ # sure that large enough blocks are allocated and can be correctly reused.
797+ for num_tokens in piecewise_cuda_graph_num_tokens :
798+ batch = get_warmup_request (num_tokens , 0 , least_requests = False )
799+ if batch is None :
800+ continue
801+ with release_batch (batch ) as batch :
802+ logger .info (
803+ f"Run piecewise CUDA graph warmup for num tokens={ num_tokens } with most requests"
804+ )
805+ self .forward (batch ,
806+ new_tensors_device = None ,
807+ resource_manager = resource_manager )
808+
809+ torch .cuda .synchronize ()
810+
811+ # Also, we run a general warmup from large to small to make sure that blocks are allocated well.
812+ # The cudagraph and piecewise cuda graph capture calls torch.cuda.empty_cache() and block may already
813+ # be freed even we calls general_warmup for torch compile.
814+ general_warmup (reverse = True )
815+
770816 # Set the value back to the original value
771817 self .enable_spec_decode = self .is_spec_decode
772818
0 commit comments