@@ -83,12 +83,18 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_
8383 )
8484 data = torch .arange (voc ).float () * 0.0001
8585 _perm = torch .randperm (voc )
86- data = data [_perm ].to (x_dtype ).to (torch_device )
86+ if (torch_device == 'maca' ):
87+ data = data [_perm ].to (x_dtype ).to ('cuda' )
88+ else :
89+ data = data [_perm ].to (x_dtype ).to (torch_device )
8790 if (topp > 0 and topk > 1 ):
8891 ans = random_sample (data .to ("cpu" ), random_val , topp , topk , voc , temperature , "cpu" )
8992 else :
9093 ans = random_sample_0 (data )
91- indices = torch .zeros ([1 ], dtype = torch .int64 ).to (torch_device )
94+ if (torch_device == 'maca' ):
95+ indices = torch .zeros ([1 ], dtype = torch .int64 ).to ('cuda' )
96+ else :
97+ indices = torch .zeros ([1 ], dtype = torch .uint64 ).to (torch_device )
9298 x_tensor = to_tensor (data , lib )
9399 indices_tensor = to_tensor (indices , lib )
94100 indices_tensor .descriptor .contents .dt = U64 # treat int64 as uint64
@@ -163,7 +169,15 @@ def test_ascend(lib, test_cases):
163169 handle = create_handle (lib , device )
164170 for (voc , random_val , topp , topk , temperature ) in test_cases :
165171 test (lib , handle , "npu" , voc , random_val , topp , topk , temperature )
166- destroy_handle (lib , handle )
172+ destroy_handle (lib , handle )
173+
174+ def test_maca (lib , test_cases ):
175+ device = DeviceEnum .DEVICE_MACA
176+ handle = create_handle (lib , device )
177+ for (voc , random_val , topp , topk , temperature ) in test_cases :
178+ test (lib , handle , "maca" , voc , random_val , topp , topk , temperature )
179+ destroy_handle (lib , handle )
180+
167181
168182
169183if __name__ == "__main__" :
@@ -220,6 +234,8 @@ def test_ascend(lib, test_cases):
220234 test_bang (lib , test_cases )
221235 if args .ascend :
222236 test_ascend (lib , test_cases )
223- if not (args .cpu or args .cuda or args .bang or args .ascend ):
237+ if args .maca :
238+ test_maca (lib , test_cases )
239+ if not (args .cpu or args .cuda or args .bang or args .ascend or args .maca ):
224240 test_cpu (lib , test_cases )
225241 print ("\033 [92mTest passed!\033 [0m" )
0 commit comments