@@ -157,7 +157,11 @@ class last {
157157 using data_type_c = fp16;
158158};
159159
160- template <class Test >
160+ template <class Test , gpu_arch x, mma_engine y>
161+ class KernalName {
162+
163+ };
164+ template <class Test , gpu_arch x, mma_engine y>
161165void dequantize_gemm_run (uint32_t iter) {
162166 using namespace gpu ;
163167 // Accept incoming parameters
@@ -238,16 +242,16 @@ void dequantize_gemm_run(uint32_t iter) {
238242 data_type_scale,
239243 data_type_zero_pt,
240244 quant_info,
241- mma_engine::xmx ,
242- gpu_arch::XeHpg >;
245+ y ,
246+ x >;
243247 using gemm_t = xetla::group::
244248 gemm_t <compute_policy, tile_shape, mem_desc_a_t , mem_desc_b_t >;
245249
246250 using epilogue_t = xetla::group::epilogue_t <
247- xetla::group::epilogue_policy_default<gpu_arch::XeHpg >,
251+ xetla::group::epilogue_policy_default<x >,
248252 tile_shape,
249253 mem_desc_c_t >;
250- using group_swizzle = xetla::kernel::group_swizzle_default<gpu_arch::XeHpg >;
254+ using group_swizzle = xetla::kernel::group_swizzle_default<x >;
251255 using gemm_op_t = xetla::kernel::gemm_universal_t <
252256 gpu::xetla::kernel::dispatch_policy_int4_dequantize_kslicing<
253257 group_swizzle,
@@ -366,7 +370,7 @@ void dequantize_gemm_run(uint32_t iter) {
366370 for (uint32_t i = 0 ; i < iter; i++) {
367371 prof.cpu_start ();
368372 auto e_esimd = queue.submit ([&](handler& cgh) {
369- cgh.parallel_for <Test>(nd_range, [=](nd_item<3 > item) KERNEL_MAIN {
373+ cgh.parallel_for <KernalName< Test,x,y> >(nd_range, [=](nd_item<3 > item) KERNEL_MAIN {
370374 // allocate slm and nbarrier resource
371375 slm_barrier_init<gemm_op_t >();
372376 gemm_op_t gemm_op;
@@ -433,8 +437,94 @@ template <typename T>
433437class dequantize_gemm_test : public ::testing::Test {};
434438TYPED_TEST_SUITE_P (dequantize_gemm_test);
435439
440+ template <template <gpu_arch, mma_engine, class T > class F , class G >
441+ class dispatch_arch_test
442+ {
443+ using T_RET = std::invoke_result_t <decltype (F<gpu_arch::XeHpc, mma_engine::xmx, G>::exec)>;
444+
445+ public:
446+ template <typename ... Args>
447+ static T_RET exec (Args&&... args) {
448+ // save default formatting
449+ std::ios fmt_bak (nullptr );
450+ fmt_bak.copyfmt (std::cout);
451+
452+ sycl::device device;
453+ if (!device.has (aspect::ext_intel_device_id))
454+ throw std::runtime_error (" Can not get device ID" );
455+ auto deviceID = device.get_info <ext::intel::info::device::device_id>();
456+ std::cout << " deviceID: 0x" << std::hex //
457+ << std::right << std::setfill (' 0' ) << deviceID << " \n " ;
458+
459+ // restore default formatting
460+ std::cout.copyfmt (fmt_bak);
461+ #if defined(SYCL_EXT_ONEAPI_DEVICE_ARCHITECTURE) && \
462+ SYCL_EXT_ONEAPI_DEVICE_ARCHITECTURE
463+ // https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_device_architecture.asciidoc#feature-test-macro
464+ try {
465+ namespace ENS = sycl::ext::oneapi::experimental;
466+ auto deviceArch = device.get_info <ENS::info::device::architecture>();
467+ switch (deviceArch) {
468+ case ENS::architecture::intel_gpu_pvc:
469+ return F<gpu_arch::XeHpc, mma_engine::xmx, G>::exec (std::forward<Args>(args)...);
470+ return ;
471+ case ENS::architecture::intel_gpu_dg2_g10:
472+ case ENS::architecture::intel_gpu_dg2_g11:
473+ case ENS::architecture::intel_gpu_dg2_g12:
474+ return F<gpu_arch::XeHpg, mma_engine::xmx, G>::exec (std::forward<Args>(args)...);
475+ default :
476+ break ;
477+ }
478+ }
479+ catch (...) {
480+ std::cout << " Execption occurred! Please check one api versions." ;
481+ }
482+ #endif
483+ std::cout << " No matching architecture, checking device ID ...\n " ;
484+ switch (deviceID) {
485+ // MTL devices
486+ case 0x7d55 : // Intel® Arc ™ Graphics
487+ std::cout << " MTL devices identified!" << std::endl;
488+ return F<gpu_arch::XeLpg, mma_engine::fpu, G>::exec (std::forward<Args>(args)...);
489+ // DG2 devices
490+ case 0x56a0 : // Intel® Arc ™ A770 Graphics
491+ case 0x56a1 : // Intel® Arc ™ A750 Graphics
492+ case 0x56a2 : // Intel® Arc ™ A580 Graphics
493+ case 0x5690 : // Intel® Arc ™ A770M Graphics
494+ case 0x5691 : // Intel® Arc ™ A730M Graphics
495+ case 0x5692 : // Intel® Arc ™ A550M Graphics
496+ return F<gpu_arch::XeHpg, mma_engine::xmx, G>::exec (std::forward<Args>(args)...);
497+ // PVC devices
498+ case 0x0bda : //
499+ return F<gpu_arch::XeHpc, mma_engine::xmx, G>::exec (std::forward<Args>(args)...);
500+ default :
501+ std::cout << " Unknown device ID \n " ;
502+ break ;
503+ }
504+
505+ if (device.has (aspect::ext_intel_gpu_eu_simd_width))
506+ throw std::runtime_error (" Can not get eu_simd_width" );
507+ auto eu_simd_width =
508+ device.get_info <ext::intel::info::device::gpu_eu_simd_width>();
509+ if (eu_simd_width == 8 ) {
510+ return F<gpu_arch::XeHpg, mma_engine::xmx, G>::exec (std::forward<Args>(args)...);
511+ } else if (eu_simd_width == 16 ) {
512+ return F<gpu_arch::XeHpc, mma_engine::xmx, G>::exec (std::forward<Args>(args)...);
513+ } else {
514+ throw std::runtime_error (" Can not get device ID" );
515+ }
516+ }
517+ };
518+
519+ template <gpu_arch arch_tag, mma_engine engine_tag, typename T>
520+ struct main_wrapper {
521+ static constexpr auto exec = []() {
522+ dequantize_gemm_run<T, arch_tag, engine_tag>(ITER);
523+ };
524+ };
525+
436526TYPED_TEST_P (dequantize_gemm_test, esimd) {
437- dequantize_gemm_run< TypeParam>(ITER );
527+ dispatch_arch_test<main_wrapper, TypeParam>:: exec ( );
438528}
439529
440530REGISTER_TYPED_TEST_SUITE_P (dequantize_gemm_test, esimd);
0 commit comments