@@ -292,7 +292,7 @@ void extApplyPenalty(at::Tensor& logits, const at::Tensor& presence_penalty,
292292 callDiopi (diopiApplyPenalty, logits, presence_penalty, frequency_penalty,
293293 p_token_ids, p_token_counts, p_cumsum_seq_len, p_max_len_in_batch);
294294}
295-
295+ # if 0
296296// 判断是否有对应的 diopi 实现:
297297// 如果有, 则直接 pybind 上去;
298298// 否则不注册, 等到 python 层处理.
@@ -363,6 +363,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
363363 "deeplink ext_scaled_masked_softmax_bwd");
364364 }
365365}
366+ #endif
366367
367368at::Tensor& apply_penalty (at::Tensor& logits, const at::Tensor& presence_penalty,
368369 const at::Tensor& frequency_penalty,
@@ -381,40 +382,27 @@ at::Tensor& dest_index_copy_kv(const at::Tensor& k, const at::Tensor& dest_loc,
381382 return out;
382383}
383384
384- TORCH_LIBRARY (ops, m) {
385- // m.def("adamw(Tensor(a!) input, Tensor(b!) grad, Tensor(c!) exp_avg, Tensor(d!) exp_avg_sq, Tensor(e!) max_exp_avg_sq, float lr, float beta1, float beta2, float eps, float weight_decay, int step, bool amsgrad)->(Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))");
386- m.def (" apply_penalty(Tensor(a!) logits, Tensor presence_penalty, Tensor frequency_penalty, Tensor p_token_ids, Tensor p_token_counts, Tensor p_cumsum_seq_len, int p_max_len_in_batch)->Tensor(a!)" );
387- m.def (" dest_index_copy_kv(Tensor(a!) out, Tensor k, Tensor dest_loc)->Tensor(a!)" );
385+ at::Tensor& example_for_all_backend (at::Tensor& inout) {
386+ std::cout << __FUNCTION__ << " : " << inout.options () << std::endl;
387+ return inout;
388388}
389389
390- // impl for dipu
391- TORCH_LIBRARY_IMPL (ops, XPU, m) {
392- if (reinterpret_cast <void *>(diopiApplyPenalty) != nullptr ) {
393- m.impl (" apply_penalty" , apply_penalty);
394- }
395- if (reinterpret_cast <void *>(diopiDestIndexCopyKV) != nullptr ) {
396- m.impl (" dest_index_copy_kv" , dest_index_copy_kv);
397- }
390+ at::Tensor& example_only_for_xpu (at::Tensor& inout) {
391+ std::cout << __FUNCTION__ << " : " << inout.options () << std::endl;
392+ return inout;
398393}
399394
400- // impl for torch
401- TORCH_LIBRARY_IMPL (ops, CUDA, m) {
402- if (reinterpret_cast <void *>(diopiApplyPenalty) != nullptr ) {
403- m.impl (" apply_penalty" , apply_penalty);
404- }
405- if (reinterpret_cast <void *>(diopiDestIndexCopyKV) != nullptr ) {
406- m.impl (" dest_index_copy_kv" , dest_index_copy_kv);
407- }
395+ // By default, all backends (XPU, AutocastXPU, AutoGradXPU, CUDA, PrivateUse1, AutogradPrivateUse1 etc) are registered. If you need to register separately for a certain backend, separate registration for a certain backend is also supported.
396+ TORCH_LIBRARY (deeplink_ext_, m) {
397+ m.def (" adamw(Tensor(a!) input, Tensor(b!) grad, Tensor(c!) exp_avg, Tensor(d!) exp_avg_sq, Tensor(e!) max_exp_avg_sq, float lr, float beta1, float beta2, float eps, float weight_decay, int step, bool amsgrad)->(Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))" );
398+ m.def (" apply_penalty(Tensor(a!) logits, Tensor presence_penalty, Tensor frequency_penalty, Tensor p_token_ids, Tensor p_token_counts, Tensor p_cumsum_seq_len, int p_max_len_in_batch)->Tensor(a!)" );
399+ m.def (" dest_index_copy_kv(Tensor(a!) out, Tensor k, Tensor dest_loc)->Tensor(a!)" );
400+ m.def (" example(Tensor(a!) inout)->Tensor(a!)" , example_for_all_backend);
408401}
409402
410- // impl for torch_npu
411- TORCH_LIBRARY_IMPL (ops, PrivateUse1, m) {
412- if (reinterpret_cast <void *>(diopiApplyPenalty) != nullptr ) {
413- m.impl (" apply_penalty" , apply_penalty);
414- }
415- if (reinterpret_cast <void *>(diopiDestIndexCopyKV) != nullptr ) {
416- m.impl (" dest_index_copy_kv" , dest_index_copy_kv);
417- }
403+ // only impl for dipu
404+ TORCH_LIBRARY_IMPL (deeplink_ext_, XPU, m) {
405+ // m.impl("example", example_only_for_xpu);
418406}
419407
420- } // namespace dipu::dipu_ext
408+ } // namespace dipu::dipu_ext
0 commit comments