@@ -94,7 +94,7 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_
9494 if (torch_device == 'maca' ):
9595 indices = torch .zeros ([1 ], dtype = torch .int64 ).to ('cuda' )
9696 else :
97- indices = torch .zeros ([1 ], dtype = torch .uint64 ).to (torch_device )
97+ indices = torch .zeros ([1 ], dtype = torch .int64 ).to (torch_device )
9898 x_tensor = to_tensor (data , lib )
9999 indices_tensor = to_tensor (indices , lib )
100100 indices_tensor .descriptor .contents .dt = U64 # treat int64 as uint64
@@ -170,7 +170,7 @@ def test_ascend(lib, test_cases):
170170 for (voc , random_val , topp , topk , temperature ) in test_cases :
171171 test (lib , handle , "npu" , voc , random_val , topp , topk , temperature )
172172 destroy_handle (lib , handle )
173-
173+
174174def test_maca (lib , test_cases ):
175175 device = DeviceEnum .DEVICE_MACA
176176 handle = create_handle (lib , device )
@@ -179,6 +179,13 @@ def test_maca(lib, test_cases):
179179 destroy_handle (lib , handle )
180180
181181
182+ def test_musa (lib , test_cases ):
183+ import torch_musa
184+ device = DeviceEnum .DEVICE_MUSA
185+ handle = create_handle (lib , device )
186+ for (voc , random_val , topp , topk , temperature ) in test_cases :
187+ test (lib , handle , "musa" , voc , random_val , topp , topk , temperature )
188+ destroy_handle (lib , handle )
182189
183190if __name__ == "__main__" :
184191 test_cases = [
@@ -236,6 +243,8 @@ def test_maca(lib, test_cases):
236243 test_ascend (lib , test_cases )
237244 if args .maca :
238245 test_maca (lib , test_cases )
239- if not (args .cpu or args .cuda or args .bang or args .ascend or args .maca ):
246+ if args .musa :
247+ test_musa (lib , test_cases )
248+ if not (args .cpu or args .cuda or args .bang or args .ascend or args .maca or args .musa ):
240249 test_cpu (lib , test_cases )
241250 print ("\033 [92mTest passed!\033 [0m" )
0 commit comments