@@ -4,6 +4,7 @@ use std::cmp;
44use libc:: c_uint;
55use rustc_abi:: { BackendRepr , HasDataLayout , Primitive , Reg , RegKind , Size } ;
66use rustc_codegen_ssa:: MemFlags ;
7+ use rustc_codegen_ssa:: common:: TypeKind ;
78use rustc_codegen_ssa:: mir:: operand:: { OperandRef , OperandValue } ;
89use rustc_codegen_ssa:: mir:: place:: { PlaceRef , PlaceValue } ;
910use rustc_codegen_ssa:: traits:: * ;
@@ -308,7 +309,7 @@ impl<'ll, 'tcx> ArgAbiBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
308309}
309310
310311pub ( crate ) trait FnAbiLlvmExt < ' ll , ' tcx > {
311- fn llvm_type ( & self , cx : & CodegenCx < ' ll , ' tcx > ) -> & ' ll Type ;
312+ fn llvm_type ( & self , cx : & CodegenCx < ' ll , ' tcx > , name : & [ u8 ] ) -> & ' ll Type ;
312313 fn ptr_to_llvm_type ( & self , cx : & CodegenCx < ' ll , ' tcx > ) -> & ' ll Type ;
313314 fn llvm_cconv ( & self , cx : & CodegenCx < ' ll , ' tcx > ) -> llvm:: CallConv ;
314315
@@ -325,26 +326,45 @@ pub(crate) trait FnAbiLlvmExt<'ll, 'tcx> {
325326}
326327
327328impl < ' ll , ' tcx > FnAbiLlvmExt < ' ll , ' tcx > for FnAbi < ' tcx , Ty < ' tcx > > {
328- fn llvm_type ( & self , cx : & CodegenCx < ' ll , ' tcx > ) -> & ' ll Type {
329+ fn llvm_type ( & self , cx : & CodegenCx < ' ll , ' tcx > , name : & [ u8 ] ) -> & ' ll Type {
329330 // Ignore "extra" args from the call site for C variadic functions.
330331 // Only the "fixed" args are part of the LLVM function signature.
331332 let args =
332333 if self . c_variadic { & self . args [ ..self . fixed_count as usize ] } else { & self . args } ;
333334
335+ // todo(sayantn): a better way is to look at the `link_name` instead of the function name, because function name can be "faked" using `#[export_name]`
336+ let llvm_intrinsic = name. starts_with ( b"llvm." )
337+ && !self . c_variadic
338+ && self . conv == Conv :: C
339+ && !self . can_unwind ;
340+ let amx_intrinsic =
341+ llvm_intrinsic && name. starts_with ( b"llvm.x86." ) && name. ends_with ( b".internal" ) ;
342+ let adjust_ty = |ty| {
343+ // Change type to `x86amx` from `i32x256` for x86_64 AMX intrinsics
344+ if amx_intrinsic && cx. type_kind ( ty) == TypeKind :: Vector && cx. vector_length ( ty) == 256
345+ {
346+ let element_ty = cx. element_type ( ty) ;
347+ if cx. type_kind ( element_ty) == TypeKind :: Integer && cx. int_width ( element_ty) == 32 {
348+ return cx. type_x86amx ( ) ;
349+ }
350+ }
351+ ty
352+ } ;
353+
334354 // This capacity calculation is approximate.
335355 let mut llargument_tys = Vec :: with_capacity (
336356 self . args . len ( ) + if let PassMode :: Indirect { .. } = self . ret . mode { 1 } else { 0 } ,
337357 ) ;
338358
339- let llreturn_ty = match & self . ret . mode {
359+ let llreturn_ty = adjust_ty ( match & self . ret . mode {
340360 PassMode :: Ignore => cx. type_void ( ) ,
341361 PassMode :: Direct ( _) | PassMode :: Pair ( ..) => self . ret . layout . immediate_llvm_type ( cx) ,
342362 PassMode :: Cast { cast, pad_i32 : _ } => cast. llvm_type ( cx) ,
343363 PassMode :: Indirect { .. } => {
344364 llargument_tys. push ( cx. type_ptr ( ) ) ;
345365 cx. type_void ( )
346366 }
347- } ;
367+ } ) ;
348368
349369 for arg in args {
350370 // Note that the exact number of arguments pushed here is carefully synchronized with
@@ -388,7 +408,7 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
388408 cast. llvm_type ( cx)
389409 }
390410 } ;
391- llargument_tys. push ( llarg_ty) ;
411+ llargument_tys. push ( adjust_ty ( llarg_ty) ) ;
392412 }
393413
394414 if self . c_variadic {
0 commit comments