|  | 
|  | 1 | +from __future__ import annotations | 
|  | 2 | + | 
|  | 3 | +import functools | 
|  | 4 | +from typing import TYPE_CHECKING | 
|  | 5 | + | 
|  | 6 | +import torch | 
|  | 7 | +import torch.distributed as dist | 
|  | 8 | +import torch.distributed._symmetric_memory as symm_mem | 
|  | 9 | + | 
|  | 10 | +from .experiment_util import BenchmarkOperator | 
|  | 11 | +from .experiment_util import ExperimentConfig | 
|  | 12 | + | 
|  | 13 | +if TYPE_CHECKING: | 
|  | 14 | +    import argparse | 
|  | 15 | + | 
|  | 16 | +BUILDTIN_SHAPES = [ | 
|  | 17 | +    4093, | 
|  | 18 | +    4096, | 
|  | 19 | +    5000, | 
|  | 20 | +    8192, | 
|  | 21 | +    8193, | 
|  | 22 | +    16384, | 
|  | 23 | +    16380, | 
|  | 24 | +    16387, | 
|  | 25 | +] | 
|  | 26 | +LARGE_K_SHAPES = [2**exp for exp in range(15, 21)] | 
|  | 27 | + | 
|  | 28 | + | 
|  | 29 | +class AllReduceBench(BenchmarkOperator): | 
|  | 30 | +    def gen_configs(self, args: argparse.Namespace) -> list[ExperimentConfig]: | 
|  | 31 | +        all_configs = [] | 
|  | 32 | +        for sz in args.shape: | 
|  | 33 | +            all_configs.append( | 
|  | 34 | +                ExperimentConfig( | 
|  | 35 | +                    shape=(sz,), | 
|  | 36 | +                    dtype=args.dtype, | 
|  | 37 | +                    backends=args.backend, | 
|  | 38 | +                    device=self.device, | 
|  | 39 | +                ) | 
|  | 40 | +            ) | 
|  | 41 | + | 
|  | 42 | +        return all_configs | 
|  | 43 | + | 
|  | 44 | +    def gen_inputs(self, config: ExperimentConfig) -> tuple: | 
|  | 45 | +        input_tensor = symm_mem.empty( | 
|  | 46 | +            config.shape, | 
|  | 47 | +            dtype=config.dtype, | 
|  | 48 | +            device=config.device, | 
|  | 49 | +        ) | 
|  | 50 | +        assert dist.group.WORLD is not None | 
|  | 51 | +        symm_mem.rendezvous(input_tensor, dist.group.WORLD.group_name) | 
|  | 52 | +        input_tensor = input_tensor.normal_() | 
|  | 53 | +        return (input_tensor,) | 
|  | 54 | + | 
|  | 55 | +    def additional_parser_args( | 
|  | 56 | +        self, parser: argparse.ArgumentParser | 
|  | 57 | +    ) -> argparse.ArgumentParser: | 
|  | 58 | +        parser.add_argument( | 
|  | 59 | +            "--shape", | 
|  | 60 | +            type=int, | 
|  | 61 | +            nargs="+", | 
|  | 62 | +            default=BUILDTIN_SHAPES + LARGE_K_SHAPES, | 
|  | 63 | +            help="Tensor lengths", | 
|  | 64 | +        ) | 
|  | 65 | +        return parser | 
|  | 66 | + | 
|  | 67 | +    def __init__(self) -> None: | 
|  | 68 | +        self.op_name = "allreduce" | 
|  | 69 | +        self.baseline = "nccl" | 
|  | 70 | +        super().__init__() | 
|  | 71 | + | 
|  | 72 | +        def nccl_ring(msg: torch.Tensor) -> torch.Tensor: | 
|  | 73 | +            dist.all_reduce(msg) | 
|  | 74 | +            return msg | 
|  | 75 | + | 
|  | 76 | +        assert dist.group.WORLD is not None | 
|  | 77 | + | 
|  | 78 | +        ALLREDUCE_DICT = { | 
|  | 79 | +            "multimem": functools.partial( | 
|  | 80 | +                torch.ops.symm_mem.multimem_all_reduce_, | 
|  | 81 | +                reduce_op="sum", | 
|  | 82 | +                group_name=dist.group.WORLD.group_name, | 
|  | 83 | +            ), | 
|  | 84 | +            "oneshot": functools.partial( | 
|  | 85 | +                torch.ops.symm_mem.one_shot_all_reduce, | 
|  | 86 | +                reduce_op="sum", | 
|  | 87 | +                group_name=dist.group.WORLD.group_name, | 
|  | 88 | +            ), | 
|  | 89 | +            "twoshot": functools.partial( | 
|  | 90 | +                torch.ops.symm_mem.two_shot_all_reduce_, | 
|  | 91 | +                reduce_op="sum", | 
|  | 92 | +                group_name=dist.group.WORLD.group_name, | 
|  | 93 | +            ), | 
|  | 94 | +            "nccl": nccl_ring, | 
|  | 95 | +            "helion_oneshot": ("examples.all_reduce", "helion_one_shot_all_reduce"), | 
|  | 96 | +            "kraken_oneshot": ("kraken.all_reduce", "one_shot_all_reduce"), | 
|  | 97 | +        } | 
|  | 98 | +        self.backend_dict = ALLREDUCE_DICT | 
0 commit comments