|  | 
|  | 1 | +""" | 
|  | 2 | +One-Shot All-Reduce Example | 
|  | 3 | +======================================== | 
|  | 4 | +This example demonstrates how to implement a one-shot pulling all-reduce operation | 
|  | 5 | +using Helion and PyTorch's distributed capabilities. It includes a Helion kernel | 
|  | 6 | +demonstrating how to do cross-device synchronization using symmetric memory signal pads | 
|  | 7 | +and access symmetric memory tensor resident on peer devices. | 
|  | 8 | +""" | 
|  | 9 | + | 
|  | 10 | +# %% | 
|  | 11 | +# Imports | 
|  | 12 | +# ------- | 
|  | 13 | +from __future__ import annotations | 
|  | 14 | + | 
|  | 15 | +import os | 
|  | 16 | + | 
|  | 17 | +import torch | 
|  | 18 | +import torch.distributed as dist | 
|  | 19 | +import torch.distributed._symmetric_memory as symm_mem | 
|  | 20 | +from torch.utils.cpp_extension import load_inline | 
|  | 21 | + | 
|  | 22 | +import helion | 
|  | 23 | +import helion.language as hl | 
|  | 24 | + | 
|  | 25 | +# %% | 
|  | 26 | +# Work around before symm mem natively supports extract dev_ptrs as tensors: from_blob | 
|  | 27 | +from_blob_cpp = """ | 
|  | 28 | +#include <cuda.h> | 
|  | 29 | +#include <cuda_runtime.h> | 
|  | 30 | +#include <iostream> | 
|  | 31 | +
 | 
|  | 32 | +
 | 
|  | 33 | +at::Tensor from_blob(uint64_t data_ptr, c10::IntArrayRef sizes, py::object dtype) { | 
|  | 34 | +
 | 
|  | 35 | +    at::Tensor tensor = at::for_blob((void*)data_ptr, sizes) | 
|  | 36 | +             .deleter([](void *ptr) { | 
|  | 37 | +               ; | 
|  | 38 | +             }) | 
|  | 39 | +             .options(at::device(at::kCUDA).dtype(((THPDtype*)dtype.ptr())->scalar_type)) | 
|  | 40 | +             .make_tensor(); | 
|  | 41 | +
 | 
|  | 42 | +    return tensor; | 
|  | 43 | +} | 
|  | 44 | +""" | 
|  | 45 | + | 
|  | 46 | +cpp_mod = load_inline( | 
|  | 47 | +    "cpp_mod", cpp_sources=from_blob_cpp, with_cuda=True, functions=["from_blob"] | 
|  | 48 | +) | 
|  | 49 | + | 
|  | 50 | + | 
|  | 51 | +def dev_array_to_tensor_short( | 
|  | 52 | +    dev_array_ptr: int, shape: tuple[int], dtype: torch.dtype, device: torch.device | 
|  | 53 | +) -> torch.Tensor: | 
|  | 54 | +    """ | 
|  | 55 | +    Convert a device array pointer to a PyTorch tensor. | 
|  | 56 | +
 | 
|  | 57 | +    This is a workaround function that creates a PyTorch tensor from a raw device pointer | 
|  | 58 | +    using the C++ extension. It's used to interface with symmetric memory device pointers | 
|  | 59 | +    before native support is available. | 
|  | 60 | +
 | 
|  | 61 | +    Args: | 
|  | 62 | +        dev_array_ptr: Raw device pointer as integer | 
|  | 63 | +        shape: Shape of the tensor to create | 
|  | 64 | +        dtype: PyTorch data type for the tensor | 
|  | 65 | +        device: Target device for the tensor | 
|  | 66 | +
 | 
|  | 67 | +    Returns: | 
|  | 68 | +        PyTorch tensor created from the device pointer | 
|  | 69 | +    """ | 
|  | 70 | +    return cpp_mod.from_blob(dev_array_ptr, shape, dtype)  # pyright: ignore[reportAttributeAccessIssue] | 
|  | 71 | + | 
|  | 72 | + | 
|  | 73 | +# %% | 
|  | 74 | +# One Shot All-Reduce Kernel Implementation | 
|  | 75 | +# ---------------------------------------- | 
|  | 76 | +@helion.jit( | 
|  | 77 | +    config=helion.Config( | 
|  | 78 | +        block_sizes=[8192], | 
|  | 79 | +        num_warps=32, | 
|  | 80 | +    ), | 
|  | 81 | +    static_shapes=True, | 
|  | 82 | +) | 
|  | 83 | +def one_shot_all_reduce_kernel( | 
|  | 84 | +    signal_pad_addrs: torch.Tensor, | 
|  | 85 | +    local_signal_pad: torch.Tensor, | 
|  | 86 | +    a_shared_tuple: tuple[torch.Tensor, ...], | 
|  | 87 | +    my_rank: hl.constexpr, | 
|  | 88 | +) -> torch.Tensor: | 
|  | 89 | +    """ | 
|  | 90 | +    Helion JIT-compiled kernel for one-shot all-reduce operation. | 
|  | 91 | +
 | 
|  | 92 | +    This kernel implements a distributed all-reduce using symmetric memory and signal pads | 
|  | 93 | +    for cross-device synchronization. It performs element-wise summation across all devices | 
|  | 94 | +    in the distributed group using tiled computation for memory efficiency. | 
|  | 95 | +
 | 
|  | 96 | +    Args: | 
|  | 97 | +        signal_pad_addrs: Tensor containing addresses of signal pads for all devices | 
|  | 98 | +        local_signal_pad: Local signal pad for synchronization | 
|  | 99 | +        a_shared_tuple: Tuple of shared tensors from all devices in the group | 
|  | 100 | +        my_rank: Current device's rank in the distributed group | 
|  | 101 | +
 | 
|  | 102 | +    Returns: | 
|  | 103 | +        Tensor containing the all-reduced result (sum across all devices) | 
|  | 104 | +    """ | 
|  | 105 | +    _, world_size = local_signal_pad.size() | 
|  | 106 | +    world_size = hl.specialize(world_size) | 
|  | 107 | +    out = torch.empty_like(a_shared_tuple[0]) | 
|  | 108 | +    N = out.size(0) | 
|  | 109 | + | 
|  | 110 | +    for tile_n in hl.tile(N): | 
|  | 111 | +        # Sync all devices through signal_pad to make sure | 
|  | 112 | +        # all previous writes to the shared tensor are visible | 
|  | 113 | +        ptr_tile = signal_pad_addrs[:] | 
|  | 114 | +        stack_signalpad = hl.stacktensor_like(local_signal_pad, ptr_tile) | 
|  | 115 | +        hl.signal( | 
|  | 116 | +            stack_signalpad, | 
|  | 117 | +            [tile_n.id, my_rank], | 
|  | 118 | +            signal=1, | 
|  | 119 | +            wait_for=0, | 
|  | 120 | +            scope="sys", | 
|  | 121 | +            hasPreviousMemAccess=False, | 
|  | 122 | +        ) | 
|  | 123 | + | 
|  | 124 | +        for world in hl.tile(world_size, block_size=world_size): | 
|  | 125 | +            hl.wait( | 
|  | 126 | +                local_signal_pad, | 
|  | 127 | +                [tile_n.id, world], | 
|  | 128 | +                signal=1, | 
|  | 129 | +                update=0, | 
|  | 130 | +                scope="sys", | 
|  | 131 | +            ) | 
|  | 132 | + | 
|  | 133 | +        acc = hl.zeros( | 
|  | 134 | +            [tile_n], dtype=a_shared_tuple[0].dtype, device=local_signal_pad.device | 
|  | 135 | +        ) | 
|  | 136 | + | 
|  | 137 | +        for a in a_shared_tuple: | 
|  | 138 | +            acc += a[tile_n] | 
|  | 139 | + | 
|  | 140 | +        out[tile_n] = acc | 
|  | 141 | + | 
|  | 142 | +        # Sync all devices through signal_pad to make sure our writes to shared | 
|  | 143 | +        # tensor are visible to subsequent kernels. | 
|  | 144 | +        hl.signal( | 
|  | 145 | +            stack_signalpad, [tile_n.id, my_rank], signal=1, wait_for=0, scope="sys" | 
|  | 146 | +        ) | 
|  | 147 | + | 
|  | 148 | +        for world in hl.tile(world_size, block_size=world_size): | 
|  | 149 | +            hl.wait( | 
|  | 150 | +                local_signal_pad, | 
|  | 151 | +                [tile_n.id, world], | 
|  | 152 | +                signal=1, | 
|  | 153 | +                update=0, | 
|  | 154 | +                scope="sys", | 
|  | 155 | +                hasSubsequentMemAccess=False, | 
|  | 156 | +            ) | 
|  | 157 | +    return out | 
|  | 158 | + | 
|  | 159 | + | 
|  | 160 | +# %% | 
|  | 161 | +# Attract tensors from symmetric memory handler | 
|  | 162 | +# ---------------------------------------- | 
|  | 163 | +def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor: | 
|  | 164 | +    """ | 
|  | 165 | +    Prepares symmetric memory tensors for Helion one-shot all-reduce kernel. | 
|  | 166 | +    Tracks shared tensors as tuple of tensors, and/or dev_ptrs tensors. | 
|  | 167 | +
 | 
|  | 168 | +    Args: | 
|  | 169 | +        a_shared: Input tensor to be all-reduced across all devices | 
|  | 170 | +
 | 
|  | 171 | +    Returns: | 
|  | 172 | +        Tensor containing the all-reduced result (sum across all devices) | 
|  | 173 | +    """ | 
|  | 174 | +    assert dist.group.WORLD is not None | 
|  | 175 | + | 
|  | 176 | +    symm_mem_hdl = symm_mem.rendezvous(a_shared, group=dist.group.WORLD) | 
|  | 177 | + | 
|  | 178 | +    a_shared_tuple = tuple( | 
|  | 179 | +        [ | 
|  | 180 | +            symm_mem_hdl.get_buffer(i, tuple(a_shared.shape), a_shared.dtype) | 
|  | 181 | +            for i in range(symm_mem_hdl.world_size) | 
|  | 182 | +        ] | 
|  | 183 | +    ) | 
|  | 184 | + | 
|  | 185 | +    local_signal_pad = symm_mem_hdl.get_signal_pad( | 
|  | 186 | +        symm_mem_hdl.rank, dtype=torch.int32 | 
|  | 187 | +    ).view(-1, symm_mem_hdl.world_size) | 
|  | 188 | + | 
|  | 189 | +    signal_pad_addrs = dev_array_to_tensor_short( | 
|  | 190 | +        symm_mem_hdl.signal_pad_ptrs_dev, | 
|  | 191 | +        (symm_mem_hdl.world_size,), | 
|  | 192 | +        dtype=torch.uint64, | 
|  | 193 | +        device=a_shared.device, | 
|  | 194 | +    ) | 
|  | 195 | + | 
|  | 196 | +    return one_shot_all_reduce_kernel( | 
|  | 197 | +        signal_pad_addrs, | 
|  | 198 | +        local_signal_pad, | 
|  | 199 | +        a_shared_tuple, | 
|  | 200 | +        my_rank=symm_mem_hdl.rank, | 
|  | 201 | +    ) | 
|  | 202 | + | 
|  | 203 | + | 
|  | 204 | +# %% | 
|  | 205 | +# Testing Function | 
|  | 206 | +# ---------------------------------------- | 
|  | 207 | +def test(N: int, device: torch.device, dtype: torch.dtype) -> None: | 
|  | 208 | +    """ | 
|  | 209 | +    Test the Helion all-reduce implementation against PyTorch's reference implementation. | 
|  | 210 | +    Args: | 
|  | 211 | +        N: Total number of elements to test (will be divided by world_size per device) | 
|  | 212 | +        device: CUDA device to run the test on | 
|  | 213 | +        dtype: Data type for the test tensors | 
|  | 214 | +    """ | 
|  | 215 | +    dist_group = dist.group.WORLD | 
|  | 216 | +    assert dist_group is not None | 
|  | 217 | + | 
|  | 218 | +    world_size = dist.get_world_size() | 
|  | 219 | +    a_shared = symm_mem.empty(N // world_size, dtype=dtype, device=device).normal_() | 
|  | 220 | + | 
|  | 221 | +    a_shared_clone = symm_mem.empty( | 
|  | 222 | +        a_shared.shape, | 
|  | 223 | +        dtype=a_shared.dtype, | 
|  | 224 | +        device=a_shared.device, | 
|  | 225 | +    ) | 
|  | 226 | +    symm_mem.rendezvous(a_shared_clone, dist_group.group_name) | 
|  | 227 | +    a_shared_clone.copy_(a_shared) | 
|  | 228 | + | 
|  | 229 | +    a_out = helion_one_shot_all_reduce(a_shared) | 
|  | 230 | + | 
|  | 231 | +    gloden_o = torch.ops.symm_mem.one_shot_all_reduce( | 
|  | 232 | +        a_shared_clone, "sum", dist_group.group_name | 
|  | 233 | +    ) | 
|  | 234 | + | 
|  | 235 | +    torch.testing.assert_close(a_out, gloden_o, rtol=1e-1, atol=1e-1) | 
|  | 236 | + | 
|  | 237 | + | 
|  | 238 | +def main() -> None: | 
|  | 239 | +    """ | 
|  | 240 | +    Main entry point for the all-reduce example. | 
|  | 241 | +
 | 
|  | 242 | +    Sets up the distributed environment, initializes CUDA devices, and runs the | 
|  | 243 | +    all-reduce test, and then clean up. | 
|  | 244 | +    """ | 
|  | 245 | +    rank = int(os.environ["LOCAL_RANK"]) | 
|  | 246 | +    torch.manual_seed(42 + rank) | 
|  | 247 | +    device = torch.device(f"cuda:{rank}") | 
|  | 248 | +    torch.cuda.set_device(device) | 
|  | 249 | +    dist.init_process_group("nccl") | 
|  | 250 | +    test(16384, device, torch.bfloat16) | 
|  | 251 | + | 
|  | 252 | +    dist.destroy_process_group() | 
|  | 253 | + | 
|  | 254 | + | 
|  | 255 | +if __name__ == "__main__": | 
|  | 256 | +    """ | 
|  | 257 | +    Run with: | 
|  | 258 | +    torchrun \ | 
|  | 259 | +    --nnodes 1 --nproc-per-node 8 \ | 
|  | 260 | +    --rdzv-backend c10d --rdzv-endpoint localhost:0 \ | 
|  | 261 | +    --no_python python3 examples/all_reduce.py | 
|  | 262 | +    """ | 
|  | 263 | +    main() | 
0 commit comments