|  | 
| 8 | 8 | 
 | 
| 9 | 9 | import sympy | 
| 10 | 10 | import torch | 
|  | 11 | +from torch._inductor.utils import triton_type | 
| 11 | 12 | 
 | 
| 12 | 13 | from .. import exc | 
| 13 | 14 | from .._compat import get_tensor_descriptor_fn_name | 
|  | 
| 19 | 20 | from .variable_origin import BlockSizeOrigin | 
| 20 | 21 | 
 | 
| 21 | 22 | if TYPE_CHECKING: | 
|  | 23 | +    from collections.abc import Sequence | 
|  | 24 | + | 
| 22 | 25 |     from ..runtime.config import Config | 
| 23 | 26 |     from .device_function import TensorDescriptorArg | 
| 24 | 27 |     from .inductor_lowering import CodegenState | 
| 25 | 28 | 
 | 
|  | 29 | +    SymIntLike = torch.SymInt | int | 
|  | 30 | +    ShapeLike = Sequence[SymIntLike] | 
|  | 31 | + | 
| 26 | 32 | 
 | 
| 27 | 33 | class IndexingStrategy: | 
| 28 | 34 |     def codegen_load( | 
| @@ -289,6 +295,153 @@ def codegen_store( | 
| 289 | 295 |         ) | 
| 290 | 296 | 
 | 
| 291 | 297 | 
 | 
|  | 298 | +class MulticastIndexingStrategy: | 
|  | 299 | +    """ | 
|  | 300 | +    Generate pointer math for multicasting load/store to several device memory pointers sharing the same indexing. | 
|  | 301 | +
 | 
|  | 302 | +    offset, mask are calculated for the tensor_like template tensor and then broadcasted to each dev_ptr | 
|  | 303 | +    , with the results stacked. | 
|  | 304 | +
 | 
|  | 305 | +    e.g. for a 1D offset tensor and a 1D dev_ptr array, the multicasted offset is: | 
|  | 306 | +    multicast_offset = dev_ptrs[:, None] + offset[None, :] | 
|  | 307 | +
 | 
|  | 308 | +    """ | 
|  | 309 | + | 
|  | 310 | +    @staticmethod | 
|  | 311 | +    def get_broadcast_str( | 
|  | 312 | +        multicast_shape: ShapeLike, | 
|  | 313 | +        subscript_shape: ShapeLike, | 
|  | 314 | +    ) -> tuple[str, str]: | 
|  | 315 | +        """ | 
|  | 316 | +        Args: | 
|  | 317 | +            multicast_shape: shape of the dev_ptr tensor. | 
|  | 318 | +            subscript_shape: shape of subscription for each individual tensor. | 
|  | 319 | +
 | 
|  | 320 | +        Returns: | 
|  | 321 | +            the broadcast str for dev_ptrs and individual tensor offset. | 
|  | 322 | +        """ | 
|  | 323 | +        multicast_broadcast_keys = [":" for _ in multicast_shape] + [ | 
|  | 324 | +            "None" for _ in subscript_shape | 
|  | 325 | +        ] | 
|  | 326 | +        multicast_broadcast = f"[{', '.join(multicast_broadcast_keys)}]" | 
|  | 327 | +        tensor_broadcast_keys = ["None" for _ in multicast_shape] + [ | 
|  | 328 | +            ":" for _ in subscript_shape | 
|  | 329 | +        ] | 
|  | 330 | +        tensor_broadcast = f"[{', '.join(tensor_broadcast_keys)}]" | 
|  | 331 | + | 
|  | 332 | +        return multicast_broadcast, tensor_broadcast | 
|  | 333 | + | 
|  | 334 | +    @staticmethod | 
|  | 335 | +    def get_mask_expr( | 
|  | 336 | +        state: CodegenState, | 
|  | 337 | +        indexing: SubscriptIndexing, | 
|  | 338 | +        multicast_shape: ShapeLike, | 
|  | 339 | +        subscript_shape: ShapeLike, | 
|  | 340 | +    ) -> ast.AST | None: | 
|  | 341 | +        multicast_broadcast, tensor_broadcast = ( | 
|  | 342 | +            MulticastIndexingStrategy.get_broadcast_str( | 
|  | 343 | +                multicast_shape, subscript_shape | 
|  | 344 | +            ) | 
|  | 345 | +        ) | 
|  | 346 | + | 
|  | 347 | +        mask_exprs = [] | 
|  | 348 | +        dev_ptr_mask_exprs = [] | 
|  | 349 | +        # Generate Mask | 
|  | 350 | + | 
|  | 351 | +        for dim, size in enumerate(multicast_shape): | 
|  | 352 | +            if ( | 
|  | 353 | +                index := CompileEnvironment.current().get_block_id(size) | 
|  | 354 | +            ) is not None and (mask_var := state.codegen.mask_var(index)) is not None: | 
|  | 355 | +                expand = state.tile_strategy.expand_str(multicast_shape, dim) | 
|  | 356 | +                dev_ptr_mask_exprs.append(f"({mask_var}{expand})") | 
|  | 357 | + | 
|  | 358 | +        if dev_ptr_mask_exprs: | 
|  | 359 | +            dev_ptr_mask_expr = f"({'&'.join(dev_ptr_mask_exprs)})" | 
|  | 360 | +            if len(dev_ptr_mask_exprs) < len(multicast_shape): | 
|  | 361 | +                dev_ptr_mask_expr = f"tl.broadcast_to({dev_ptr_mask_expr}, {state.tile_strategy.shape_str(multicast_shape)})" | 
|  | 362 | +            dev_ptr_mask_expr = f"({dev_ptr_mask_expr}){multicast_broadcast}" | 
|  | 363 | +            mask_exprs.append(dev_ptr_mask_expr) | 
|  | 364 | + | 
|  | 365 | +        if indexing.has_mask(): | 
|  | 366 | +            mask_exprs.append(f"(tensor_mask){tensor_broadcast}") | 
|  | 367 | +            return expr_from_string( | 
|  | 368 | +                "&".join(mask_exprs), tensor_mask=indexing.mask_expr | 
|  | 369 | +            ) | 
|  | 370 | +        if mask_exprs: | 
|  | 371 | +            return expr_from_string("&".join(mask_exprs)) | 
|  | 372 | +        return None | 
|  | 373 | + | 
|  | 374 | +    @staticmethod | 
|  | 375 | +    def codegen_load( | 
|  | 376 | +        state: CodegenState, | 
|  | 377 | +        multicast_tensor: tuple[torch.Tensor, torch.Tensor], | 
|  | 378 | +        dev_ptrs_ast: ast.AST, | 
|  | 379 | +        subscript: list[object], | 
|  | 380 | +        extra_mask: ast.AST | None, | 
|  | 381 | +    ) -> ast.AST: | 
|  | 382 | +        tensor_like, dev_ptrs = multicast_tensor | 
|  | 383 | +        indexing = SubscriptIndexing.create(state, tensor_like, subscript, extra_mask) | 
|  | 384 | +        subscripts_shape = SubscriptIndexing.compute_shape(tensor_like, subscript) | 
|  | 385 | +        multicast_shape = [*dev_ptrs.size()] | 
|  | 386 | + | 
|  | 387 | +        mask_expr = MulticastIndexingStrategy.get_mask_expr( | 
|  | 388 | +            state, indexing, multicast_shape, subscripts_shape | 
|  | 389 | +        ) | 
|  | 390 | +        extra = ", other=0" | 
|  | 391 | +        if mask_expr is None: | 
|  | 392 | +            mask_expr = expr_from_string("None") | 
|  | 393 | +            extra = "" | 
|  | 394 | + | 
|  | 395 | +        multicast_broadcast, tensor_broadcast = ( | 
|  | 396 | +            MulticastIndexingStrategy.get_broadcast_str( | 
|  | 397 | +                multicast_shape, subscripts_shape | 
|  | 398 | +            ) | 
|  | 399 | +        ) | 
|  | 400 | + | 
|  | 401 | +        dtype = triton_type(tensor_like.dtype) | 
|  | 402 | +        return expr_from_string( | 
|  | 403 | +            f"tl.load((base.to(tl.pointer_type({dtype}))){multicast_broadcast} + (offset){tensor_broadcast}, mask{extra})", | 
|  | 404 | +            base=dev_ptrs_ast, | 
|  | 405 | +            offset=indexing.index_expr, | 
|  | 406 | +            mask=mask_expr, | 
|  | 407 | +        ) | 
|  | 408 | + | 
|  | 409 | +    @staticmethod | 
|  | 410 | +    def codegen_store( | 
|  | 411 | +        state: CodegenState, | 
|  | 412 | +        multicast_tensor: tuple[torch.Tensor, torch.Tensor], | 
|  | 413 | +        dev_ptrs_ast: ast.AST, | 
|  | 414 | +        subscript: list[object], | 
|  | 415 | +        value: ast.AST, | 
|  | 416 | +        extra_mask: ast.AST | None, | 
|  | 417 | +    ) -> ast.AST: | 
|  | 418 | +        tensor_like, dev_ptrs = multicast_tensor | 
|  | 419 | +        indexing = SubscriptIndexing.create(state, tensor_like, subscript, extra_mask) | 
|  | 420 | +        subscripts_shape = SubscriptIndexing.compute_shape(tensor_like, subscript) | 
|  | 421 | +        multicast_shape = [*dev_ptrs.size()] | 
|  | 422 | + | 
|  | 423 | +        mask_expr = MulticastIndexingStrategy.get_mask_expr( | 
|  | 424 | +            state, indexing, multicast_shape, subscripts_shape | 
|  | 425 | +        ) | 
|  | 426 | +        if mask_expr is None: | 
|  | 427 | +            mask_expr = expr_from_string("None") | 
|  | 428 | + | 
|  | 429 | +        multicast_broadcast, tensor_broadcast = ( | 
|  | 430 | +            MulticastIndexingStrategy.get_broadcast_str( | 
|  | 431 | +                multicast_shape, subscripts_shape | 
|  | 432 | +            ) | 
|  | 433 | +        ) | 
|  | 434 | + | 
|  | 435 | +        dtype = triton_type(tensor_like.dtype) | 
|  | 436 | +        return expr_from_string( | 
|  | 437 | +            f"tl.store(base.to(tl.pointer_type({dtype})){multicast_broadcast} + (offset){tensor_broadcast}, value, mask)", | 
|  | 438 | +            base=dev_ptrs_ast, | 
|  | 439 | +            value=value, | 
|  | 440 | +            offset=indexing.index_expr, | 
|  | 441 | +            mask=mask_expr, | 
|  | 442 | +        ) | 
|  | 443 | + | 
|  | 444 | + | 
| 292 | 445 | class SubscriptIndexing(NamedTuple): | 
| 293 | 446 |     index_expr: ast.AST | 
| 294 | 447 |     mask_expr: ast.AST | 
|  | 
0 commit comments