Skip to content

Commit 9e2eea4

Browse files
committed
delete print
1 parent f215332 commit 9e2eea4

File tree

7 files changed

+7
-90
lines changed

7 files changed

+7
-90
lines changed

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,8 +361,7 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
361361
const paddle::Tensor &step_idx,
362362
const paddle::Tensor &stop_seqs,
363363
const paddle::Tensor &stop_seqs_len,
364-
const bool beam_search,
365-
const bool is_pooling);
364+
const bool beam_search);
366365

367366

368367
void UpdateInputes(const paddle::Tensor &stop_flags,

custom_ops/gpu_ops/stop_generation_multi_ends.cu

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,21 +38,12 @@ __global__ void set_value_by_flags(bool *stop_flags,
3838
const int stop_seqs_bs,
3939
const int stop_seqs_max_len,
4040
bool beam_search,
41-
bool prefill_one_step_stop,
42-
bool is_pooling) {
41+
bool prefill_one_step_stop) {
4342
int tid = threadIdx.x;
4443
int bid = blockIdx.x;
4544
if (tid >= stop_seqs_bs) return;
4645
if (bid < bs) {
4746
if(tid == 0){
48-
if (is_pooling)
49-
{
50-
if(prefill_one_step_stop)
51-
{
52-
stop_flags[bid] = true;
53-
}
54-
return;
55-
}
5647
if (prefill_one_step_stop) {
5748
stop_flags[bid] = true;
5849
if (seq_lens[bid] == 0) {
@@ -78,7 +69,6 @@ __global__ void set_value_by_flags(bool *stop_flags,
7869
}
7970
}
8071
// dealing stop_seqs
81-
if (is_pooling) return;
8272
const int stop_seq_len = (stop_seqs_len + bid * stop_seqs_bs)[tid];
8373
if (stop_seq_len <= 0) return;
8474
const int64_t *stop_seq_now = stop_seqs + bid * stop_seqs_bs + tid * stop_seqs_max_len;
@@ -111,8 +101,7 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
111101
const paddle::Tensor &step_idx,
112102
const paddle::Tensor &stop_seqs,
113103
const paddle::Tensor &stop_seqs_len,
114-
const bool beam_search,
115-
const bool is_pooling) {
104+
const bool beam_search) {
116105
PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64);
117106
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
118107
bool prefill_one_step_stop = false;
@@ -151,13 +140,12 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
151140
stop_seqs_bs,
152141
stop_seqs_max_len,
153142
beam_search,
154-
prefill_one_step_stop,
155-
is_pooling);
143+
prefill_one_step_stop);
156144
}
157145

158146
PD_BUILD_STATIC_OP(set_stop_value_multi_ends)
159147
.Inputs({"topk_ids", "stop_flags", "seq_lens", "end_ids", "next_tokens", "pre_ids", "step_idx", "stop_seqs", "stop_seqs_len"})
160-
.Attrs({"beam_search: bool","is_pooling:bool"})
148+
.Attrs({"beam_search: bool"})
161149
.Outputs({"topk_ids_out", "stop_flags_out", "next_tokens_out"})
162150
.SetInplaceMap({{"topk_ids", "topk_ids_out"},
163151
{"stop_flags", "stop_flags_out"},

custom_ops/gpu_ops/update_inputs_v1.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,6 @@ __global__ void update_inputs_kernel_v1(bool *not_need_stop,
118118
int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int);
119119
if (thread_idx == 0) {
120120
not_need_stop[0] = stop_sum < stop_nums[0];
121-
printf("[CUDA DEBUG] Stop sum: %lld / %lld, not_need_stop=%d\n",
122-
stop_sum, stop_nums[0], not_need_stop[0]);
123121
}
124122
}
125123

fastdeploy/model_executor/pre_and_post_process.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -824,22 +824,6 @@ def post_process_pooling(
824824
model_output.stop_flags,
825825
)
826826

827-
if current_platform.is_cuda() or current_platform.is_iluvatar() or current_platform.is_dcu():
828-
dummy_tokens = paddle.full_like(model_output.next_tokens, -1, dtype="int64")
829-
set_stop_value_multi_ends(
830-
dummy_tokens,
831-
model_output.stop_flags,
832-
model_output.seq_lens_this_time,
833-
model_output.eos_token_id,
834-
model_output.next_tokens,
835-
model_output.pre_ids,
836-
model_output.step_idx,
837-
model_output.stop_token_ids,
838-
model_output.stop_seqs_len,
839-
False,
840-
True,
841-
)
842-
843827
with paddle.framework._no_check_dy2st_diff():
844828
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
845829
dummy_sampled_tokens = paddle.full_like(model_output.next_tokens, -1, dtype="int64")

fastdeploy/output/token_processor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,6 @@ def _process_batch_output_use_zmq(self, receive_datas):
262262

263263
if task.pooling_params is not None:
264264
pooler_output = stream_data.pooler_output
265-
llm_logger.info(f"xxxxxxpooler_output:{pooler_output}")
266265
if isinstance(pooler_output, np.ndarray):
267266
pooler_output = pooler_output.tolist()
268267
result = PoolingRequestOutput(

fastdeploy/worker/gpu_model_runner.py

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2107,51 +2107,6 @@ def _pool(self, hidden_states: paddle.Tensor, num_running_requests: int) -> Opti
21072107

21082108
return pooler_output
21092109

2110-
def _schedule_cache_and_update_buffer(
2111-
self, model_forward_batch: Optional[List[Request]], num_running_request: int
2112-
) -> None:
2113-
2114-
# Update 'infer_seed' and step_cuda()
2115-
self.share_inputs["infer_seed"].add_(self.infer_seed_increment)
2116-
self.share_inputs["infer_seed"][:] %= self.MAX_INFER_SEED
2117-
2118-
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
2119-
step_cuda(
2120-
self.share_inputs,
2121-
self.cache_config.block_size,
2122-
self.cache_config.enc_dec_block_num,
2123-
self.speculative_config,
2124-
self.cache_config.enable_prefix_caching,
2125-
)
2126-
2127-
self._update_chunked_prefill(model_forward_batch)
2128-
self._add_cache(model_forward_batch)
2129-
elif self.speculative_decoding:
2130-
speculate_schedule_cache(
2131-
self.share_inputs["draft_tokens"],
2132-
self.share_inputs["block_tables"],
2133-
self.share_inputs["stop_flags"],
2134-
self.share_inputs["prompt_lens"],
2135-
self.share_inputs["seq_lens_this_time"],
2136-
self.share_inputs["seq_lens_encoder"],
2137-
self.share_inputs["seq_lens_decoder"],
2138-
self.share_inputs["step_seq_lens_decoder"],
2139-
self.share_inputs["step_draft_tokens"],
2140-
self.share_inputs["step_seq_lens_this_time"],
2141-
self.share_inputs["accept_num"],
2142-
self.share_inputs["accept_tokens"],
2143-
self.share_inputs["is_block_step"],
2144-
self.share_inputs["not_need_stop"],
2145-
self.share_inputs["stop_nums"],
2146-
self.cache_config.block_size,
2147-
self.speculative_config.num_speculative_tokens,
2148-
)
2149-
2150-
# Copy seq_lens_this_time buffer
2151-
self.seq_lens_this_time_buffer[:num_running_request].copy_(
2152-
self.share_inputs["seq_lens_this_time"][:num_running_request], False
2153-
)
2154-
21552110
def _add_cache(self, model_forward_batch) -> None:
21562111
"""
21572112
Add cache for guided decoding.

fastdeploy/worker/gpu_worker.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import gc
1818
import time
19-
import traceback
2019
from typing import List, Optional
2120

2221
import paddle
@@ -191,13 +190,8 @@ def execute_model(
191190
num_running_request: int = None,
192191
) -> Optional[ModelRunnerOutput]:
193192
""" """
194-
try:
195-
output = self.model_runner.execute_model(model_forward_batch, num_running_request)
196-
return output
197-
except Exception as e:
198-
traceback.print_exc()
199-
logger.error(f"model_runner.execute_model failed, {str(e)}")
200-
raise e
193+
output = self.model_runner.execute_model(model_forward_batch, num_running_request)
194+
return output
201195

202196
def preprocess_new_task(self, req_dicts: List[Request], num_running_requests: int) -> None:
203197
"""Process new requests and then start the decode loop

0 commit comments

Comments
 (0)