Skip to content

Commit a560438

Browse files
Merge pull request #153 from qinyiqun/musa
feat: 增加摩尔线程大模型算子和部分传统模型算子
2 parents 93db3ad + c9ade4d commit a560438

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+2807
-12
lines changed

operatorspy/tests/add.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,16 @@ def test_bang(lib, test_cases):
115115
test(lib, handle, "mlu", c_shape, a_shape, b_shape, tensor_dtype=torch.float32, inplace=inplace)
116116
destroy_handle(lib, handle)
117117

118+
def test_musa(lib, test_cases):
119+
import torch_musa
120+
121+
device = DeviceEnum.DEVICE_MUSA
122+
handle = create_handle(lib, device)
123+
for c_shape, a_shape, b_shape, inplace in test_cases:
124+
test(lib, handle, "musa", c_shape, a_shape, b_shape, tensor_dtype=torch.float16, inplace=inplace)
125+
test(lib, handle, "musa", c_shape, a_shape, b_shape, tensor_dtype=torch.float32, inplace=inplace)
126+
destroy_handle(lib, handle)
127+
118128

119129
if __name__ == "__main__":
120130
test_cases = [
@@ -163,6 +173,8 @@ def test_bang(lib, test_cases):
163173
test_cuda(lib, test_cases)
164174
if args.bang:
165175
test_bang(lib, test_cases)
166-
if not (args.cpu or args.cuda or args.bang):
176+
if args.musa:
177+
test_musa(lib, test_cases)
178+
if not (args.cpu or args.cuda or args.bang or args.musa):
167179
test_cpu(lib, test_cases)
168180
print("\033[92mTest passed!\033[0m")

operatorspy/tests/causal_softmax.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,16 @@ def test_maca(lib, test_cases):
119119

120120
destroy_handle(lib, handle)
121121

122+
def test_musa(lib, test_cases):
123+
import torch_musa
124+
device = DeviceEnum.DEVICE_MUSA
125+
126+
handle = create_handle(lib, device)
127+
for x_shape, x_stride in test_cases:
128+
test(lib, handle, "musa", x_shape, x_stride)
129+
130+
destroy_handle(lib, handle)
131+
122132
if __name__ == "__main__":
123133
test_cases = [
124134
# x_shape, x_stride
@@ -161,6 +171,8 @@ def test_maca(lib, test_cases):
161171
test_ascend(lib, test_cases)
162172
if args.maca:
163173
test_maca(lib, test_cases)
164-
if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca):
174+
if args.musa:
175+
test_musa(lib, test_cases)
176+
if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca or args.musa):
165177
test_cpu(lib, test_cases)
166178
print("\033[92mTest passed!\033[0m")

operatorspy/tests/expand.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,16 @@ def test_bang(lib, test_cases):
133133
test(lib, handle, "mlu", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float32)
134134
destroy_handle(lib, handle)
135135

136+
def test_musa(lib, test_cases):
137+
import torch_musa
138+
139+
device = DeviceEnum.DEVICE_MUSA
140+
handle = create_handle(lib, device)
141+
for y_shape, x_shape, y_stride, x_stride in test_cases:
142+
test(lib, handle, "musa", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float16)
143+
test(lib, handle, "musa", y_shape, x_shape, y_stride, x_stride, tensor_dtype=torch.float32)
144+
destroy_handle(lib, handle)
145+
136146

137147
if __name__ == "__main__":
138148
test_cases = [
@@ -174,6 +184,8 @@ def test_bang(lib, test_cases):
174184
test_cuda(lib, test_cases)
175185
if args.bang:
176186
test_bang(lib, test_cases)
177-
if not (args.cpu or args.cuda or args.bang):
187+
if args.musa:
188+
test_musa(lib, test_cases)
189+
if not (args.cpu or args.cuda or args.bang or args.musa):
178190
test_cpu(lib, test_cases)
179191
print("\033[92mTest passed!\033[0m")

operatorspy/tests/matmul.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,37 @@ def test_maca(lib, test_cases):
325325

326326
destroy_handle(lib, handle)
327327

328+
def test_musa(lib, test_cases):
329+
import torch_musa
330+
331+
device = DeviceEnum.DEVICE_MUSA
332+
handle = create_handle(lib, device)
333+
for (
334+
alpha,
335+
beta,
336+
a_shape,
337+
b_shape,
338+
c_shape,
339+
a_stride,
340+
b_stride,
341+
c_stride,
342+
dtype,
343+
) in test_cases:
344+
test(
345+
lib,
346+
handle,
347+
"musa",
348+
alpha,
349+
beta,
350+
a_shape,
351+
b_shape,
352+
c_shape,
353+
a_stride,
354+
b_stride,
355+
c_stride,
356+
dtype,
357+
)
358+
328359
if __name__ == "__main__":
329360
test_cases = [
330361
# alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride, dtype
@@ -387,6 +418,8 @@ def test_maca(lib, test_cases):
387418
test_ascend(lib, test_cases)
388419
if args.maca:
389420
test_maca(lib, test_cases)
390-
if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca):
421+
if args.musa:
422+
test_musa(lib, test_cases)
423+
if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca or args.musa):
391424
test_cpu(lib, test_cases)
392425
print("\033[92mTest passed!\033[0m")

operatorspy/tests/random_sample.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_
9494
if(torch_device == 'maca'):
9595
indices = torch.zeros([1], dtype = torch.int64).to('cuda')
9696
else:
97-
indices = torch.zeros([1], dtype = torch.uint64).to(torch_device)
97+
indices = torch.zeros([1], dtype = torch.int64).to(torch_device)
9898
x_tensor = to_tensor(data, lib)
9999
indices_tensor = to_tensor(indices, lib)
100100
indices_tensor.descriptor.contents.dt = U64 # treat int64 as uint64
@@ -170,7 +170,7 @@ def test_ascend(lib, test_cases):
170170
for (voc, random_val, topp, topk, temperature) in test_cases:
171171
test(lib, handle, "npu", voc, random_val, topp, topk, temperature)
172172
destroy_handle(lib, handle)
173-
173+
174174
def test_maca(lib, test_cases):
175175
device = DeviceEnum.DEVICE_MACA
176176
handle = create_handle(lib, device)
@@ -179,6 +179,13 @@ def test_maca(lib, test_cases):
179179
destroy_handle(lib, handle)
180180

181181

182+
def test_musa(lib, test_cases):
183+
import torch_musa
184+
device = DeviceEnum.DEVICE_MUSA
185+
handle = create_handle(lib, device)
186+
for (voc, random_val, topp, topk, temperature) in test_cases:
187+
test(lib, handle, "musa", voc, random_val, topp, topk, temperature)
188+
destroy_handle(lib, handle)
182189

183190
if __name__ == "__main__":
184191
test_cases = [
@@ -236,6 +243,8 @@ def test_maca(lib, test_cases):
236243
test_ascend(lib, test_cases)
237244
if args.maca:
238245
test_maca(lib, test_cases)
239-
if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca):
246+
if args.musa:
247+
test_musa(lib, test_cases)
248+
if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca or args.musa):
240249
test_cpu(lib, test_cases)
241250
print("\033[92mTest passed!\033[0m")

operatorspy/tests/rearrange.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,26 @@ def test_maca(lib, test_cases):
117117
test(lib, handle, "cuda", x_shape, x_stride, y_shape, y_stride)
118118
destroy_handle(lib, handle)
119119

120+
def test_musa(lib, test_cases):
121+
import torch_musa
122+
device = DeviceEnum.DEVICE_MUSA
123+
handle = create_handle(lib, device)
124+
for test_case in test_cases:
125+
x_shape, x_stride = test_case[0]
126+
y_shape, y_stride = test_case[1]
127+
test(lib, handle, "musa", x_shape, x_stride, y_shape, y_stride)
128+
destroy_handle(lib, handle)
129+
130+
def test_musa(lib, test_cases):
131+
import torch_musa
132+
device = DeviceEnum.DEVICE_MUSA
133+
handle = create_handle(lib, device)
134+
for test_case in test_cases:
135+
x_shape, x_stride = test_case[0]
136+
y_shape, y_stride = test_case[1]
137+
test(lib, handle, "musa", x_shape, x_stride, y_shape, y_stride)
138+
destroy_handle(lib, handle)
139+
120140
if __name__ == "__main__":
121141
args = get_args()
122142
test_cases = [
@@ -156,4 +176,6 @@ def test_maca(lib, test_cases):
156176
test_ascend(lib, test_cases)
157177
if args.maca:
158178
test_maca(lib, test_cases)
179+
if args.musa:
180+
test_musa(lib, test_cases)
159181
print("\033[92mTest passed!\033[0m")

operatorspy/tests/relu.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,16 @@ def test_bang(lib, test_cases):
132132
test(lib, handle, "mlu", tensor_shape, tensor_dtype=torch.float32, inplace=inplace)
133133
destroy_handle(lib, handle)
134134

135+
def test_musa(lib, test_cases):
136+
import torch_musa
137+
138+
device = DeviceEnum.DEVICE_MUSA
139+
handle = create_handle(lib, device)
140+
for tensor_shape, inplace in test_cases:
141+
test(lib, handle, "musa", tensor_shape, tensor_dtype=torch.float16, inplace=inplace)
142+
test(lib, handle, "musa", tensor_shape, tensor_dtype=torch.float32, inplace=inplace)
143+
destroy_handle(lib, handle)
144+
135145

136146
if __name__ == "__main__":
137147
test_cases = [
@@ -172,6 +182,8 @@ def test_bang(lib, test_cases):
172182
test_cuda(lib, test_cases)
173183
if args.bang:
174184
test_bang(lib, test_cases)
175-
if not (args.cpu or args.cuda or args.bang):
185+
if args.musa:
186+
test_musa(lib, test_cases)
187+
if not (args.cpu or args.cuda or args.bang or args.musa):
176188
test_cpu(lib, test_cases)
177189
print("\033[92mTest passed!\033[0m")

operatorspy/tests/rms_norm.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,14 @@ def test_maca(lib, test_cases):
125125

126126
destroy_handle(lib, handle)
127127

128+
def test_musa(lib, test_cases):
129+
import torch_musa
130+
device = DeviceEnum.DEVICE_MUSA
131+
handle = create_handle(lib, device)
132+
for (y_shape, x_shape, w_shape, dtype, w_dtype) in test_cases:
133+
test(lib, handle, "musa", y_shape, x_shape, w_shape, dtype, w_dtype)
134+
destroy_handle(lib, handle)
135+
128136
if __name__ == "__main__":
129137
test_cases = [
130138
# y_shape, x_shape, w_shape, dtype, w_dtype
@@ -174,6 +182,8 @@ def test_maca(lib, test_cases):
174182
test_ascend(lib, test_cases)
175183
if args.maca:
176184
test_maca(lib, test_cases)
177-
if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca):
185+
if args.musa:
186+
test_musa(lib, test_cases)
187+
if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca or args.musa):
178188
test_cpu(lib, test_cases)
179189
print("\033[92mTest passed!\033[0m")

operatorspy/tests/rotary_embedding.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
7777
pos[2 * i] = posTmp[i]
7878
pos[2 * i + 1] = 0
7979
theta = 1e4
80-
if torch_device == 'mlu' or torch_device == 'npu':
80+
if torch_device == 'mlu' or torch_device == 'npu' or torch_device == 'musa':
8181
ans = rotary_embedding(t, posTmp, theta, "cpu").to(torch_device)
8282
pos = pos.to(torch_device)
8383
t = t.to(torch_device)
@@ -181,6 +181,14 @@ def test_maca(lib, test_cases) :
181181
test(lib, handle, "maca", shape, strides, dtype)
182182
destroy_handle(lib, handle)
183183

184+
def test_musa(lib, test_cases) :
185+
import torch_musa
186+
device = DeviceEnum.DEVICE_MUSA
187+
handle = create_handle(lib, device)
188+
for shape, strides, dtype in test_cases:
189+
test(lib, handle, "musa", shape, strides, dtype)
190+
destroy_handle(lib, handle)
191+
184192
if __name__ == "__main__":
185193
test_cases = [
186194
((1, 32, 128), None, torch.float16),
@@ -233,6 +241,8 @@ def test_maca(lib, test_cases) :
233241
test_ascend(lib, test_cases)
234242
if args.maca:
235243
test_maca(lib, test_cases)
236-
if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca):
244+
if args.musa:
245+
test_musa(lib, test_cases)
246+
if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca or args.musa):
237247
test_cpu(lib, test_cases)
238248
print("\033[92mTest passed!\033[0m")

operatorspy/tests/swiglu.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,20 @@ def test_maca(lib, test_cases):
262262

263263
destroy_handle(lib, handle)
264264

265+
def test_musa(lib, test_cases):
266+
import torch_musa
267+
device = DeviceEnum.DEVICE_MUSA
268+
handle = create_handle(lib, device)
269+
270+
for shape, a_stride, b_stride, c_stride, dtype in test_cases:
271+
test_out_of_place(
272+
lib, handle, "musa", shape, a_stride, b_stride, c_stride, dtype
273+
)
274+
test_in_place1(lib, handle, "musa", shape, a_stride, b_stride, dtype)
275+
test_in_place2(lib, handle, "musa", shape, a_stride, b_stride, dtype)
276+
277+
destroy_handle(lib, handle)
278+
265279

266280
if __name__ == "__main__":
267281
test_cases = [
@@ -307,4 +321,6 @@ def test_maca(lib, test_cases):
307321
test_ascend(lib, test_cases)
308322
if args.maca:
309323
test_maca(lib, test_cases)
324+
if args.musa:
325+
test_musa(lib, test_cases)
310326
print("\033[92mTest passed!\033[0m")

0 commit comments

Comments
 (0)