|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.distributed as dist |
| 5 | +import torch.distributed._symmetric_memory as symm_mem |
| 6 | + |
| 7 | + |
| 8 | +import helion |
| 9 | + |
| 10 | + |
| 11 | +@helion.jit( |
| 12 | + config=helion.Config( |
| 13 | + block_sizes=[24], |
| 14 | + num_warps=32, |
| 15 | + indexing="pointers", |
| 16 | + ), |
| 17 | + static_shapes=True, |
| 18 | +) |
| 19 | +def one_shot_all_reduce_kernel( |
| 20 | + buffer_ptr_addrs, |
| 21 | + signal_pad_ptrs, |
| 22 | + output_ptr, |
| 23 | + numel: tl.constexpr, |
| 24 | + rank: tl.constexpr, |
| 25 | + world_size: tl.constexpr, |
| 26 | + BLOCK_SIZE: tl.constexpr, |
| 27 | +): |
| 28 | + output = torch.empty_like(x) |
| 29 | + ptx_utils.symm_mem_sync( |
| 30 | + signal_pad_ptrs, None, rank, world_size, hasSubsequenceMemAccess=True |
| 31 | + ) |
| 32 | + |
| 33 | + pid = tl.program_id(axis=0) |
| 34 | + buffer_ptr_addrs = buffer_ptr_addrs.to(tl.pointer_type(tl.uint64)) |
| 35 | + output_ptr = output_ptr.to(tl.pointer_type(tl.bfloat16)) |
| 36 | + block_start = pid * BLOCK_SIZE |
| 37 | + |
| 38 | + while block_start < numel: |
| 39 | + # Each thread processes 128 bits. |
| 40 | + offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| 41 | + mask = offsets < numel |
| 42 | + |
| 43 | + acc = tl.zeros((BLOCK_SIZE,), dtype=tl.bfloat16) |
| 44 | + for i in range(world_size): |
| 45 | + buffer_ptr = tl.load(buffer_ptr_addrs + i).to(tl.pointer_type(tl.bfloat16)) |
| 46 | + tl.multiple_of(buffer_ptr, 16) |
| 47 | + x = tl.load(buffer_ptr + offsets, mask=mask) |
| 48 | + acc += x |
| 49 | + tl.store(output_ptr + offsets, acc, mask=mask) |
| 50 | + block_start += tl.num_programs(axis=0) * BLOCK_SIZE |
| 51 | + |
| 52 | + ptx_utils.symm_mem_sync( |
| 53 | + signal_pad_ptrs, None, rank, world_size, hasPreviousMemAccess=True |
| 54 | + ) |
| 55 | + |
| 56 | + |
| 57 | +def one_shot_all_reduce(tensor: torch.Tensor, **kwargs) -> torch.Tensor: |
| 58 | + config = { |
| 59 | + "max_num_blocks": kwargs.get("max_num_blocks", 24), |
| 60 | + "num_warps": kwargs.get("num_warps", 32), |
| 61 | + "BLOCK_SIZE": kwargs.get("BLOCK_SIZE", 8192), |
| 62 | + } |
| 63 | + |
| 64 | + assert tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." |
| 65 | + assert tensor.numel() % 8 == 0, "The number of elements must be 128-bit aligned." |
| 66 | + assert config["BLOCK_SIZE"] % (config["num_warps"] * 32) == 0, ( |
| 67 | + "BLOCK_SIZE must be a multiple of num_warps * 32" |
| 68 | + ) |
| 69 | + |
| 70 | + num_blocks = min( |
| 71 | + triton.cdiv(tensor.numel(), config["BLOCK_SIZE"]), config["max_num_blocks"] |
| 72 | + ) |
| 73 | + |
| 74 | + symm_mem_hdl = symm_mem.rendezvous(tensor, group=dist.group.WORLD) |
| 75 | + output = torch.empty_like(tensor) |
| 76 | + |
| 77 | + one_shot_all_reduce_kernel[(num_blocks, 1, 1)]( |
| 78 | + symm_mem_hdl.buffer_ptrs_dev, |
| 79 | + symm_mem_hdl.signal_pad_ptrs_dev, |
| 80 | + output, |
| 81 | + numel=tensor.numel(), |
| 82 | + rank=symm_mem_hdl.rank, |
| 83 | + world_size=symm_mem_hdl.world_size, |
| 84 | + BLOCK_SIZE=config["BLOCK_SIZE"], |
| 85 | + num_warps=config["num_warps"], |
| 86 | + ) |
| 87 | + |
| 88 | + return output |
0 commit comments