Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
73384a6
delete impl
yuanlehome Oct 13, 2025
fe92435
delete min_length&max_length
yuanlehome Oct 13, 2025
1b289b6
support limit thinking content strategy
yuanlehome Oct 13, 2025
1912e72
fix
yuanlehome Oct 13, 2025
fe0fee8
fix
yuanlehome Oct 13, 2025
81674a5
fix
yuanlehome Oct 13, 2025
3282a2f
update
yuanlehome Oct 14, 2025
406676d
Merge branch 'develop' into upgrade_limit_think_length
yuanlehome Oct 14, 2025
02281b7
Merge branch 'develop' of https://github.com/PaddlePaddle/FastDeploy …
yuanlehome Oct 14, 2025
6f1f082
fix set_value_by_flags_and_idx
yuanlehome Oct 14, 2025
31aa8ee
fix
yuanlehome Oct 15, 2025
8421f31
Merge branch 'develop' into upgrade_limit_think_length
yuanlehome Oct 15, 2025
bc60b26
fix
yuanlehome Oct 15, 2025
fe43788
Merge branch 'upgrade_limit_think_length' of https://github.com/yuanl…
yuanlehome Oct 15, 2025
61d9b72
fix
yuanlehome Oct 15, 2025
06b5441
fix
yuanlehome Oct 15, 2025
dcc8dca
Merge branch 'develop' into upgrade_limit_think_length
yuanlehome Oct 15, 2025
db20c22
Merge branch 'develop' of https://github.com/PaddlePaddle/FastDeploy …
yuanlehome Oct 16, 2025
948555d
Merge branch 'develop' of https://github.com/PaddlePaddle/FastDeploy …
yuanlehome Oct 16, 2025
36ed90d
update
yuanlehome Oct 16, 2025
2f8aa11
fix
yuanlehome Oct 16, 2025
324d17e
fix
yuanlehome Oct 16, 2025
9caf6f3
Merge branch 'develop' of https://github.com/PaddlePaddle/FastDeploy …
yuanlehome Oct 16, 2025
0710f34
fix typo
yuanlehome Oct 16, 2025
141608f
fix ci
yuanlehome Oct 17, 2025
e247d2a
Merge branch 'develop' of https://github.com/PaddlePaddle/FastDeploy …
yuanlehome Oct 17, 2025
565c4d1
Merge branch 'develop' of https://github.com/PaddlePaddle/FastDeploy …
yuanlehome Oct 17, 2025
9bb4629
fix
yuanlehome Oct 17, 2025
41ef32c
fix
yuanlehome Oct 17, 2025
bc43254
Merge branch 'develop' of https://github.com/PaddlePaddle/FastDeploy …
yuanlehome Oct 20, 2025
849eaa6
support mtp
yuanlehome Oct 20, 2025
1fff8f3
Merge branch 'develop' of https://github.com/PaddlePaddle/FastDeploy …
yuanlehome Oct 20, 2025
393d830
fix
yuanlehome Oct 20, 2025
2e0f607
fix
yuanlehome Oct 20, 2025
4fd1dde
update
yuanlehome Oct 20, 2025
0a0571c
Merge branch 'develop' of https://github.com/PaddlePaddle/FastDeploy …
yuanlehome Oct 20, 2025
db4e279
update
yuanlehome Oct 20, 2025
ba3cf37
Merge branch 'develop' of https://github.com/PaddlePaddle/FastDeploy …
yuanlehome Oct 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,38 @@ void SaveOutMmsgStatic(const paddle::Tensor& x,
int64_t rank_id,
bool save_each_rank);

void LimitThinkingContentLengthV1(const paddle::Tensor &next_tokens,
const paddle::Tensor &max_think_lens,
const paddle::Tensor &step_idx,
const paddle::Tensor &limit_think_status,
const int64_t think_end_id);

void LimitThinkingContentLengthV2(const paddle::Tensor &next_tokens,
const paddle::Tensor &max_think_lens,
const paddle::Tensor &step_idx,
const paddle::Tensor &limit_think_status,
const int64_t think_end_id,
const int64_t line_break_id);

void SpeculateLimitThinkingContentLengthV1(
const paddle::Tensor& next_tokens,
const paddle::Tensor& max_think_lens,
const paddle::Tensor& step_idx,
const paddle::Tensor& limit_think_status,
const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder,
const int64_t think_end_id);

void SpeculateLimitThinkingContentLengthV2(
const paddle::Tensor& next_tokens,
const paddle::Tensor& max_think_lens,
const paddle::Tensor& step_idx,
const paddle::Tensor& limit_think_status,
const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder,
const int64_t think_end_id,
const int64_t line_break_id);

void SpeculateGetLogits(const paddle::Tensor &draft_logits,
const paddle::Tensor &next_token_num,
const paddle::Tensor &batch_token_num,
Expand Down Expand Up @@ -1320,6 +1352,14 @@ PYBIND11_MODULE(fastdeploy_ops, m) {

m.def("save_output", &SaveOutMmsgStatic, "save_output function");

m.def("limit_thinking_content_length_v1", &LimitThinkingContentLengthV1, "limit_thinking_content_length_v1 function");

m.def("limit_thinking_content_length_v2", &LimitThinkingContentLengthV2, "limit_thinking_content_length_v2 function");

m.def("speculate_limit_thinking_content_length_v1", &SpeculateLimitThinkingContentLengthV1, "speculate limit thinking content length function");

m.def("speculate_limit_thinking_content_length_v2", &SpeculateLimitThinkingContentLengthV2, "speculate limit thinking content length function");

m.def("speculate_get_logits", &SpeculateGetLogits, "speculate_get_logits function");

m.def("speculate_insert_first_token", &SpeculateInsertFirstToken, "speculate_insert_first_token function");
Expand Down
88 changes: 88 additions & 0 deletions custom_ops/gpu_ops/limit_thinking_content_length_v1.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "helper.h"
#include "paddle/extension.h"

__global__ void limit_thinking_content_length_kernel_v1(
int64_t *next_tokens,
const int *max_think_lens,
const int64_t *step_idx,
int *limit_think_status,
const int64_t think_end_id,
const int bs) {
int bid = threadIdx.x;
if (bid >= bs) return;

// 如果该序列未启用思考功能,则直接返回,默认值为 -1,表示不限制思考长度
const int max_think_len = max_think_lens[bid];
if (max_think_len < 0) return;
int current_limit_think_status = limit_think_status[bid];
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行.
if (current_limit_think_status == 2) {
return;
}

int64_t next_token = next_tokens[bid];
const int64_t step = step_idx[bid];

// ======================= 思考阶段控制 =======================
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束
if (current_limit_think_status < 1) {
// 当开启思考长度控制时,检查是否超时
if (step >= max_think_len) {
// 强制将当前token替换为结束思考的token
next_token = think_end_id;
// 将状态推进到 1, 表示 "正在结束思考"
current_limit_think_status = 1;
}
}
// ======================= 思考结束处理 =======================
// 阶段 2: 检查是否已满足结束思考的条件 (status < 2)
// 这种情况会处理两种场景:
// 1. status == 0: 模型自己生成了 think_end_id
// 2. status == 1: 上一阶段强制注入了 think_end_id
if (current_limit_think_status < 2) {
if (next_token == think_end_id) {
// 确认思考结束,将状态推进到 2 (响应阶段)
current_limit_think_status = 2;
}
}
// 写回更新后的 token
next_tokens[bid] = next_token;
// 更新全局状态
limit_think_status[bid] = current_limit_think_status;
}

void LimitThinkingContentLengthV1(const paddle::Tensor &next_tokens,
const paddle::Tensor &max_think_lens,
const paddle::Tensor &step_idx,
const paddle::Tensor &limit_think_status,
const int64_t think_end_id) {
const int batch_size = next_tokens.shape()[0];
limit_thinking_content_length_kernel_v1<<<1, 1024>>>(
const_cast<int64_t *>(next_tokens.data<int64_t>()),
max_think_lens.data<int>(),
step_idx.data<int64_t>(),
const_cast<int *>(limit_think_status.data<int>()),
think_end_id,
batch_size);
}

PD_BUILD_OP(limit_thinking_content_length_v1)
.Inputs({"next_tokens", "max_think_lens", "step_idx", "limit_think_status"})
.Attrs({"think_end_id: int64_t"})
.Outputs({"next_tokens_out"})
.SetInplaceMap({{"next_tokens", "next_tokens_out"}})
.SetKernelFn(PD_KERNEL(LimitThinkingContentLengthV1));
111 changes: 111 additions & 0 deletions custom_ops/gpu_ops/limit_thinking_content_length_v2.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "helper.h"
#include "paddle/extension.h"

// status == 0: 正常生成阶段
// status == 1: 替换阶段
// status == 2: 替换结束阶段
// status == 3: 思考结束阶段
__global__ void limit_thinking_content_length_kernel_v2(
int64_t *next_tokens,
const int *max_think_lens,
const int64_t *step_idx,
int *limit_think_status,
const int64_t think_end_id,
const int64_t line_break_id,
const int bs) {
int bid = threadIdx.x;
if (bid >= bs) return;
// 如果该序列未启用思考功能,则直接返回,默认值为 -1,表示不限制思考长度
const int max_think_len = max_think_lens[bid];
if (max_think_len < 0) return;
int current_limit_think_status = limit_think_status[bid];
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行.
if (current_limit_think_status == 3) {
return;
}

int64_t next_token = next_tokens[bid];
const int64_t step = step_idx[bid];

// ======================= 思考阶段控制 =======================
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束
// 阶段 2: 在替换 (status == 1), 检查是否替换结束
if (current_limit_think_status <= 1) {
// 当开启思考长度控制时,检查是否超时
if (step == max_think_len) {
// 强制将当前token替换为结束思考的token
next_token = line_break_id;
current_limit_think_status = 1;
} else if (step == max_think_len + 1) {
// 强制将当前token替换为结束思考的token
next_token = think_end_id;
current_limit_think_status = 1;
} else if (step == max_think_len + 2) {
// 强制将当前token替换为结束思考的token
next_token = line_break_id;
current_limit_think_status = 1;
} else if (step == max_think_len + 3) {
// 强制将当前token替换为结束思考的token
next_token = line_break_id;
// 将状态推进到 1, 表示 "正在结束思考"
current_limit_think_status = 2;
}
}
// ======================= 思考结束处理 =======================
// 阶段 3: 检查是否已满足结束思考的条件 (status == 0 || status == 2)
// 这种情况会处理两种场景:
// 1. status == 0: 模型可能自己生成了 </think>
// 2. status == 2: 上一阶段强制注入了 \n</think>\n\n
if (current_limit_think_status == 0) {
if (next_token == think_end_id) {
// 确认思考结束,将状态推进到 3 (响应阶段)
current_limit_think_status = 3;
}
}
if (current_limit_think_status == 2) {
// 确认思考结束,将状态推进到 3 (响应阶段)
current_limit_think_status = 3;
}
// 写回更新后的 token
next_tokens[bid] = next_token;
// 更新全局状态
limit_think_status[bid] = current_limit_think_status;
}

void LimitThinkingContentLengthV2(const paddle::Tensor &next_tokens,
const paddle::Tensor &max_think_lens,
const paddle::Tensor &step_idx,
const paddle::Tensor &limit_think_status,
const int64_t think_end_id,
const int64_t line_break_id) {
const int batch_size = next_tokens.shape()[0];
limit_thinking_content_length_kernel_v2<<<1, 1024>>>(
const_cast<int64_t *>(next_tokens.data<int64_t>()),
max_think_lens.data<int>(),
step_idx.data<int64_t>(),
const_cast<int *>(limit_think_status.data<int>()),
think_end_id,
line_break_id,
batch_size);
}

PD_BUILD_OP(limit_thinking_content_length_v2)
.Inputs({"next_tokens", "max_think_lens", "step_idx", "limit_think_status"})
.Attrs({"think_end_id: int64_t", "line_break_id: int64_t"})
.Outputs({"next_tokens_out"})
.SetInplaceMap({{"next_tokens", "next_tokens_out"}})
.SetKernelFn(PD_KERNEL(LimitThinkingContentLengthV2));
2 changes: 1 addition & 1 deletion custom_ops/gpu_ops/set_value_by_flags_and_idx.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ __global__ void set_value_by_flag_and_id(const bool *stop_flags,
const int seq_len_dec = seq_lens_decoder[tid];
const int seq_len_enc = seq_lens_encoder[tid];
if (seq_len_dec == 0 && seq_len_enc == 0) return; // stopped
if (step_idx[tid] >= 0) {
if (step_idx[tid] > 0) {
if (seq_len_enc > 0) { // encoder, get last token accord to seq_lens_encoder
pre_ids_all_now[step_idx[tid]] = input_ids_now[seq_len_enc - 1];
} else { // decoedr, get first token
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "helper.h"
#include "paddle/extension.h"

__global__ void speculate_limit_thinking_content_length_kernel_v1(
int64_t* next_tokens,
const int* max_think_lens,
int64_t* step_idx,
int* limit_think_status,
int* accept_num,
int* seq_lens_decoder,
const int64_t think_end_id,
const int tokens_per_step,
const int bs) {
int bid = threadIdx.x;
if (bid >= bs) return;

const int original_accept_num = accept_num[bid];
if (original_accept_num <= 0) return;

// 如果该序列未启用思考功能,则直接返回,默认值为 -1,表示不限制思考长度
const int max_think_len = max_think_lens[bid];
if (max_think_len < 0) return;
int current_limit_think_status = limit_think_status[bid];
// 如果在回复阶段, 且已经触发停止标志, 则直接返回, 无需多余执行.
if (current_limit_think_status == 3) {
return;
}

int new_accept_num = original_accept_num;

const int64_t current_base_step = step_idx[bid] - original_accept_num + 1;

for (int token_offset = 0; token_offset < original_accept_num;
token_offset++) {
const int token_idx = bid * tokens_per_step + token_offset;
int64_t next_token = next_tokens[token_idx];
const int64_t current_step = current_base_step + token_offset;

bool condition_triggered = false;

// ======================= 思考阶段控制 =======================
// 阶段 1: 仍在思考 (status == 0), 检查是否需要强制结束
// 阶段 2: 在替换 (status == 1), 检查是否替换结束
if (current_limit_think_status < 1) {
// 当开启思考长度控制时,检查是否超时
if (current_step >= max_think_len) {
// 强制将当前token替换为结束思考的token
next_token = think_end_id;
current_limit_think_status = 1;
condition_triggered = true; // 因为修改了token,需要截断
}
}

// ======================= 思考结束处理 =======================
// 阶段 3: 检查是否已满足结束思考的条件 (status == 0 || status == 2)
// 这种情况会处理两种场景:
// 1. status == 0: 模型可能自己生成了 </think>
// 2. status == 2: 上一阶段强制注入了 </think>
if (current_limit_think_status < 2) {
if (next_token == think_end_id) {
// 确认思考结束,将状态推进到 2 (响应阶段)
current_limit_think_status = 2;
}
}

next_tokens[token_idx] = next_token;

if (condition_triggered) {
new_accept_num = token_offset + 1;
break;
}
}

// 更新全局状态
int discarded_tokens = original_accept_num - new_accept_num;
if (discarded_tokens > 0) {
step_idx[bid] -= discarded_tokens;
seq_lens_decoder[bid] -= discarded_tokens;
}

accept_num[bid] = new_accept_num;
limit_think_status[bid] = current_limit_think_status;
}

void SpeculateLimitThinkingContentLengthV1(
const paddle::Tensor& next_tokens,
const paddle::Tensor& max_think_lens,
const paddle::Tensor& step_idx,
const paddle::Tensor& limit_think_status,
const paddle::Tensor& accept_num,
const paddle::Tensor& seq_lens_decoder,
const int64_t think_end_id) {
const int batch_size = next_tokens.shape()[0];
const int tokens_per_step = next_tokens.shape()[1];

speculate_limit_thinking_content_length_kernel_v1<<<1, 1024>>>(
const_cast<int64_t*>(next_tokens.data<int64_t>()),
max_think_lens.data<int>(),
const_cast<int64_t*>(step_idx.data<int64_t>()),
const_cast<int*>(limit_think_status.data<int>()),
const_cast<int*>(accept_num.data<int>()),
const_cast<int*>(seq_lens_decoder.data<int>()),
think_end_id,
tokens_per_step,
batch_size);
}

PD_BUILD_STATIC_OP(speculate_limit_thinking_content_length_v1)
.Inputs({"next_tokens",
"max_think_lens",
"step_idx",
"limit_think_status",
"accept_num",
"seq_lens_decoder"})
.Attrs({"think_end_id: int64_t"})
.Outputs({"next_tokens_out"})
.SetInplaceMap({{"next_tokens", "next_tokens_out"}})
.SetKernelFn(PD_KERNEL(SpeculateLimitThinkingContentLengthV1));
Loading
Loading