@@ -4,6 +4,7 @@ use std::cmp;
44use libc:: c_uint;
55use rustc_abi:: { BackendRepr , HasDataLayout , Primitive , Reg , RegKind , Size } ;
66use rustc_codegen_ssa:: MemFlags ;
7+ use rustc_codegen_ssa:: common:: TypeKind ;
78use rustc_codegen_ssa:: mir:: operand:: { OperandRef , OperandValue } ;
89use rustc_codegen_ssa:: mir:: place:: { PlaceRef , PlaceValue } ;
910use rustc_codegen_ssa:: traits:: * ;
@@ -331,20 +332,37 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
331332 let args =
332333 if self . c_variadic { & self . args [ ..self . fixed_count as usize ] } else { & self . args } ;
333334
335+ let adjust_ty = |ty| {
336+ // todo: rectify this to be more selective (help wanted)
337+ let probably_unadjusted = self . conv == Conv :: C && !self . can_unwind && !self . c_variadic ;
338+ let probably_amx_intrinsic = probably_unadjusted && cx. tcx . sess . target . arch == "x86_64" ;
339+ // Change type to `x86amx` from `i32x256` for x86_64 AMX intrinsics
340+ if probably_amx_intrinsic
341+ && cx. type_kind ( ty) == TypeKind :: Vector
342+ && cx. vector_length ( ty) == 256
343+ {
344+ let element_ty = cx. element_type ( ty) ;
345+ if cx. type_kind ( element_ty) == TypeKind :: Integer && cx. int_width ( element_ty) == 32 {
346+ return cx. type_x86amx ( ) ;
347+ }
348+ }
349+ ty
350+ } ;
351+
334352 // This capacity calculation is approximate.
335353 let mut llargument_tys = Vec :: with_capacity (
336354 self . args . len ( ) + if let PassMode :: Indirect { .. } = self . ret . mode { 1 } else { 0 } ,
337355 ) ;
338356
339- let llreturn_ty = match & self . ret . mode {
357+ let llreturn_ty = adjust_ty ( match & self . ret . mode {
340358 PassMode :: Ignore => cx. type_void ( ) ,
341359 PassMode :: Direct ( _) | PassMode :: Pair ( ..) => self . ret . layout . immediate_llvm_type ( cx) ,
342360 PassMode :: Cast { cast, pad_i32 : _ } => cast. llvm_type ( cx) ,
343361 PassMode :: Indirect { .. } => {
344362 llargument_tys. push ( cx. type_ptr ( ) ) ;
345363 cx. type_void ( )
346364 }
347- } ;
365+ } ) ;
348366
349367 for arg in args {
350368 // Note that the exact number of arguments pushed here is carefully synchronized with
@@ -388,7 +406,7 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
388406 cast. llvm_type ( cx)
389407 }
390408 } ;
391- llargument_tys. push ( llarg_ty) ;
409+ llargument_tys. push ( adjust_ty ( llarg_ty) ) ;
392410 }
393411
394412 if self . c_variadic {
0 commit comments