Skip to content

Commit 9f4977e

Browse files
authored
[xpu] support mtp for xpu(mix) (#5274)
* [XPU] support kernel for mtp(base) * [XPU] support kernel for mtp(base) * format * format * format * fix gather next token * fix step && add test * fix * mv pre/post process * add adjust batch / gather next token for mtp * fix code style * fix mtp kenrel name * fix mtp kernel test * mv xpu pre/post process * mv xpu pre/post process * [xpu] support mtp * fix code style
1 parent 8aec3ac commit 9f4977e

File tree

8 files changed

+691
-106
lines changed

8 files changed

+691
-106
lines changed

fastdeploy/model_executor/layers/sample/ops/apply_penalty_multi_scores.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -182,24 +182,28 @@ def apply_speculative_penalty_multi_scores(
182182
from fastdeploy.model_executor.ops.gpu import (
183183
speculate_get_token_penalty_multi_scores,
184184
)
185-
186-
speculate_get_token_penalty_multi_scores(
187-
pre_token_ids,
188-
logits,
189-
repetition_penalties,
190-
frequency_penalties,
191-
presence_penalties,
192-
temperature,
193-
bad_words_token_ids,
194-
step_idx,
195-
min_dec_lens,
196-
eos_token_ids,
197-
seq_lens_this_time,
198-
output_padding_offset,
199-
output_cum_offsets,
200-
max_len,
185+
elif current_platform.is_xpu():
186+
from fastdeploy.model_executor.ops.xpu import (
187+
speculate_get_token_penalty_multi_scores,
201188
)
189+
202190
else:
203191
raise NotImplementedError
192+
speculate_get_token_penalty_multi_scores(
193+
pre_token_ids,
194+
logits,
195+
repetition_penalties,
196+
frequency_penalties,
197+
presence_penalties,
198+
temperature,
199+
bad_words_token_ids,
200+
step_idx,
201+
min_dec_lens,
202+
eos_token_ids,
203+
seq_lens_this_time,
204+
output_padding_offset,
205+
output_cum_offsets,
206+
max_len,
207+
)
204208
# inplace
205209
return logits

fastdeploy/model_executor/layers/sample/sampler.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,8 @@ def __init__(self, fd_config: FDConfig):
572572
super().__init__()
573573
if current_platform.is_cuda():
574574
self.forward = self.forward_cuda
575+
elif current_platform.is_xpu():
576+
self.forward = self.forward_xpu
575577
else:
576578
raise NotImplementedError
577579
self.logprobs_mode = fd_config.model_config.logprobs_mode
@@ -814,6 +816,80 @@ def forward_cuda(
814816

815817
return sampler_output
816818

819+
def forward_xpu(
820+
self,
821+
logits: paddle.Tensor,
822+
sampling_metadata: SamplingMetadata,
823+
max_model_len: int,
824+
share_inputs: List[paddle.Tensor],
825+
accept_all_drafts: bool = False,
826+
reject_all_drafts: bool = False,
827+
) -> paddle.Tensor:
828+
from fastdeploy.model_executor.ops.xpu import speculate_verify, top_p_candidates
829+
830+
logits = apply_speculative_penalty_multi_scores(
831+
sampling_metadata.pre_token_ids,
832+
logits,
833+
sampling_metadata.repetition_penalties,
834+
sampling_metadata.frequency_penalties,
835+
sampling_metadata.presence_penalties,
836+
sampling_metadata.temperature,
837+
sampling_metadata.bad_words_token_ids,
838+
sampling_metadata.step_idx,
839+
sampling_metadata.min_dec_lens,
840+
sampling_metadata.eos_token_ids,
841+
share_inputs["seq_lens_this_time"],
842+
share_inputs["output_padding_offset"],
843+
share_inputs["output_cum_offsets"],
844+
max_model_len,
845+
)
846+
847+
probs = F.softmax(logits)
848+
849+
verify_scores, verify_tokens, actual_candidate_len = top_p_candidates(
850+
probs,
851+
sampling_metadata.top_p,
852+
share_inputs["output_padding_offset"],
853+
self.speculative_max_candidate_len,
854+
max_model_len,
855+
)
856+
857+
speculate_verify(
858+
share_inputs["accept_tokens"],
859+
share_inputs["accept_num"],
860+
share_inputs["step_idx"],
861+
share_inputs["stop_flags"],
862+
share_inputs["seq_lens_encoder"],
863+
share_inputs["seq_lens_decoder"],
864+
share_inputs[
865+
"draft_tokens"
866+
], # Both input and output, need to write the last 1 token accepted to position 0.
867+
share_inputs["seq_lens_this_time"],
868+
verify_tokens,
869+
verify_scores,
870+
share_inputs["max_dec_len"],
871+
sampling_metadata.eos_token_ids,
872+
share_inputs["is_block_step"],
873+
share_inputs["output_cum_offsets"],
874+
actual_candidate_len,
875+
share_inputs["actual_draft_token_num"],
876+
sampling_metadata.top_p,
877+
max_model_len,
878+
self.speculative_verify_window,
879+
True, # enable_topp
880+
(self.speculative_benchmark_mode or reject_all_drafts),
881+
accept_all_drafts,
882+
)
883+
# TODO(chenhuan09): support return logprobs
884+
token_ids = share_inputs["accept_tokens"]
885+
sampler_output = SamplerOutput(
886+
sampled_token_ids=token_ids,
887+
logprobs_tensors=None,
888+
token_num_per_batch=share_inputs["accept_num"],
889+
cu_batch_token_offset=None,
890+
)
891+
return sampler_output
892+
817893

818894
class MTPSampler(nn.Layer):
819895
""" """
@@ -823,6 +899,8 @@ def __init__(self, fd_config: FDConfig):
823899
super().__init__()
824900
if current_platform.is_cuda():
825901
self.forward = self.forward_cuda
902+
elif current_platform.is_xpu():
903+
self.forward = self.forward_xpu
826904
else:
827905
raise NotImplementedError
828906
self.logprobs_mode = fd_config.model_config.logprobs_mode
@@ -1013,3 +1091,44 @@ def forward_cuda(
10131091
cu_batch_token_offset=share_inputs["cu_batch_token_offset"],
10141092
)
10151093
return next_tokens, sampler_output
1094+
1095+
def forward_xpu(
1096+
self,
1097+
logits: paddle.Tensor,
1098+
sampling_metadata: SamplingMetadata,
1099+
max_model_len: int,
1100+
share_inputs: List[paddle.Tensor],
1101+
) -> paddle.Tensor:
1102+
1103+
logits = apply_speculative_penalty_multi_scores(
1104+
sampling_metadata.pre_token_ids,
1105+
logits,
1106+
sampling_metadata.repetition_penalties,
1107+
sampling_metadata.frequency_penalties,
1108+
sampling_metadata.presence_penalties,
1109+
sampling_metadata.temperature,
1110+
sampling_metadata.bad_words_token_ids,
1111+
sampling_metadata.step_idx,
1112+
sampling_metadata.min_dec_lens,
1113+
sampling_metadata.eos_token_ids,
1114+
share_inputs["seq_lens_this_time"],
1115+
share_inputs["output_padding_offset"],
1116+
share_inputs["output_cum_offsets"],
1117+
max_model_len,
1118+
)
1119+
probs = F.softmax(logits)
1120+
1121+
_, next_tokens = top_k_top_p_sampling(
1122+
probs, sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.top_k_list
1123+
)
1124+
# TODO(chenhuan09): add support for logprobs
1125+
token_ids = None
1126+
logprobs_tensors = None
1127+
1128+
sampler_output = SamplerOutput(
1129+
sampled_token_ids=token_ids,
1130+
logprobs_tensors=logprobs_tensors,
1131+
token_num_per_batch=None,
1132+
cu_batch_token_offset=None,
1133+
)
1134+
return next_tokens, sampler_output

0 commit comments

Comments
 (0)