1515from  .. import  exc 
1616from  .._compat  import  get_tensor_descriptor_fn_name 
1717from  .ast_extension  import  expr_from_string 
18+ from  .ast_extension  import  statement_from_string 
1819from  .compile_environment  import  CompileEnvironment 
1920from  .device_function  import  DeviceFunction 
2021from  .host_function  import  HostFunction 
@@ -353,7 +354,6 @@ def codegen_load(
353354            )
354355        assert  extra_mask  is  None 
355356        indexing  =  BlockedSubscriptIndexing .create (state , fake_tensor , subscript )
356- 
357357        # Load from tensor descriptor with permuted offsets 
358358        load_expr  =  expr_from_string (
359359            f"{ indexing .tensor_descriptor (state )}  .load({ indexing .offsets_str_permuted (state )}  )" 
@@ -383,10 +383,12 @@ def codegen_store(
383383            )
384384        assert  extra_mask  is  None 
385385        indexing  =  BlockedSubscriptIndexing .create (state , fake_tensor , subscript )
386+         store_value  =  indexing .reshape_store (state , value )
386387
388+         config  =  DeviceFunction .current ().config 
389+         epilogue_subtiles  =  state .config .epilogue_subtiling 
387390        # Apply permutation to the value being stored if needed 
388391        desc_arg  =  indexing .tensor_descriptor_arg (state )
389-         store_value  =  indexing .reshape_store (state , value )
390392
391393        if  desc_arg .permutation  is  not   None :
392394            # Apply permutation to the value 
@@ -395,11 +397,110 @@ def codegen_store(
395397                store_val = store_value ,
396398            )
397399
400+         if  (idx  :=  state .device_function .device_store_index ) <  len (epilogue_subtiles ):
401+             subtile_split  =  epilogue_subtiles [idx ]
402+             state .device_function .device_store_index  +=  1 
403+ 
404+             subtile_codegen  =  self ._codegen_epilogue_subtile_store (
405+                 state , fake_tensor , indexing , store_value , subtile_split , config 
406+             )
407+             if  subtile_codegen  is  not   None :
408+                 return  subtile_codegen 
409+ 
398410        return  expr_from_string (
399411            f"{ indexing .tensor_descriptor (state )}  .store({ indexing .offsets_str_permuted (state )}  , {{value}})" ,
400412            value = store_value ,
401413        )
402414
415+     def  _codegen_epilogue_subtile_store (
416+         self ,
417+         state : CodegenState ,
418+         fake_tensor : torch .Tensor ,
419+         indexing : BlockedSubscriptIndexing ,
420+         store_value : ast .AST ,
421+         subtile_split : int ,
422+         config : Config ,
423+     ) ->  ast .AST  |  None :
424+         # Currently support 2D tiles without permutations 
425+         if  (
426+             len (indexing .block_shape ) !=  2 
427+             or  len (indexing .offsets ) !=  2 
428+             or  subtile_split  ==  0 
429+         ):
430+             return  None 
431+ 
432+         env  =  CompileEnvironment .current ()
433+         block_m , block_n  =  indexing .block_shape 
434+         try :
435+             block_n_hint  =  env .size_hint (block_n )
436+             block_idx  =  env .get_block_id (block_n )
437+             block_size  =  env .block_sizes [block_idx ].from_config (config )
438+         except  Exception :
439+             return  None 
440+ 
441+         if  block_n_hint  %  2  !=  0  or  block_size  <=  16 :
442+             return  None 
443+ 
444+         device_fn  =  state .device_function 
445+         codegen  =  state .codegen 
446+ 
447+         block_m_str  =  device_fn .literal_expr (block_m )
448+         block_n_str  =  device_fn .literal_expr (block_n )
449+         indexing .block_shape [1 ] //=  subtile_split 
450+ 
451+         # TODO(PaulZhang12): Support more epilogue subtile configs besides 2 
452+         block_n_half_str  =  f"({ block_n_str }   // { subtile_split }  )" 
453+ 
454+         # Lift the store value into a temporary variable for reuse 
455+         acc_var  =  codegen .lift (store_value , prefix = "acc" )
456+ 
457+         reshape_expr  =  expr_from_string (
458+             "tl.reshape({acc}, [{dim_m}, 2, {dim_half}]).permute(0, 2, 1)" ,
459+             acc = acc_var ,
460+             dim_m = expr_from_string (block_m_str ),
461+             dim_half = expr_from_string (block_n_half_str ),
462+         )
463+         reshape_var  =  codegen .lift (reshape_expr , prefix = "acc" )
464+ 
465+         acc0_name  =  codegen .tmpvar (prefix = "acc" )
466+         acc1_name  =  codegen .tmpvar (prefix = "acc" )
467+         codegen .add_statement (
468+             statement_from_string (
469+                 f"{ acc0_name }  , { acc1_name }   = tl.split({{acc}})" ,
470+                 acc = reshape_var ,
471+             )
472+         )
473+         acc0  =  expr_from_string (acc0_name )
474+         acc1  =  expr_from_string (acc1_name )
475+ 
476+         desc_name  =  indexing .tensor_descriptor (state )
477+         offset0  =  expr_from_string (indexing .offsets [0 ])
478+         offset1  =  expr_from_string (indexing .offsets [1 ])
479+ 
480+         # First subtile store 
481+         codegen .add_statement (
482+             statement_from_string (
483+                 f"{ desc_name }  .store([{{off0}}, {{off1}}], {{value}})" ,
484+                 off0 = offset0 ,
485+                 off1 = offset1 ,
486+                 value = acc0 ,
487+             )
488+         )
489+ 
490+         offset1_shifted  =  expr_from_string (
491+             "({offset} + {half})" ,
492+             offset = expr_from_string (indexing .offsets [1 ]),
493+             half = expr_from_string (block_n_half_str ),
494+         )
495+ 
496+         # Emit second subtile store as the expression returned to the caller 
497+         return  expr_from_string (
498+             f"{ desc_name }  .store([{{off0}}, {{off1}}], {{value}})" ,
499+             off0 = offset0 ,
500+             off1 = offset1_shifted ,
501+             value = acc1 ,
502+         )
503+ 
403504
404505class  StackIndexingStrategy :
405506    """ 
0 commit comments