|
| 1 | +""" |
| 2 | +Helion SE Block Example |
| 3 | +============================ |
| 4 | +This example demonstrates a Helion kernel implementation of SE Block. |
| 5 | +""" |
| 6 | + |
| 7 | +# %% |
| 8 | +from __future__ import annotations |
| 9 | + |
| 10 | +import torch |
| 11 | +from torch import Tensor |
| 12 | + |
| 13 | +import helion |
| 14 | +from helion._testing import DEVICE |
| 15 | +from helion._testing import run_example |
| 16 | +import helion.language as hl |
| 17 | + |
| 18 | + |
| 19 | +# %% |
| 20 | +@helion.kernel( |
| 21 | + # static_shapes=True gives a performance boost for matmuls |
| 22 | + static_shapes=True, |
| 23 | +) |
| 24 | +def se_block_fwd(x: Tensor, w: Tensor) -> tuple[Tensor, Tensor]: |
| 25 | + """ |
| 26 | + Performs 2 * x * sigmoid(x @ w) |
| 27 | + Args: |
| 28 | + x: 2D tensor of shape [m, n]. |
| 29 | + w: 2D tensor of shape [n, n]. |
| 30 | + Returns: |
| 31 | + out: Resulting matrix of shape [m, n]. |
| 32 | + s: sigmoid(x @ w) of shape [m, n]. |
| 33 | + """ |
| 34 | + m, n = x.size() |
| 35 | + |
| 36 | + out = torch.empty([m, n], dtype=x.dtype, device=x.device) |
| 37 | + s = torch.empty([m, n], dtype=x.dtype, device=x.device) |
| 38 | + |
| 39 | + for tile_m in hl.tile(m): |
| 40 | + for tile_n in hl.tile(n): |
| 41 | + s[tile_m, tile_n] = torch.sigmoid(x[tile_m, :] @ w[:, tile_n]) |
| 42 | + acc = 2.0 * x[tile_m, tile_n] * s[tile_m, tile_n] |
| 43 | + out[tile_m, tile_n] = acc |
| 44 | + |
| 45 | + return out, s |
| 46 | + |
| 47 | + |
| 48 | +# %% |
| 49 | +@helion.kernel(static_shapes=True) |
| 50 | +def se_block_bwd_dx(grad_out: Tensor, x: Tensor, w: Tensor, s: Tensor) -> Tensor: |
| 51 | + """ |
| 52 | + Compute gradient for x. |
| 53 | + grad_x = 2 * grad_out * s + (2 * grad_out * x * s * (1 - s)) @ w.T |
| 54 | +
|
| 55 | + Args: |
| 56 | + grad_out: Gradient w.r.t output [m, n] |
| 57 | + x: Input tensor [m, n] |
| 58 | + w: Weight matrix [n, n] |
| 59 | + s: sigmoid(x @ w) from forward pass [m, n] |
| 60 | +
|
| 61 | + Returns: |
| 62 | + grad_x: Gradient w.r.t x [m, n] |
| 63 | + """ |
| 64 | + m, n = x.size() |
| 65 | + |
| 66 | + grad_x = torch.empty([m, n], dtype=torch.float32, device=x.device) |
| 67 | + |
| 68 | + for tile_m, tile_n in hl.tile([m, n]): |
| 69 | + # 2 * grad_out * s |
| 70 | + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) |
| 71 | + acc += 2.0 * grad_out[tile_m, tile_n] * s[tile_m, tile_n] |
| 72 | + |
| 73 | + for tile_k in hl.tile(n): |
| 74 | + # 2 * grad_out * x * s * (1-s) for tile_k |
| 75 | + grad_to_w = ( |
| 76 | + 2.0 |
| 77 | + * grad_out[tile_m, tile_k].to(torch.float32) |
| 78 | + * x[tile_m, tile_k].to(torch.float32) |
| 79 | + * s[tile_m, tile_k].to(torch.float32) |
| 80 | + * (1.0 - s[tile_m, tile_k].to(torch.float32)) |
| 81 | + ) |
| 82 | + # grad_to_w @ w.T[tile_k, tile_n] = grad_to_w @ w[tile_n, tile_k].T |
| 83 | + acc += grad_to_w @ w[tile_n, tile_k].to(torch.float32).T |
| 84 | + |
| 85 | + grad_x[tile_m, tile_n] = acc.to(x.dtype) |
| 86 | + |
| 87 | + return grad_x |
| 88 | + |
| 89 | + |
| 90 | +# %% |
| 91 | +@helion.kernel(static_shapes=True) |
| 92 | +def se_block_bwd_dw(grad_out: Tensor, x: Tensor, s: Tensor) -> Tensor: |
| 93 | + """ |
| 94 | + Compute gradient for w. |
| 95 | + grad_w = x.T @ (2 * grad_out * x * s * (1 - s)) |
| 96 | +
|
| 97 | + Args: |
| 98 | + grad_out: Gradient w.r.t output [m, n] |
| 99 | + x: Input tensor [m, n] |
| 100 | + s: sigmoid(x @ w) from forward pass [m, n] |
| 101 | +
|
| 102 | + Returns: |
| 103 | + grad_w: Gradient w.r.t w [n, n] |
| 104 | + """ |
| 105 | + m, n = x.size() |
| 106 | + |
| 107 | + grad_w = torch.zeros([n, n], dtype=torch.float32, device=x.device) |
| 108 | + |
| 109 | + for tile_n1, tile_n2 in hl.tile([n, n]): |
| 110 | + acc_w = hl.zeros([tile_n1, tile_n2], dtype=torch.float32) |
| 111 | + for tile_m in hl.tile(m): |
| 112 | + # 2 * grad_out * x * s * (1-s) |
| 113 | + grad_to_w = ( |
| 114 | + 2.0 |
| 115 | + * grad_out[tile_m, tile_n2].to(torch.float32) |
| 116 | + * x[tile_m, tile_n2].to(torch.float32) |
| 117 | + * s[tile_m, tile_n2].to(torch.float32) |
| 118 | + * (1.0 - s[tile_m, tile_n2].to(torch.float32)) |
| 119 | + ) |
| 120 | + # x[tile_m, tile_n1].T @ grad_to_w[tile_m, tile_n2] |
| 121 | + acc_w += x[tile_m, tile_n1].to(torch.float32).T @ grad_to_w |
| 122 | + |
| 123 | + grad_w[tile_n1, tile_n2] = acc_w.to(x.dtype) |
| 124 | + |
| 125 | + return grad_w |
| 126 | + |
| 127 | + |
| 128 | +# %% |
| 129 | +# Reference Implementation |
| 130 | +# -------------------- |
| 131 | +def se_block_pytorch(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: |
| 132 | + """ |
| 133 | + PyTorch reference implementation se_block. |
| 134 | +
|
| 135 | + Args: |
| 136 | + x, w: Input tensors |
| 137 | +
|
| 138 | + Returns: |
| 139 | + tensor of 2 * x * sigmoid(x @ w) |
| 140 | + """ |
| 141 | + return 2 * x * torch.sigmoid(x @ w) |
| 142 | + |
| 143 | + |
| 144 | +# %% |
| 145 | +# Autograd Function |
| 146 | +# ------------------ |
| 147 | +class SEBlockFunction(torch.autograd.Function): |
| 148 | + @staticmethod |
| 149 | + def forward( # type: ignore[override] |
| 150 | + ctx: object, |
| 151 | + x: torch.Tensor, |
| 152 | + w: torch.Tensor, |
| 153 | + ) -> torch.Tensor: |
| 154 | + """Forward pass for se block.""" |
| 155 | + out, s = se_block_fwd(x, w) |
| 156 | + ctx.save_for_backward(x, w, s) # type: ignore[attr-defined] |
| 157 | + return out |
| 158 | + |
| 159 | + @staticmethod |
| 160 | + def backward( # type: ignore[override] |
| 161 | + ctx: object, |
| 162 | + grad_out: torch.Tensor, |
| 163 | + ) -> tuple[torch.Tensor, torch.Tensor]: |
| 164 | + """Backward pass for se block.""" |
| 165 | + x, w, s = ctx.saved_tensors # type: ignore[attr-defined] |
| 166 | + |
| 167 | + grad_x = se_block_bwd_dx(grad_out, x, w, s) |
| 168 | + grad_w = se_block_bwd_dw(grad_out, x, s) |
| 169 | + |
| 170 | + return grad_x, grad_w |
| 171 | + |
| 172 | + |
| 173 | +def se_block(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: |
| 174 | + """ |
| 175 | + SE Block with autograd support. |
| 176 | +
|
| 177 | + Args: |
| 178 | + x: Input tensor [m, n] |
| 179 | + w: Weight matrix [n, n] |
| 180 | +
|
| 181 | + Returns: |
| 182 | + Output tensor [m, n] |
| 183 | + """ |
| 184 | + return SEBlockFunction.apply(x, w) # type: ignore[no-any-return] |
| 185 | + |
| 186 | + |
| 187 | +def check(m: int, n: int) -> None: |
| 188 | + """ |
| 189 | + Checks the correctness against PyTorch. |
| 190 | + Args: |
| 191 | + m (int): Number of rows in matrix x. |
| 192 | + n (int): Number of columns in matrix x. |
| 193 | + """ |
| 194 | + x = torch.randn([m, n], device=DEVICE, dtype=torch.float16, requires_grad=True) |
| 195 | + w = torch.randn([n, n], device=DEVICE, dtype=torch.float16, requires_grad=True) |
| 196 | + for bwd in [True, False]: |
| 197 | + run_example(se_block, se_block_pytorch, (x, w), bwd=bwd) |
| 198 | + |
| 199 | + |
| 200 | +# %% |
| 201 | +def main() -> None: |
| 202 | + """ |
| 203 | + Main function to run correctness checks. |
| 204 | + """ |
| 205 | + check(1024, 1024) |
| 206 | + |
| 207 | + |
| 208 | +# %% |
| 209 | +if __name__ == "__main__": |
| 210 | + main() |
0 commit comments