Skip to content

Commit d0c203a

Browse files
support op dispatch
1 parent cf4b44f commit d0c203a

2 files changed

Lines changed: 42 additions & 30 deletions

File tree

csrc/extensions.cpp

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ void extApplyPenalty(at::Tensor& logits, const at::Tensor& presence_penalty,
292292
callDiopi(diopiApplyPenalty, logits, presence_penalty, frequency_penalty,
293293
p_token_ids, p_token_counts, p_cumsum_seq_len, p_max_len_in_batch);
294294
}
295-
295+
#if 0
296296
// 判断是否有对应的 diopi 实现:
297297
// 如果有, 则直接 pybind 上去;
298298
// 否则不注册, 等到 python 层处理.
@@ -363,6 +363,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
363363
"deeplink ext_scaled_masked_softmax_bwd");
364364
}
365365
}
366+
#endif
366367

367368
at::Tensor& apply_penalty(at::Tensor& logits, const at::Tensor& presence_penalty,
368369
const at::Tensor& frequency_penalty,
@@ -381,40 +382,27 @@ at::Tensor& dest_index_copy_kv(const at::Tensor& k, const at::Tensor& dest_loc,
381382
return out;
382383
}
383384

384-
TORCH_LIBRARY(ops, m) {
385-
//m.def("adamw(Tensor(a!) input, Tensor(b!) grad, Tensor(c!) exp_avg, Tensor(d!) exp_avg_sq, Tensor(e!) max_exp_avg_sq, float lr, float beta1, float beta2, float eps, float weight_decay, int step, bool amsgrad)->(Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))");
386-
m.def("apply_penalty(Tensor(a!) logits, Tensor presence_penalty, Tensor frequency_penalty, Tensor p_token_ids, Tensor p_token_counts, Tensor p_cumsum_seq_len, int p_max_len_in_batch)->Tensor(a!)");
387-
m.def("dest_index_copy_kv(Tensor(a!) out, Tensor k, Tensor dest_loc)->Tensor(a!)");
385+
at::Tensor& example_for_all_backend(at::Tensor& inout) {
386+
std::cout << __FUNCTION__ << ": "<< inout.options() << std::endl;
387+
return inout;
388388
}
389389

390-
// impl for dipu
391-
TORCH_LIBRARY_IMPL(ops, XPU, m) {
392-
if (reinterpret_cast<void*>(diopiApplyPenalty) != nullptr) {
393-
m.impl("apply_penalty", apply_penalty);
394-
}
395-
if (reinterpret_cast<void*>(diopiDestIndexCopyKV) != nullptr) {
396-
m.impl("dest_index_copy_kv", dest_index_copy_kv);
397-
}
390+
at::Tensor& example_only_for_xpu(at::Tensor& inout) {
391+
std::cout << __FUNCTION__ << ": " << inout.options() << std::endl;
392+
return inout;
398393
}
399394

400-
// impl for torch
401-
TORCH_LIBRARY_IMPL(ops, CUDA, m) {
402-
if (reinterpret_cast<void*>(diopiApplyPenalty) != nullptr) {
403-
m.impl("apply_penalty", apply_penalty);
404-
}
405-
if (reinterpret_cast<void*>(diopiDestIndexCopyKV) != nullptr) {
406-
m.impl("dest_index_copy_kv", dest_index_copy_kv);
407-
}
395+
// By default, all backends (XPU, AutocastXPU, AutoGradXPU, CUDA, PrivateUse1, AutogradPrivateUse1 etc) are registered. If you need to register separately for a certain backend, separate registration for a certain backend is also supported.
396+
TORCH_LIBRARY(deeplink_ext_, m) {
397+
m.def("adamw(Tensor(a!) input, Tensor(b!) grad, Tensor(c!) exp_avg, Tensor(d!) exp_avg_sq, Tensor(e!) max_exp_avg_sq, float lr, float beta1, float beta2, float eps, float weight_decay, int step, bool amsgrad)->(Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))");
398+
m.def("apply_penalty(Tensor(a!) logits, Tensor presence_penalty, Tensor frequency_penalty, Tensor p_token_ids, Tensor p_token_counts, Tensor p_cumsum_seq_len, int p_max_len_in_batch)->Tensor(a!)");
399+
m.def("dest_index_copy_kv(Tensor(a!) out, Tensor k, Tensor dest_loc)->Tensor(a!)");
400+
m.def("example(Tensor(a!) inout)->Tensor(a!)", example_for_all_backend);
408401
}
409402

410-
// impl for torch_npu
411-
TORCH_LIBRARY_IMPL(ops, PrivateUse1, m) {
412-
if (reinterpret_cast<void*>(diopiApplyPenalty) != nullptr) {
413-
m.impl("apply_penalty", apply_penalty);
414-
}
415-
if (reinterpret_cast<void*>(diopiDestIndexCopyKV) != nullptr) {
416-
m.impl("dest_index_copy_kv", dest_index_copy_kv);
417-
}
403+
// only impl for dipu
404+
TORCH_LIBRARY_IMPL(deeplink_ext_, XPU, m) {
405+
// m.impl("example", example_only_for_xpu);
418406
}
419407

420-
} // namespace dipu::dipu_ext
408+
} // namespace dipu::dipu_ext

test_dispatch.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import torch
2+
import torch_dipu
3+
import deeplink_ext
4+
torch.ops.load_library(deeplink_ext.__path__[0] + "/cpp_extensions.cpython-39-x86_64-linux-gnu.so")
5+
print(f"torch.ops.loaded_libraries:{torch.ops.loaded_libraries}")
6+
7+
#print(torch.ops.deeplink_ext_.dest_index_copy_kv)
8+
9+
def code_to_profile():
10+
x = torch.randn(3,4)
11+
y = torch.ops.deeplink_ext_.example(x)
12+
y = torch.ops.deeplink_ext_.example(x.cuda())
13+
14+
15+
with torch.profiler.profile(
16+
activities=[
17+
torch.profiler.ProfilerActivity.CPU,
18+
torch.profiler.ProfilerActivity.CUDA,
19+
]
20+
) as p:
21+
code_to_profile()
22+
print(p.key_averages().table(
23+
sort_by="self_cuda_time_total", row_limit=-1))
24+

0 commit comments

Comments
 (0)