@@ -1861,19 +1861,9 @@ class at the server level, which is too granular for ModelRunner.
18611861 )
18621862 if self .use_cudagraph :
18631863 model_output = model_output [: self .real_token_num ]
1864- hidden_states = rebuild_padding (
1865- model_output ,
1866- self .share_inputs ["cu_seqlens_q" ],
1867- self .share_inputs ["seq_lens_this_time" ],
1868- self .share_inputs ["seq_lens_decoder" ],
1869- self .share_inputs ["seq_lens_encoder" ],
1870- (self .share_inputs ["output_padding_offset" ] if self .speculative_decoding else None ),
1871- self .model_config .max_model_len ,
1872- )
18731864
1874- # 4. Compute logits, Sample
1865+ hidden_states = model_output
18751866 if self .is_pooling_model :
1876- # num_scheduled_tokens = int(self.share_inputs["seq_lens_this_time"][:num_running_requests].sum())
18771867 pooler_output = self ._pool (hidden_states , num_running_requests )
18781868
18791869 model_output_data = ModelOutputData (
@@ -1921,158 +1911,168 @@ class at the server level, which is too granular for ModelRunner.
19211911 )
19221912
19231913 return None
1924-
19251914 else :
1926- logits = self .model .compute_logits (hidden_states )
1927-
1928- if not self .speculative_decoding :
1929- set_value_by_flags_and_idx (
1930- self .share_inputs ["pre_ids" ],
1931- self .share_inputs ["input_ids" ],
1915+ hidden_states = rebuild_padding (
1916+ model_output ,
1917+ self .share_inputs ["cu_seqlens_q" ],
19321918 self .share_inputs ["seq_lens_this_time" ],
1933- self .share_inputs ["seq_lens_encoder" ],
19341919 self .share_inputs ["seq_lens_decoder" ],
1935- self .share_inputs ["step_idx" ],
1936- self .share_inputs ["stop_flags" ],
1937- )
1938- sampler_output = self .sampler (
1939- logits ,
1940- self .sampling_metadata ,
1941- skip_idx_list ,
1942- )
1943- if self .parallel_config .tensor_parallel_size > 1 :
1944- paddle .distributed .broadcast (
1945- sampler_output .sampled_token_ids ,
1946- self .parallel_config .data_parallel_rank * self .parallel_config .tensor_parallel_size ,
1947- group = self .parallel_config .tp_group ,
1948- )
1949- else :
1950- sampler_output = self .sampler (
1951- logits ,
1952- self .sampling_metadata ,
1920+ self .share_inputs ["seq_lens_encoder" ],
1921+ (self .share_inputs ["output_padding_offset" ] if self .speculative_decoding else None ),
19531922 self .model_config .max_model_len ,
1954- self .share_inputs ,
19551923 )
1956- if self .parallel_config .tensor_parallel_size > 1 :
1957- paddle .distributed .broadcast (
1958- self .share_inputs ["accept_tokens" ],
1959- self .parallel_config .data_parallel_rank * self .parallel_config .tensor_parallel_size ,
1960- group = self .parallel_config .tp_group ,
1961- )
1962- paddle .distributed .broadcast (
1963- self .share_inputs ["accept_num" ],
1964- self .parallel_config .data_parallel_rank * self .parallel_config .tensor_parallel_size ,
1965- group = self .parallel_config .tp_group ,
1966- )
1967- paddle .distributed .broadcast (
1924+
1925+ # 4. Compute logits, Sample
1926+ logits = self .model .compute_logits (hidden_states )
1927+
1928+ if not self .speculative_decoding :
1929+ set_value_by_flags_and_idx (
1930+ self .share_inputs ["pre_ids" ],
1931+ self .share_inputs ["input_ids" ],
1932+ self .share_inputs ["seq_lens_this_time" ],
1933+ self .share_inputs ["seq_lens_encoder" ],
1934+ self .share_inputs ["seq_lens_decoder" ],
19681935 self .share_inputs ["step_idx" ],
1969- self .parallel_config .data_parallel_rank * self .parallel_config .tensor_parallel_size ,
1970- group = self .parallel_config .tp_group ,
1971- )
1972- paddle .distributed .broadcast (
19731936 self .share_inputs ["stop_flags" ],
1974- self .parallel_config .data_parallel_rank * self .parallel_config .tensor_parallel_size ,
1975- group = self .parallel_config .tp_group ,
19761937 )
1938+ sampler_output = self .sampler (
1939+ logits ,
1940+ self .sampling_metadata ,
1941+ skip_idx_list ,
1942+ )
1943+ if self .parallel_config .tensor_parallel_size > 1 :
1944+ paddle .distributed .broadcast (
1945+ sampler_output .sampled_token_ids ,
1946+ self .parallel_config .data_parallel_rank * self .parallel_config .tensor_parallel_size ,
1947+ group = self .parallel_config .tp_group ,
1948+ )
1949+ else :
1950+ sampler_output = self .sampler (
1951+ logits ,
1952+ self .sampling_metadata ,
1953+ self .model_config .max_model_len ,
1954+ self .share_inputs ,
1955+ )
1956+ if self .parallel_config .tensor_parallel_size > 1 :
1957+ paddle .distributed .broadcast (
1958+ self .share_inputs ["accept_tokens" ],
1959+ self .parallel_config .data_parallel_rank * self .parallel_config .tensor_parallel_size ,
1960+ group = self .parallel_config .tp_group ,
1961+ )
1962+ paddle .distributed .broadcast (
1963+ self .share_inputs ["accept_num" ],
1964+ self .parallel_config .data_parallel_rank * self .parallel_config .tensor_parallel_size ,
1965+ group = self .parallel_config .tp_group ,
1966+ )
1967+ paddle .distributed .broadcast (
1968+ self .share_inputs ["step_idx" ],
1969+ self .parallel_config .data_parallel_rank * self .parallel_config .tensor_parallel_size ,
1970+ group = self .parallel_config .tp_group ,
1971+ )
1972+ paddle .distributed .broadcast (
1973+ self .share_inputs ["stop_flags" ],
1974+ self .parallel_config .data_parallel_rank * self .parallel_config .tensor_parallel_size ,
1975+ group = self .parallel_config .tp_group ,
1976+ )
19771977
1978- # 5. Post Process
1979- model_output_data = ModelOutputData (
1980- next_tokens = self .share_inputs ["next_tokens" ],
1981- stop_flags = self .share_inputs ["stop_flags" ],
1982- step_idx = self .share_inputs ["step_idx" ],
1983- max_dec_len = self .share_inputs ["max_dec_len" ],
1984- pre_ids = self .share_inputs ["pre_ids" ],
1985- seq_lens_this_time = self .share_inputs ["seq_lens_this_time" ],
1986- eos_token_id = self .share_inputs ["eos_token_id" ],
1987- not_need_stop = self .share_inputs ["not_need_stop" ],
1988- input_ids = self .share_inputs ["input_ids" ],
1989- stop_nums = self .share_inputs ["stop_nums" ],
1990- seq_lens_encoder = self .share_inputs ["seq_lens_encoder" ],
1991- seq_lens_decoder = self .share_inputs ["seq_lens_decoder" ],
1992- is_block_step = self .share_inputs ["is_block_step" ],
1993- full_hidden_states = model_output ,
1994- msg_queue_id = self .parallel_config .msg_queue_id ,
1995- mp_rank = self .parallel_config .tensor_parallel_rank ,
1996- use_ep = self .parallel_config .use_ep ,
1997- draft_tokens = (self .share_inputs ["draft_tokens" ] if self .speculative_decoding else None ),
1998- actual_draft_token_num = (
1999- self .share_inputs ["actual_draft_token_num" ] if self .speculative_decoding else None
2000- ),
2001- accept_tokens = (self .share_inputs ["accept_tokens" ] if self .speculative_decoding else None ),
2002- accept_num = (self .share_inputs ["accept_num" ] if self .speculative_decoding else None ),
2003- stop_token_ids = self .share_inputs ["stop_seqs" ],
2004- stop_seqs_len = self .share_inputs ["stop_seqs_len" ],
2005- prompt_lens = self .share_inputs ["prompt_lens" ],
2006- )
1978+ # 5. Post Process
1979+ model_output_data = ModelOutputData (
1980+ next_tokens = self .share_inputs ["next_tokens" ],
1981+ stop_flags = self .share_inputs ["stop_flags" ],
1982+ step_idx = self .share_inputs ["step_idx" ],
1983+ max_dec_len = self .share_inputs ["max_dec_len" ],
1984+ pre_ids = self .share_inputs ["pre_ids" ],
1985+ seq_lens_this_time = self .share_inputs ["seq_lens_this_time" ],
1986+ eos_token_id = self .share_inputs ["eos_token_id" ],
1987+ not_need_stop = self .share_inputs ["not_need_stop" ],
1988+ input_ids = self .share_inputs ["input_ids" ],
1989+ stop_nums = self .share_inputs ["stop_nums" ],
1990+ seq_lens_encoder = self .share_inputs ["seq_lens_encoder" ],
1991+ seq_lens_decoder = self .share_inputs ["seq_lens_decoder" ],
1992+ is_block_step = self .share_inputs ["is_block_step" ],
1993+ full_hidden_states = model_output ,
1994+ msg_queue_id = self .parallel_config .msg_queue_id ,
1995+ mp_rank = self .parallel_config .tensor_parallel_rank ,
1996+ use_ep = self .parallel_config .use_ep ,
1997+ draft_tokens = (self .share_inputs ["draft_tokens" ] if self .speculative_decoding else None ),
1998+ actual_draft_token_num = (
1999+ self .share_inputs ["actual_draft_token_num" ] if self .speculative_decoding else None
2000+ ),
2001+ accept_tokens = (self .share_inputs ["accept_tokens" ] if self .speculative_decoding else None ),
2002+ accept_num = (self .share_inputs ["accept_num" ] if self .speculative_decoding else None ),
2003+ stop_token_ids = self .share_inputs ["stop_seqs" ],
2004+ stop_seqs_len = self .share_inputs ["stop_seqs_len" ],
2005+ prompt_lens = self .share_inputs ["prompt_lens" ],
2006+ )
20072007
2008- if self .speculative_config .method in ["mtp" ] and self .scheduler_config .splitwise_role == "prefill" :
2009- skip_save_output = True
2010- else :
2011- skip_save_output = False
2008+ if self .speculative_config .method in ["mtp" ] and self .scheduler_config .splitwise_role == "prefill" :
2009+ skip_save_output = True
2010+ else :
2011+ skip_save_output = False
20122012
2013- post_process (
2014- sampler_or_pooler_output = sampler_output ,
2015- model_output = model_output_data ,
2016- share_inputs = self .share_inputs ,
2017- block_size = self .cache_config .block_size ,
2018- save_each_rank = self .parallel_config .use_ep ,
2019- speculative_decoding = self .speculative_decoding ,
2020- skip_save_output = skip_save_output ,
2021- async_output_queue = self .async_output_queue ,
2022- think_end_id = self .model_config .think_end_id ,
2023- line_break_id = self .model_config .line_break_id ,
2024- )
2025- if self .guided_backend is not None and sampler_output is not None :
2026- self .sampler .post_process (sampler_output .sampled_token_ids , skip_idx_list )
2013+ post_process (
2014+ sampler_or_pooler_output = sampler_output ,
2015+ model_output = model_output_data ,
2016+ share_inputs = self .share_inputs ,
2017+ block_size = self .cache_config .block_size ,
2018+ save_each_rank = self .parallel_config .use_ep ,
2019+ speculative_decoding = self .speculative_decoding ,
2020+ skip_save_output = skip_save_output ,
2021+ async_output_queue = self .async_output_queue ,
2022+ think_end_id = self .model_config .think_end_id ,
2023+ line_break_id = self .model_config .line_break_id ,
2024+ )
2025+ if self .guided_backend is not None and sampler_output is not None :
2026+ self .sampler .post_process (sampler_output .sampled_token_ids , skip_idx_list )
2027+
2028+ # 6. Speculative decode
2029+ if self .speculative_decoding :
2030+ if self .speculative_method == "mtp" :
2031+ self .proposer .run (
2032+ full_hidden_states = model_output , step_use_cudagraph = self .forward_meta .step_use_cudagraph
2033+ )
2034+ else :
2035+ self .proposer .run (share_inputs = self .share_inputs )
20272036
2028- # 6. Speculative decode
2029- if self .speculative_decoding :
2030- if self .speculative_method == "mtp" :
2031- self .proposer .run (
2032- full_hidden_states = model_output , step_use_cudagraph = self .forward_meta .step_use_cudagraph
2037+ # 7. Update 'infer_seed' and step_cuda()
2038+ self .share_inputs ["infer_seed" ].add_ (self .infer_seed_increment )
2039+ self .share_inputs ["infer_seed" ][:] %= self .MAX_INFER_SEED
2040+ if not envs .ENABLE_V1_KVCACHE_SCHEDULER :
2041+ step_cuda (
2042+ self .share_inputs ,
2043+ self .cache_config .block_size ,
2044+ self .cache_config .enc_dec_block_num ,
2045+ self .speculative_config ,
2046+ self .cache_config .enable_prefix_caching ,
20332047 )
2034- else :
2035- self .proposer .run (share_inputs = self .share_inputs )
20362048
2037- # 7. Update 'infer_seed' and step_cuda()
2038- self .share_inputs ["infer_seed" ].add_ (self .infer_seed_increment )
2039- self .share_inputs ["infer_seed" ][:] %= self .MAX_INFER_SEED
2040- if not envs .ENABLE_V1_KVCACHE_SCHEDULER :
2041- step_cuda (
2042- self .share_inputs ,
2043- self .cache_config .block_size ,
2044- self .cache_config .enc_dec_block_num ,
2045- self .speculative_config ,
2046- self .cache_config .enable_prefix_caching ,
2047- )
2049+ self ._update_chunked_prefill (model_forward_batch )
2050+ self ._add_cache (model_forward_batch )
2051+ elif self .speculative_decoding :
2052+ speculate_schedule_cache (
2053+ self .share_inputs ["draft_tokens" ],
2054+ self .share_inputs ["block_tables" ],
2055+ self .share_inputs ["stop_flags" ],
2056+ self .share_inputs ["prompt_lens" ],
2057+ self .share_inputs ["seq_lens_this_time" ],
2058+ self .share_inputs ["seq_lens_encoder" ],
2059+ self .share_inputs ["seq_lens_decoder" ],
2060+ self .share_inputs ["step_seq_lens_decoder" ],
2061+ self .share_inputs ["step_draft_tokens" ],
2062+ self .share_inputs ["step_seq_lens_this_time" ],
2063+ self .share_inputs ["accept_num" ],
2064+ self .share_inputs ["accept_tokens" ],
2065+ self .share_inputs ["is_block_step" ],
2066+ self .share_inputs ["not_need_stop" ],
2067+ self .share_inputs ["stop_nums" ],
2068+ self .cache_config .block_size ,
2069+ self .speculative_config .num_speculative_tokens ,
2070+ )
20482071
2049- self ._update_chunked_prefill (model_forward_batch )
2050- self ._add_cache (model_forward_batch )
2051- elif self .speculative_decoding :
2052- speculate_schedule_cache (
2053- self .share_inputs ["draft_tokens" ],
2054- self .share_inputs ["block_tables" ],
2055- self .share_inputs ["stop_flags" ],
2056- self .share_inputs ["prompt_lens" ],
2057- self .share_inputs ["seq_lens_this_time" ],
2058- self .share_inputs ["seq_lens_encoder" ],
2059- self .share_inputs ["seq_lens_decoder" ],
2060- self .share_inputs ["step_seq_lens_decoder" ],
2061- self .share_inputs ["step_draft_tokens" ],
2062- self .share_inputs ["step_seq_lens_this_time" ],
2063- self .share_inputs ["accept_num" ],
2064- self .share_inputs ["accept_tokens" ],
2065- self .share_inputs ["is_block_step" ],
2066- self .share_inputs ["not_need_stop" ],
2067- self .share_inputs ["stop_nums" ],
2068- self .cache_config .block_size ,
2069- self .speculative_config .num_speculative_tokens ,
2072+ self .seq_lens_this_time_buffer [:num_running_requests ].copy_ (
2073+ self .share_inputs ["seq_lens_this_time" ][:num_running_requests ], False
20702074 )
2071-
2072- self .seq_lens_this_time_buffer [:num_running_requests ].copy_ (
2073- self .share_inputs ["seq_lens_this_time" ][:num_running_requests ], False
2074- )
2075- return None
2075+ return None
20762076
20772077 def _pool (self , hidden_states : paddle .Tensor , num_running_requests : int ) -> Optional [ModelRunnerOutput ]:
20782078
0 commit comments