@@ -78,6 +78,7 @@ pub(crate) const ARGUMENT_BUFFER_WRAPPER_STRUCT: &str = "NagaArgumentBufferWrapp
7878/// allowing them to be conveniently passed to user-defined or wrapper
7979/// functions. The struct is declared in [`Writer::write_type_defs`].
8080pub ( crate ) const EXTERNAL_TEXTURE_WRAPPER_STRUCT : & str = "NagaExternalTextureWrapper" ;
81+ pub ( crate ) const COOPERATIVE_LOAD_FUNCTION : & str = "NagaCooperativeLoad" ;
8182pub ( crate ) const COOPERATIVE_MULTIPLY_ADD_FUNCTION : & str = "NagaCooperativeMultiplyAdd" ;
8283
8384/// Write the Metal name for a Naga numeric type: scalar, vector, or matrix.
@@ -484,6 +485,12 @@ enum WrappedFunction {
484485 ImageQuerySize {
485486 class : crate :: ImageClass ,
486487 } ,
488+ CooperativeLoad {
489+ space : crate :: AddressSpace ,
490+ columns : crate :: CooperativeSize ,
491+ rows : crate :: CooperativeSize ,
492+ scalar : crate :: Scalar ,
493+ } ,
487494 CooperativeMultiplyAdd {
488495 space : crate :: AddressSpace ,
489496 columns : crate :: CooperativeSize ,
@@ -2842,6 +2849,17 @@ impl<W: Write> Writer<W> {
28422849 }
28432850 write ! ( self . out, "}}" ) ?;
28442851 }
2852+ crate :: Expression :: CooperativeLoad { ref data, .. } => {
2853+ if context. lang_version < ( 2 , 3 ) {
2854+ return Err ( Error :: UnsupportedCooperativeMatrix ) ;
2855+ }
2856+ write ! ( self . out, "{COOPERATIVE_LOAD_FUNCTION}(" ) ?;
2857+ write ! ( self . out, "&" ) ?;
2858+ self . put_access_chain ( data. pointer , context. policies . index , context) ?;
2859+ write ! ( self . out, ", " ) ?;
2860+ self . put_expression ( data. stride , context, true ) ?;
2861+ write ! ( self . out, ", {})" , data. row_major) ?;
2862+ }
28452863 crate :: Expression :: CooperativeMultiplyAdd { a, b, c } => {
28462864 if context. lang_version < ( 2 , 3 ) {
28472865 return Err ( Error :: UnsupportedCooperativeMatrix ) ;
@@ -4235,25 +4253,18 @@ impl<W: Write> Writer<W> {
42354253 }
42364254 writeln ! ( self . out, ");" ) ?;
42374255 }
4238- crate :: Statement :: CooperativeLoadStore {
4239- store,
4240- target,
4241- pointer,
4242- stride,
4243- row_major,
4244- } => {
4245- let op_str = if store { "store" } else { "load" } ;
4246- write ! ( self . out, "{level}simdgroup_{op_str}(" ) ?;
4256+ crate :: Statement :: CooperativeStore { target, ref data } => {
4257+ write ! ( self . out, "{level}simdgroup_store(" ) ?;
42474258 self . put_expression ( target, & context. expression , true ) ?;
42484259 write ! ( self . out, ", &" ) ?;
42494260 self . put_access_chain (
4250- pointer,
4261+ data . pointer ,
42514262 context. expression . policies . index ,
42524263 & context. expression ,
42534264 ) ?;
42544265 write ! ( self . out, ", " ) ?;
4255- self . put_expression ( stride, & context. expression , true ) ?;
4256- if row_major {
4266+ self . put_expression ( data . stride , & context. expression , true ) ?;
4267+ if data . row_major {
42574268 let matrix_origin = "0" ;
42584269 let transpose = true ;
42594270 write ! ( self . out, ", {matrix_origin}, {transpose}" ) ?;
@@ -6316,6 +6327,55 @@ template <typename A>
63166327 Ok ( ( ) )
63176328 }
63186329
6330+ fn write_wrapped_cooperative_load (
6331+ & mut self ,
6332+ module : & crate :: Module ,
6333+ func_ctx : & back:: FunctionCtx ,
6334+ columns : crate :: CooperativeSize ,
6335+ rows : crate :: CooperativeSize ,
6336+ pointer : Handle < crate :: Expression > ,
6337+ ) -> BackendResult {
6338+ let ptr_ty = func_ctx. resolve_type ( pointer, & module. types ) ;
6339+ let space = ptr_ty. pointer_space ( ) . unwrap ( ) ;
6340+ let scalar = ptr_ty
6341+ . pointer_base_type ( )
6342+ . unwrap ( )
6343+ . inner_with ( & module. types )
6344+ . scalar ( )
6345+ . unwrap ( ) ;
6346+ let wrapped = WrappedFunction :: CooperativeLoad {
6347+ space,
6348+ columns,
6349+ rows,
6350+ scalar,
6351+ } ;
6352+ if !self . wrapped_functions . insert ( wrapped) {
6353+ return Ok ( ( ) ) ;
6354+ }
6355+ let space_name = space. to_msl_name ( ) . unwrap_or_default ( ) ;
6356+ let scalar_name = scalar. to_msl_name ( ) ;
6357+ writeln ! (
6358+ self . out,
6359+ "{NAMESPACE}::simdgroup_{scalar_name}{}x{} {COOPERATIVE_LOAD_FUNCTION}(const {space_name} {scalar_name}* ptr, int stride, bool is_row_major) {{" ,
6360+ columns as u32 , rows as u32 ,
6361+ ) ?;
6362+ let l1 = back:: Level ( 1 ) ;
6363+ writeln ! (
6364+ self . out,
6365+ "{l1}{NAMESPACE}::simdgroup_{scalar_name}{}x{} m;" ,
6366+ columns as u32 , rows as u32
6367+ ) ?;
6368+ let matrix_origin = "0" ;
6369+ writeln ! (
6370+ self . out,
6371+ "{l1}simdgroup_load(m, ptr, stride, {matrix_origin}, is_row_major);"
6372+ ) ?;
6373+ writeln ! ( self . out, "{l1}return m;" ) ?;
6374+ writeln ! ( self . out, "}}" ) ?;
6375+ writeln ! ( self . out) ?;
6376+ Ok ( ( ) )
6377+ }
6378+
63196379 fn write_wrapped_cooperative_multiply_add (
63206380 & mut self ,
63216381 module : & crate :: Module ,
@@ -6441,6 +6501,20 @@ template <typename A>
64416501 crate :: Expression :: ImageQuery { image, query } => {
64426502 self . write_wrapped_image_query ( module, func_ctx, image, query) ?;
64436503 }
6504+ crate :: Expression :: CooperativeLoad {
6505+ columns,
6506+ rows,
6507+ role : _,
6508+ ref data,
6509+ } => {
6510+ self . write_wrapped_cooperative_load (
6511+ module,
6512+ func_ctx,
6513+ columns,
6514+ rows,
6515+ data. pointer ,
6516+ ) ?;
6517+ }
64446518 crate :: Expression :: CooperativeMultiplyAdd { a, b, c : _ } => {
64456519 let space = crate :: AddressSpace :: Private ;
64466520 self . write_wrapped_cooperative_multiply_add ( module, func_ctx, space, a, b) ?;
0 commit comments