1515from  .. import  exc 
1616from  .._compat  import  get_tensor_descriptor_fn_name 
1717from  .ast_extension  import  expr_from_string 
18+ from  .ast_extension  import  statement_from_string 
1819from  .compile_environment  import  CompileEnvironment 
1920from  .device_function  import  DeviceFunction 
2021from  .host_function  import  HostFunction 
@@ -353,7 +354,6 @@ def codegen_load(
353354            )
354355        assert  extra_mask  is  None 
355356        indexing  =  BlockedSubscriptIndexing .create (state , fake_tensor , subscript )
356- 
357357        # Load from tensor descriptor with permuted offsets 
358358        load_expr  =  expr_from_string (
359359            f"{ indexing .tensor_descriptor (state )} { indexing .offsets_str_permuted (state )}  
@@ -383,10 +383,12 @@ def codegen_store(
383383            )
384384        assert  extra_mask  is  None 
385385        indexing  =  BlockedSubscriptIndexing .create (state , fake_tensor , subscript )
386+         store_value  =  indexing .reshape_store (state , value )
386387
388+         config  =  DeviceFunction .current ().config 
389+         epilogue_subtiles  =  state .config .epilogue_subtiling 
387390        # Apply permutation to the value being stored if needed 
388391        desc_arg  =  indexing .tensor_descriptor_arg (state )
389-         store_value  =  indexing .reshape_store (state , value )
390392
391393        if  desc_arg .permutation  is  not None :
392394            # Apply permutation to the value 
@@ -395,11 +397,204 @@ def codegen_store(
395397                store_val = store_value ,
396398            )
397399
400+         if  (idx  :=  state .device_function .device_store_index ) <  len (epilogue_subtiles ):
401+             subtile_split  =  epilogue_subtiles [idx ]
402+             state .device_function .device_store_index  +=  1 
403+ 
404+             # Check if we should fuse a pointwise operation into the epilogue store 
405+             fused_pointwise_node  =  self ._get_fusable_pointwise_node (state )
406+ 
407+             subtile_codegen  =  self ._codegen_epilogue_subtile_store (
408+                 state ,
409+                 fake_tensor ,
410+                 indexing ,
411+                 store_value ,
412+                 subtile_split ,
413+                 config ,
414+                 fused_pointwise_node ,
415+             )
416+             if  subtile_codegen  is  not None :
417+                 return  subtile_codegen 
418+ 
398419        return  expr_from_string (
399420            f"{ indexing .tensor_descriptor (state )} { indexing .offsets_str_permuted (state )}  ,
400421            value = store_value ,
401422        )
402423
424+     def  _get_fusable_pointwise_node (self , state : CodegenState ) ->  torch .fx .Node  |  None :
425+         """Find a pointwise node feeding into this store that can be fused. 
426+ 
427+         Returns the pointwise FX node if found, None otherwise. 
428+         """ 
429+         if  state .fx_node  is  None :
430+             return  None 
431+ 
432+         # Get the value being stored (3rd argument to store) 
433+         if  len (state .fx_node .args ) <  3 :
434+             return  None 
435+ 
436+         value_node  =  state .fx_node .args [2 ]
437+         if  not  isinstance (value_node , torch .fx .Node ):
438+             return  None 
439+ 
440+         # Check if this is a pointwise node 
441+         from  .inductor_lowering  import  PointwiseLowering 
442+ 
443+         lowering  =  value_node .meta .get ("lowering" )
444+         if  not  isinstance (lowering , PointwiseLowering ):
445+             return  None 
446+ 
447+         # Check if this node only has one user (the store) 
448+         if  len (list (value_node .users )) !=  1 :
449+             return  None 
450+ 
451+         return  value_node 
452+ 
453+     def  _apply_pointwise_to_subtile (
454+         self , state : CodegenState , pointwise_node : torch .fx .Node , subtile_value : ast .AST 
455+     ) ->  ast .AST :
456+         """Apply a pointwise operation to a subtile value. 
457+ 
458+         Args: 
459+             state: The codegen state 
460+             pointwise_node: The FX node representing the pointwise operation 
461+             subtile_value: The AST for the subtile value to apply the operation to 
462+ 
463+         Returns: 
464+             AST for the result after applying the pointwise operation 
465+         """ 
466+         from  torch ._inductor  import  ir 
467+ 
468+         from  .inductor_lowering  import  PointwiseLowering 
469+         from  .inductor_lowering  import  install_inductor_kernel_handlers 
470+ 
471+         lowering  =  pointwise_node .meta ["lowering" ]
472+         assert  isinstance (lowering , PointwiseLowering )
473+ 
474+         # Get the pointwise buffer 
475+         buffer  =  lowering .buffer 
476+         assert  isinstance (buffer .data , ir .Pointwise )
477+ 
478+         # Create a temporary variable for the subtile 
479+         codegen  =  state .codegen 
480+         subtile_var  =  codegen .lift (subtile_value , prefix = "subtile" )
481+ 
482+         # Set up the inductor kernel handlers with the subtile as input 
483+         with  install_inductor_kernel_handlers (
484+             codegen , {lowering .input_names [0 ]: subtile_var }
485+         ):
486+             # Generate the pointwise operation 
487+             indices  =  [sympy .Symbol (f"i{ n }  ) for  n  in  range (len (buffer .data .ranges ))]
488+             from  .inductor_lowering  import  _unpack_opsvalue 
489+ 
490+             result_name  =  _unpack_opsvalue (buffer .data .inner_fn (indices ))
491+             return  expr_from_string (result_name )
492+ 
493+     def  _codegen_epilogue_subtile_store (
494+         self ,
495+         state : CodegenState ,
496+         fake_tensor : torch .Tensor ,
497+         indexing : BlockedSubscriptIndexing ,
498+         store_value : ast .AST ,
499+         subtile_split : int ,
500+         config : Config ,
501+         fused_pointwise_node : torch .fx .Node  |  None  =  None ,
502+     ) ->  ast .AST  |  None :
503+         # Currently support 2D tiles without permutations 
504+         if  (
505+             len (indexing .block_shape ) !=  2 
506+             or  len (indexing .offsets ) !=  2 
507+             or  subtile_split  ==  0 
508+         ):
509+             return  None 
510+ 
511+         env  =  CompileEnvironment .current ()
512+         block_m , block_n  =  indexing .block_shape 
513+         try :
514+             block_n_hint  =  env .size_hint (block_n )
515+             block_idx  =  env .get_block_id (block_n )
516+             block_size  =  env .block_sizes [block_idx ].from_config (config )
517+         except  Exception :
518+             return  None 
519+ 
520+         if  block_n_hint  %  2  !=  0  or  block_size  <=  16 :
521+             return  None 
522+ 
523+         device_fn  =  state .device_function 
524+         codegen  =  state .codegen 
525+ 
526+         block_m_str  =  device_fn .literal_expr (block_m )
527+         block_n_str  =  device_fn .literal_expr (block_n )
528+         indexing .block_shape [1 ] //=  subtile_split 
529+ 
530+         # TODO(PaulZhang12): Support more epilogue subtile configs besides 2 
531+         block_n_half_str  =  f"({ block_n_str } { subtile_split }  
532+ 
533+         # If we have a fused pointwise operation, mark it to skip normal codegen 
534+         # and get its input value instead 
535+         if  fused_pointwise_node  is  not None :
536+             fused_pointwise_node .meta ["fused_into_store" ] =  True 
537+ 
538+         # Lift the store value into a temporary variable for reuse 
539+         acc_var  =  codegen .lift (store_value , prefix = "acc" )
540+ 
541+         reshape_expr  =  expr_from_string (
542+             "tl.reshape({acc}, [{dim_m}, 2, {dim_half}]).permute(0, 2, 1)" ,
543+             acc = acc_var ,
544+             dim_m = expr_from_string (block_m_str ),
545+             dim_half = expr_from_string (block_n_half_str ),
546+         )
547+         reshape_var  =  codegen .lift (reshape_expr , prefix = "acc" )
548+ 
549+         acc0_name  =  codegen .tmpvar (prefix = "acc" )
550+         acc1_name  =  codegen .tmpvar (prefix = "acc" )
551+         codegen .add_statement (
552+             statement_from_string (
553+                 f"{ acc0_name } { acc1_name }  ,
554+                 acc = reshape_var ,
555+             )
556+         )
557+ 
558+         # Now apply the pointwise operation per-subtile if we have one 
559+         if  fused_pointwise_node  is  not None :
560+             acc0  =  self ._apply_pointwise_to_subtile (
561+                 state , fused_pointwise_node , expr_from_string (acc0_name )
562+             )
563+             acc1  =  self ._apply_pointwise_to_subtile (
564+                 state , fused_pointwise_node , expr_from_string (acc1_name )
565+             )
566+         else :
567+             acc0  =  expr_from_string (acc0_name )
568+             acc1  =  expr_from_string (acc1_name )
569+ 
570+         desc_name  =  indexing .tensor_descriptor (state )
571+         offset0  =  expr_from_string (indexing .offsets [0 ])
572+         offset1  =  expr_from_string (indexing .offsets [1 ])
573+ 
574+         # First subtile store 
575+         codegen .add_statement (
576+             statement_from_string (
577+                 f"{ desc_name }  ,
578+                 off0 = offset0 ,
579+                 off1 = offset1 ,
580+                 value = acc0 ,
581+             )
582+         )
583+ 
584+         offset1_shifted  =  expr_from_string (
585+             "({offset} + {half})" ,
586+             offset = expr_from_string (indexing .offsets [1 ]),
587+             half = expr_from_string (block_n_half_str ),
588+         )
589+ 
590+         # Emit second subtile store as the expression returned to the caller 
591+         return  expr_from_string (
592+             f"{ desc_name }  ,
593+             off0 = offset0 ,
594+             off1 = offset1_shifted ,
595+             value = acc1 ,
596+         )
597+ 
403598
404599class  StackIndexingStrategy :
405600    """ 
0 commit comments