|  | 
| 26 | 26 |     from ..runtime.config import Config | 
| 27 | 27 |     from .device_function import TensorDescriptorArg | 
| 28 | 28 |     from .inductor_lowering import CodegenState | 
|  | 29 | +    from .tile_dispatch import TileStrategyDispatch | 
| 29 | 30 | 
 | 
| 30 | 31 |     SymIntLike = torch.SymInt | int | 
| 31 | 32 |     ShapeLike = Sequence[SymIntLike] | 
| @@ -61,6 +62,70 @@ def _normalize_negative_index( | 
| 61 | 62 |     return f"({state.codegen.device_function.user_sympy_expr(sympy_expr)})" | 
| 62 | 63 | 
 | 
| 63 | 64 | 
 | 
|  | 65 | +def _append_remaining_dimensions( | 
|  | 66 | +    input_size: collections.deque, | 
|  | 67 | +    output_size: list[int | torch.SymInt], | 
|  | 68 | +    env: CompileEnvironment, | 
|  | 69 | +) -> None: | 
|  | 70 | +    """Append remaining dimensions from input to output for partial indexing. | 
|  | 71 | +
 | 
|  | 72 | +    Args: | 
|  | 73 | +        input_size: Deque of remaining input dimensions | 
|  | 74 | +        output_size: List to append output dimensions to | 
|  | 75 | +        env: The compile environment | 
|  | 76 | +    """ | 
|  | 77 | +    while input_size: | 
|  | 78 | +        size = input_size.popleft() | 
|  | 79 | +        if size != 1: | 
|  | 80 | +            rdim = env.allocate_reduction_dimension(size) | 
|  | 81 | +            output_size.append(rdim.var) | 
|  | 82 | +        else: | 
|  | 83 | +            output_size.append(1) | 
|  | 84 | + | 
|  | 85 | + | 
|  | 86 | +def _handle_remaining_index_dimensions( | 
|  | 87 | +    index_values: list[str], | 
|  | 88 | +    mask_values: dict[str, None], | 
|  | 89 | +    output_size: list[int | torch.SymInt], | 
|  | 90 | +    output_idx: int, | 
|  | 91 | +    fake_value: torch.Tensor, | 
|  | 92 | +    state: CodegenState, | 
|  | 93 | +    tile_strategy: TileStrategyDispatch, | 
|  | 94 | +    env: CompileEnvironment, | 
|  | 95 | +    dtype: str, | 
|  | 96 | +) -> int: | 
|  | 97 | +    """Handle remaining dimensions for partial indexing in SubscriptIndexing.create. | 
|  | 98 | +
 | 
|  | 99 | +    Args: | 
|  | 100 | +        index_values: List to append index expressions to | 
|  | 101 | +        mask_values: Dict to add mask expressions to | 
|  | 102 | +        output_size: The output shape | 
|  | 103 | +        output_idx: Current output index | 
|  | 104 | +        fake_value: The tensor being indexed | 
|  | 105 | +        state: The codegen state | 
|  | 106 | +        tile_strategy: The tile strategy | 
|  | 107 | +        env: The compile environment | 
|  | 108 | +        dtype: The triton index type | 
|  | 109 | +
 | 
|  | 110 | +    Returns: | 
|  | 111 | +        Updated output_idx | 
|  | 112 | +    """ | 
|  | 113 | +    while len(index_values) < fake_value.ndim: | 
|  | 114 | +        expand = tile_strategy.expand_str(output_size, output_idx) | 
|  | 115 | +        size = fake_value.size(len(index_values)) | 
|  | 116 | +        if size != 1: | 
|  | 117 | +            rdim = env.allocate_reduction_dimension(size) | 
|  | 118 | +            block_idx = rdim.block_id | 
|  | 119 | +            index_var = state.codegen.index_var(block_idx) | 
|  | 120 | +            index_values.append(f"({index_var}){expand}") | 
|  | 121 | +            if mask := state.codegen.mask_var(block_idx): | 
|  | 122 | +                mask_values.setdefault(f"({mask}){expand}") | 
|  | 123 | +        else: | 
|  | 124 | +            index_values.append(f"tl.zeros([1], {dtype}){expand}") | 
|  | 125 | +        output_idx += 1 | 
|  | 126 | +    return output_idx | 
|  | 127 | + | 
|  | 128 | + | 
| 64 | 129 | class IndexingStrategy: | 
| 65 | 130 |     def codegen_load( | 
| 66 | 131 |         self, | 
| @@ -132,6 +197,32 @@ def codegen_store( | 
| 132 | 197 |     ) -> ast.AST: | 
| 133 | 198 |         indexing = SubscriptIndexing.create(state, fake_tensor, subscript, extra_mask) | 
| 134 | 199 |         name = state.device_function.tensor_arg(fake_tensor).name | 
|  | 200 | + | 
|  | 201 | +        # Check if value is a tensor load (Name node with id matching a tensor arg) | 
|  | 202 | +        if isinstance(value, ast.Name) and hasattr( | 
|  | 203 | +            state.device_function, "_tensor_args" | 
|  | 204 | +        ): | 
|  | 205 | +            # Check if this name corresponds to a tensor argument | 
|  | 206 | +            tensor = None | 
|  | 207 | +            for t, tensor_arg in state.device_function._tensor_args.items(): | 
|  | 208 | +                if tensor_arg.name == value.id: | 
|  | 209 | +                    tensor = t | 
|  | 210 | +                    break | 
|  | 211 | + | 
|  | 212 | +            if tensor is not None: | 
|  | 213 | +                # Get the shape of the slice we're storing to | 
|  | 214 | +                output_shape = SubscriptIndexing.compute_shape(fake_tensor, subscript) | 
|  | 215 | +                if len(output_shape) == 1 and tensor.ndim == 1: | 
|  | 216 | +                    # Load the entire 1D tensor | 
|  | 217 | +                    value_indexing = SubscriptIndexing.create( | 
|  | 218 | +                        state, tensor, [slice(None)], None | 
|  | 219 | +                    ) | 
|  | 220 | +                    value = expr_from_string( | 
|  | 221 | +                        f"tl.load({value.id} + offset, mask)", | 
|  | 222 | +                        offset=value_indexing.index_expr, | 
|  | 223 | +                        mask=value_indexing.mask_expr, | 
|  | 224 | +                    ) | 
|  | 225 | + | 
| 135 | 226 |         return expr_from_string( | 
| 136 | 227 |             f"tl.store({name} + offset, value, mask)", | 
| 137 | 228 |             value=value, | 
| @@ -503,7 +594,9 @@ def compute_shape( | 
| 503 | 594 |     ) -> list[int | torch.SymInt]: | 
| 504 | 595 |         assert isinstance(tensor, torch.Tensor) | 
| 505 | 596 |         assert isinstance(index, (list, tuple)), index | 
| 506 |  | -        input_size = collections.deque(tensor.size()) | 
|  | 597 | +        input_size: collections.deque[int | torch.SymInt] = collections.deque( | 
|  | 598 | +            tensor.size() | 
|  | 599 | +        ) | 
| 507 | 600 |         output_size = [] | 
| 508 | 601 |         env = CompileEnvironment.current() | 
| 509 | 602 |         for i, k in enumerate(index): | 
| @@ -547,7 +640,8 @@ def compute_shape( | 
| 547 | 640 |                 output_size.extend(k.size()) | 
| 548 | 641 |             else: | 
| 549 | 642 |                 raise exc.InvalidIndexingType(k) | 
| 550 |  | -        assert len(input_size) == 0, "invalid subscript" | 
|  | 643 | +        # For partial indexing, append remaining dimensions to output | 
|  | 644 | +        _append_remaining_dimensions(input_size, output_size, env) | 
| 551 | 645 |         return output_size | 
| 552 | 646 | 
 | 
| 553 | 647 |     @staticmethod | 
| @@ -675,6 +769,20 @@ def create( | 
| 675 | 769 |                         ) | 
| 676 | 770 |             else: | 
| 677 | 771 |                 raise exc.InvalidIndexingType(type(k)) | 
|  | 772 | + | 
|  | 773 | +        # Handle remaining dimensions for partial indexing | 
|  | 774 | +        output_idx = _handle_remaining_index_dimensions( | 
|  | 775 | +            index_values, | 
|  | 776 | +            mask_values, | 
|  | 777 | +            output_size, | 
|  | 778 | +            output_idx, | 
|  | 779 | +            fake_value, | 
|  | 780 | +            state, | 
|  | 781 | +            tile_strategy, | 
|  | 782 | +            env, | 
|  | 783 | +            dtype, | 
|  | 784 | +        ) | 
|  | 785 | + | 
| 678 | 786 |         assert len(output_size) == output_idx | 
| 679 | 787 |         assert len(index_values) == fake_value.ndim | 
| 680 | 788 |         index_expr = [] | 
| @@ -800,7 +908,9 @@ def is_supported( | 
| 800 | 908 |         if extra_mask is not None: | 
| 801 | 909 |             # TODO(jansel): support block_ptr with extra_mask | 
| 802 | 910 |             return False | 
| 803 |  | -        input_sizes = collections.deque(fake_tensor.size()) | 
|  | 911 | +        input_sizes: collections.deque[int | torch.SymInt] = collections.deque( | 
|  | 912 | +            fake_tensor.size() | 
|  | 913 | +        ) | 
| 804 | 914 |         for n, k in enumerate(index): | 
| 805 | 915 |             if k is None: | 
| 806 | 916 |                 input_size = 1 | 
|  | 
0 commit comments