@@ -572,6 +572,8 @@ def __init__(self, fd_config: FDConfig):
572572 super ().__init__ ()
573573 if current_platform .is_cuda ():
574574 self .forward = self .forward_cuda
575+ elif current_platform .is_xpu ():
576+ self .forward = self .forward_xpu
575577 else :
576578 raise NotImplementedError
577579 self .logprobs_mode = fd_config .model_config .logprobs_mode
@@ -814,6 +816,80 @@ def forward_cuda(
814816
815817 return sampler_output
816818
819+ def forward_xpu (
820+ self ,
821+ logits : paddle .Tensor ,
822+ sampling_metadata : SamplingMetadata ,
823+ max_model_len : int ,
824+ share_inputs : List [paddle .Tensor ],
825+ accept_all_drafts : bool = False ,
826+ reject_all_drafts : bool = False ,
827+ ) -> paddle .Tensor :
828+ from fastdeploy .model_executor .ops .xpu import speculate_verify , top_p_candidates
829+
830+ logits = apply_speculative_penalty_multi_scores (
831+ sampling_metadata .pre_token_ids ,
832+ logits ,
833+ sampling_metadata .repetition_penalties ,
834+ sampling_metadata .frequency_penalties ,
835+ sampling_metadata .presence_penalties ,
836+ sampling_metadata .temperature ,
837+ sampling_metadata .bad_words_token_ids ,
838+ sampling_metadata .step_idx ,
839+ sampling_metadata .min_dec_lens ,
840+ sampling_metadata .eos_token_ids ,
841+ share_inputs ["seq_lens_this_time" ],
842+ share_inputs ["output_padding_offset" ],
843+ share_inputs ["output_cum_offsets" ],
844+ max_model_len ,
845+ )
846+
847+ probs = F .softmax (logits )
848+
849+ verify_scores , verify_tokens , actual_candidate_len = top_p_candidates (
850+ probs ,
851+ sampling_metadata .top_p ,
852+ share_inputs ["output_padding_offset" ],
853+ self .speculative_max_candidate_len ,
854+ max_model_len ,
855+ )
856+
857+ speculate_verify (
858+ share_inputs ["accept_tokens" ],
859+ share_inputs ["accept_num" ],
860+ share_inputs ["step_idx" ],
861+ share_inputs ["stop_flags" ],
862+ share_inputs ["seq_lens_encoder" ],
863+ share_inputs ["seq_lens_decoder" ],
864+ share_inputs [
865+ "draft_tokens"
866+ ], # Both input and output, need to write the last 1 token accepted to position 0.
867+ share_inputs ["seq_lens_this_time" ],
868+ verify_tokens ,
869+ verify_scores ,
870+ share_inputs ["max_dec_len" ],
871+ sampling_metadata .eos_token_ids ,
872+ share_inputs ["is_block_step" ],
873+ share_inputs ["output_cum_offsets" ],
874+ actual_candidate_len ,
875+ share_inputs ["actual_draft_token_num" ],
876+ sampling_metadata .top_p ,
877+ max_model_len ,
878+ self .speculative_verify_window ,
879+ True , # enable_topp
880+ (self .speculative_benchmark_mode or reject_all_drafts ),
881+ accept_all_drafts ,
882+ )
883+ # TODO(chenhuan09): support return logprobs
884+ token_ids = share_inputs ["accept_tokens" ]
885+ sampler_output = SamplerOutput (
886+ sampled_token_ids = token_ids ,
887+ logprobs_tensors = None ,
888+ token_num_per_batch = share_inputs ["accept_num" ],
889+ cu_batch_token_offset = None ,
890+ )
891+ return sampler_output
892+
817893
818894class MTPSampler (nn .Layer ):
819895 """ """
@@ -823,6 +899,8 @@ def __init__(self, fd_config: FDConfig):
823899 super ().__init__ ()
824900 if current_platform .is_cuda ():
825901 self .forward = self .forward_cuda
902+ elif current_platform .is_xpu ():
903+ self .forward = self .forward_xpu
826904 else :
827905 raise NotImplementedError
828906 self .logprobs_mode = fd_config .model_config .logprobs_mode
@@ -1013,3 +1091,44 @@ def forward_cuda(
10131091 cu_batch_token_offset = share_inputs ["cu_batch_token_offset" ],
10141092 )
10151093 return next_tokens , sampler_output
1094+
1095+ def forward_xpu (
1096+ self ,
1097+ logits : paddle .Tensor ,
1098+ sampling_metadata : SamplingMetadata ,
1099+ max_model_len : int ,
1100+ share_inputs : List [paddle .Tensor ],
1101+ ) -> paddle .Tensor :
1102+
1103+ logits = apply_speculative_penalty_multi_scores (
1104+ sampling_metadata .pre_token_ids ,
1105+ logits ,
1106+ sampling_metadata .repetition_penalties ,
1107+ sampling_metadata .frequency_penalties ,
1108+ sampling_metadata .presence_penalties ,
1109+ sampling_metadata .temperature ,
1110+ sampling_metadata .bad_words_token_ids ,
1111+ sampling_metadata .step_idx ,
1112+ sampling_metadata .min_dec_lens ,
1113+ sampling_metadata .eos_token_ids ,
1114+ share_inputs ["seq_lens_this_time" ],
1115+ share_inputs ["output_padding_offset" ],
1116+ share_inputs ["output_cum_offsets" ],
1117+ max_model_len ,
1118+ )
1119+ probs = F .softmax (logits )
1120+
1121+ _ , next_tokens = top_k_top_p_sampling (
1122+ probs , sampling_metadata .top_p , sampling_metadata .top_k , sampling_metadata .top_k_list
1123+ )
1124+ # TODO(chenhuan09): add support for logprobs
1125+ token_ids = None
1126+ logprobs_tensors = None
1127+
1128+ sampler_output = SamplerOutput (
1129+ sampled_token_ids = token_ids ,
1130+ logprobs_tensors = logprobs_tensors ,
1131+ token_num_per_batch = None ,
1132+ cu_batch_token_offset = None ,
1133+ )
1134+ return next_tokens , sampler_output
0 commit comments