@@ -102,6 +102,25 @@ def codegen_store(
102102    ) ->  ast .AST :
103103        indexing  =  SubscriptIndexing .create (state , fake_tensor , subscript , extra_mask )
104104        name  =  state .device_function .tensor_arg (fake_tensor ).name 
105+         
106+         # Check if value is a tensor load (Name node with id matching a tensor arg) 
107+         if  isinstance (value , ast .Name ) and  hasattr (state .device_function , '_tensor_args' ):
108+             # Check if this name corresponds to a tensor argument 
109+             for  tensor , tensor_arg  in  state .device_function ._tensor_args .items ():
110+                 if  tensor_arg .name  ==  value .id :
111+                     # This is a tensor value, we need to load from it 
112+                     # Get the shape of the slice we're storing to 
113+                     output_shape  =  SubscriptIndexing .compute_shape (fake_tensor , subscript )
114+                     if  len (output_shape ) ==  1  and  tensor .ndim  ==  1 :
115+                         # Load the entire 1D tensor 
116+                         value_indexing  =  SubscriptIndexing .create (state , tensor , [slice (None )], None )
117+                         value  =  expr_from_string (
118+                             f"tl.load({ value .id }  ,
119+                             offset = value_indexing .index_expr ,
120+                             mask = value_indexing .mask_expr ,
121+                         )
122+                     break 
123+         
105124        return  expr_from_string (
106125            f"tl.store({ name }  ,
107126            value = value ,
@@ -511,7 +530,14 @@ def compute_shape(
511530                output_size .extend (k .size ())
512531            else :
513532                raise  exc .InvalidIndexingType (k )
514-         assert  len (input_size ) ==  0 , "invalid subscript" 
533+         # For partial indexing, append remaining dimensions to output 
534+         while  input_size :
535+             size  =  input_size .popleft ()
536+             if  size  !=  1 :
537+                 rdim  =  env .allocate_reduction_dimension (size )
538+                 output_size .append (rdim .var )
539+             else :
540+                 output_size .append (1 )
515541        return  output_size 
516542
517543    @staticmethod  
@@ -648,6 +674,22 @@ def create(
648674                        )
649675            else :
650676                raise  exc .InvalidIndexingType (type (k ))
677+         
678+         # Handle remaining dimensions for partial indexing 
679+         while  len (index_values ) <  fake_value .ndim :
680+             expand  =  tile_strategy .expand_str (output_size , output_idx )
681+             size  =  fake_value .size (len (index_values ))
682+             if  size  !=  1 :
683+                 rdim  =  env .allocate_reduction_dimension (size )
684+                 block_idx  =  rdim .block_id 
685+                 index_var  =  state .codegen .index_var (block_idx )
686+                 index_values .append (f"({ index_var } { expand }  )
687+                 if  mask  :=  state .codegen .mask_var (block_idx ):
688+                     mask_values .setdefault (f"({ mask } { expand }  )
689+             else :
690+                 index_values .append (f"tl.zeros([1], { dtype } { expand }  )
691+             output_idx  +=  1 
692+             
651693        assert  len (output_size ) ==  output_idx 
652694        assert  len (index_values ) ==  fake_value .ndim 
653695        index_expr  =  []
0 commit comments