Skip to content

Commit ff11c14

Browse files
mengluy0125meta-codesync[bot]
authored andcommitted
Add simplified se_block kernel (#989)
Summary: Pull Request resolved: #989 We add a helion kernel to compute 2 * x * sigmoid(x @ w) Differential Revision: D84968671
1 parent fc79ea7 commit ff11c14

File tree

3 files changed

+499
-0
lines changed

3 files changed

+499
-0
lines changed

examples/se_block.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
"""
2+
Helion SE Block Example
3+
============================
4+
This example demonstrates a Helion kernel implementation of SE Block.
5+
"""
6+
7+
# %%
8+
from __future__ import annotations
9+
10+
import torch
11+
from torch import Tensor
12+
13+
import helion
14+
from helion._testing import run_example
15+
import helion.language as hl
16+
17+
18+
# %%
19+
@helion.kernel(
20+
# static_shapes=True gives a performance boost for matmuls
21+
static_shapes=True,
22+
)
23+
def se_block_fwd(x: Tensor, w: Tensor) -> tuple[Tensor, Tensor]:
24+
"""
25+
Performs 2 * x * sigmoid(x @ w)
26+
Args:
27+
x: 2D tensor of shape [m, n].
28+
w: 2D tensor of shape [n, n].
29+
Returns:
30+
out: Resulting matrix of shape [m, n].
31+
s: sigmoid(x @ w) of shape [m, n].
32+
"""
33+
m, n = x.size()
34+
35+
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
36+
s = torch.empty([m, n], dtype=x.dtype, device=x.device)
37+
38+
for tile_m in hl.tile(m):
39+
for tile_n in hl.tile(n):
40+
s[tile_m, tile_n] = torch.sigmoid(x[tile_m, :] @ w[:, tile_n])
41+
acc = 2.0 * x[tile_m, tile_n] * s[tile_m, tile_n]
42+
out[tile_m, tile_n] = acc
43+
44+
return out, s
45+
46+
47+
# %%
48+
@helion.kernel(static_shapes=True)
49+
def se_block_bwd_dx(
50+
grad_out: Tensor, x: Tensor, w: Tensor, s: Tensor
51+
) -> Tensor:
52+
"""
53+
Compute gradient for x.
54+
grad_x = 2 * grad_out * s + (2 * grad_out * x * s * (1 - s)) @ w.T
55+
56+
Args:
57+
grad_out: Gradient w.r.t output [m, n]
58+
x: Input tensor [m, n]
59+
w: Weight matrix [n, n]
60+
s: sigmoid(x @ w) from forward pass [m, n]
61+
62+
Returns:
63+
grad_x: Gradient w.r.t x [m, n]
64+
"""
65+
m, n = x.size()
66+
67+
grad_x = torch.empty([m, n], dtype=torch.float32, device=x.device)
68+
69+
for tile_m, tile_n in hl.tile([m, n]):
70+
# 2 * grad_out * s
71+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
72+
acc += 2.0 * grad_out[tile_m, tile_n] * s[tile_m, tile_n]
73+
74+
for tile_k in hl.tile(n):
75+
# 2 * grad_out * x * s * (1-s) for tile_k
76+
grad_to_w = (
77+
2.0 * grad_out[tile_m, tile_k].to(torch.float32)
78+
* x[tile_m, tile_k].to(torch.float32)
79+
* s[tile_m, tile_k].to(torch.float32)
80+
* (1.0 - s[tile_m, tile_k].to(torch.float32))
81+
)
82+
# grad_to_w @ w.T[tile_k, tile_n] = grad_to_w @ w[tile_n, tile_k].T
83+
acc += grad_to_w @ w[tile_n, tile_k].to(torch.float32).T
84+
85+
grad_x[tile_m, tile_n] = acc.to(x.dtype)
86+
87+
return grad_x
88+
89+
90+
# %%
91+
@helion.kernel(static_shapes=True)
92+
def se_block_bwd_dw(
93+
grad_out: Tensor, x: Tensor, s: Tensor
94+
) -> Tensor:
95+
"""
96+
Compute gradient for w.
97+
grad_w = x.T @ (2 * grad_out * x * s * (1 - s))
98+
99+
Args:
100+
grad_out: Gradient w.r.t output [m, n]
101+
x: Input tensor [m, n]
102+
s: sigmoid(x @ w) from forward pass [m, n]
103+
104+
Returns:
105+
grad_w: Gradient w.r.t w [n, n]
106+
"""
107+
m, n = x.size()
108+
109+
grad_w = torch.zeros([n, n], dtype=torch.float32, device=x.device)
110+
111+
for tile_n1, tile_n2 in hl.tile([n, n]):
112+
acc_w = hl.zeros([tile_n1, tile_n2], dtype=torch.float32)
113+
for tile_m in hl.tile(m):
114+
# 2 * grad_out * x * s * (1-s)
115+
grad_to_w = (
116+
2.0 * grad_out[tile_m, tile_n2].to(torch.float32)
117+
* x[tile_m, tile_n2].to(torch.float32)
118+
* s[tile_m, tile_n2].to(torch.float32)
119+
* (1.0 - s[tile_m, tile_n2].to(torch.float32))
120+
)
121+
# x[tile_m, tile_n1].T @ grad_to_w[tile_m, tile_n2]
122+
acc_w += x[tile_m, tile_n1].to(torch.float32).T @ grad_to_w
123+
124+
grad_w[tile_n1, tile_n2] = acc_w.to(x.dtype)
125+
126+
return grad_w
127+
128+
129+
# %%
130+
# Reference Implementation
131+
# --------------------
132+
def se_block_pytorch(
133+
x: torch.Tensor, w: torch.Tensor
134+
) -> torch.Tensor:
135+
"""
136+
PyTorch reference implementation se_block.
137+
138+
Args:
139+
x, w: Input tensors
140+
141+
Returns:
142+
tensor of 2 * x * sigmoid(x @ w)
143+
"""
144+
return 2 * x * torch.sigmoid(x @ w)
145+
146+
147+
# %%
148+
# Autograd Function
149+
# ------------------
150+
class SEBlockFunction(torch.autograd.Function):
151+
@staticmethod
152+
def forward( # type: ignore[override]
153+
ctx: object,
154+
x: torch.Tensor,
155+
w: torch.Tensor,
156+
) -> torch.Tensor:
157+
"""Forward pass for se block."""
158+
out, s = se_block_fwd(x, w)
159+
ctx.save_for_backward(x, w, s) # type: ignore[attr-defined]
160+
return out
161+
162+
@staticmethod
163+
def backward( # type: ignore[override]
164+
ctx: object,
165+
grad_out: torch.Tensor,
166+
) -> tuple[torch.Tensor, torch.Tensor]:
167+
"""Backward pass for se block."""
168+
x, w, s = ctx.saved_tensors # type: ignore[attr-defined]
169+
170+
grad_x = se_block_bwd_dx(grad_out, x, w, s)
171+
grad_w = se_block_bwd_dw(grad_out, x, s)
172+
173+
return grad_x, grad_w
174+
175+
176+
def se_block(
177+
x: torch.Tensor, w: torch.Tensor
178+
) -> torch.Tensor:
179+
"""
180+
SE Block with autograd support.
181+
182+
Args:
183+
x: Input tensor [m, n]
184+
w: Weight matrix [n, n]
185+
186+
Returns:
187+
Output tensor [m, n]
188+
"""
189+
return SEBlockFunction.apply(x, w) # type: ignore[no-any-return]
190+
191+
192+
def check(m: int, n: int) -> None:
193+
"""
194+
Checks the correctness against PyTorch.
195+
Args:
196+
m (int): Number of rows in matrix x.
197+
n (int): Number of columns in matrix x.
198+
"""
199+
x = torch.randn([m, n], device="cuda", dtype=torch.float16, requires_grad=True)
200+
w = torch.randn([n, n], device="cuda", dtype=torch.float16, requires_grad=True)
201+
for bwd in [True, False]:
202+
run_example(se_block, se_block_pytorch, (x, w), bwd=bwd)
203+
204+
205+
# %%
206+
def main() -> None:
207+
"""
208+
Main function to run correctness checks.
209+
"""
210+
check(1024, 1024)
211+
212+
213+
# %%
214+
if __name__ == "__main__":
215+
main()

0 commit comments

Comments
 (0)