@@ -527,12 +527,16 @@ def get_cuda_graph_warmup_request(batch_size, draft_len):
527527 result = None
528528 return result
529529
530- def get_warmup_request (num_tokens : int , num_gen_tokens : int ):
530+ def get_warmup_request (num_tokens : int ,
531+ num_gen_tokens : int ,
532+ least_requests : bool = True ):
531533 available_tokens = kv_cache_manager .get_num_available_tokens (
532534 self .runtime_draft_len )
533535 available_blocks = kv_cache_manager .get_num_free_blocks ()
534536 if num_tokens > self .max_num_tokens or num_tokens > available_tokens :
535537 return None
538+ if num_gen_tokens > self .batch_size :
539+ return None
536540
537541 num_extra_decoding_steps = get_num_extra_decoding_steps ()
538542 if num_extra_decoding_steps > 0 :
@@ -550,14 +554,28 @@ def get_warmup_request(num_tokens: int, num_gen_tokens: int):
550554 num_full_seqs = 0
551555 num_left_over_tokens = 0
552556
557+ max_context_requests = self .batch_size - num_gen_tokens
558+ if max_context_requests * max_seq_len < num_ctx_tokens :
559+ return None
560+
553561 if num_ctx_tokens > 0 :
554- # We will try to assign as less context requests as possible to
555- # fill the num_ctx_tokens.
562+ if least_requests :
563+ # We will try to assign as less context requests as possible to
564+ # fill the num_ctx_tokens.
556565
557- # Num full sequences:
558- num_full_seqs = num_ctx_tokens // max_seq_len
559- num_left_over_tokens = num_ctx_tokens - num_full_seqs * max_seq_len
566+ # Num full sequences:
567+ num_full_seqs = num_ctx_tokens // max_seq_len
568+ num_left_over_tokens = num_ctx_tokens - num_full_seqs * max_seq_len
560569
570+ else :
571+ max_bs = min (num_ctx_tokens ,
572+ self .batch_size - num_gen_tokens )
573+ if num_ctx_tokens % max_bs == 0 :
574+ num_full_seqs = max_bs
575+ else :
576+ num_full_seqs = max_bs - 1
577+ max_seq_len = num_ctx_tokens // num_full_seqs
578+ num_left_over_tokens = num_ctx_tokens - max_seq_len * num_full_seqs
561579 num_ctx_requests = num_full_seqs + (1 if num_left_over_tokens
562580 > 0 else 0 )
563581
@@ -633,33 +651,38 @@ def release_batch(result: ScheduledRequests | None):
633651 if cp_type == CpType .STAR :
634652 return
635653
636- if self ._torch_compile_enabled :
637-
654+ def general_warmup (reverse : bool = False ):
638655 warmup_requests = set ([
639656 (1 , 1 ), # Specialize for 1 token.
640657 (self .batch_size ,
641658 self .batch_size ), # max_batch_size, pure generation
642659 (2 , 0 ), # Non-one, pure context
643660 (curr_max_num_tokens , 0 ), # max_num_tokens, pure context
644661 ])
662+ if reverse :
663+ warmup_requests = sorted (list (warmup_requests ), reverse = reverse )
664+
665+ for warmup_num_tokens , warmup_num_gen_tokens in warmup_requests :
666+ with release_batch (
667+ get_warmup_request (warmup_num_tokens ,
668+ warmup_num_gen_tokens )) as batch :
669+ if batch is None :
670+ # No KV cache space!
671+ continue
672+ logger .info (
673+ f"Run warmup with { warmup_num_tokens } tokens, include { warmup_num_gen_tokens } generation tokens"
674+ )
675+ self .forward (batch ,
676+ new_tensors_device = None ,
677+ resource_manager = resource_manager )
678+ torch .cuda .synchronize ()
645679
680+ if self ._torch_compile_enabled :
646681 # Disable cuda graph capture here so that we can properly capture it later
647682 with self .no_cuda_graph ():
648- for warmup_num_tokens , warmup_num_gen_tokens in warmup_requests :
649-
650- with release_batch (
651- get_warmup_request (warmup_num_tokens ,
652- warmup_num_gen_tokens )) as batch :
653- if batch is None :
654- # No KV cache space!
655- continue
656- logger .info (
657- f"Run warmup with { warmup_num_tokens } tokens, include { warmup_num_gen_tokens } generation tokens"
658- )
659- self .forward (batch ,
660- new_tensors_device = None ,
661- resource_manager = resource_manager )
662- torch .cuda .synchronize ()
683+ # From small case to large to make sure the 1 token case is run first.
684+ # If the first graph is not the 1 token case, dynamo will specialize the non-1 token case.
685+ general_warmup ()
663686
664687 if self .pytorch_backend_config .enable_autotuner :
665688 with self .no_cuda_graph (), autotune ():
@@ -763,6 +786,29 @@ def _update_draft_inference_state(is_first_draft: bool,
763786 gc .collect ()
764787 torch .cuda .empty_cache ()
765788
789+ # When using piecewise cuda graph, the logits may suffer severe memory faction problem.
790+ # When the num of requests is growing, the block allocated by torch cannot be reused.
791+ # So after piecewise cuda graph capture, a request with most requests is triggered to make
792+ # sure that large enough blocks are allocated and can be correctly reused.
793+ for num_tokens in piecewise_cuda_graph_num_tokens :
794+ batch = get_warmup_request (num_tokens , 0 , least_requests = False )
795+ if batch is None :
796+ continue
797+ with release_batch (batch ) as batch :
798+ logger .info (
799+ f"Run piecewise CUDA graph warmup for num tokens={ num_tokens } with most requests"
800+ )
801+ self .forward (batch ,
802+ new_tensors_device = None ,
803+ resource_manager = resource_manager )
804+
805+ torch .cuda .synchronize ()
806+
807+ # Also, we run a general warmup from large to small to make sure that blocks are allocated well.
808+ # The cudagraph and piecewise cuda graph capture calls torch.cuda.empty_cache() and block may already
809+ # be freed even we calls general_warmup for torch compile.
810+ general_warmup (reverse = True )
811+
766812 # Set the value back to the original value
767813 self .enable_spec_decode = self .is_spec_decode
768814
0 commit comments