-
Notifications
You must be signed in to change notification settings - Fork 5
Open
Description
@triton.jit
def _conv2d_forward(input_ptr, weight_ptr, output_ptr, B, C, H, W, F, K,
stride, padding, dilation, H_out, W_out, input_batch_stride,
input_channel_stride, input_height_stride, input_width_stride,
weight_out_channel_stride, weight_kernel_stride, output_batch_stride,
output_channel_stride, output_height_stride, output_width_stride, M,
K_total, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
rk = tl.arange(0, BLOCK_SIZE_K)
m_mask = rm < M
batch_idx = rm // (H_out * W_out)
spatial_idx = rm % (H_out * W_out)
h_out = spatial_idx // W_out
w_out = spatial_idx % W_out
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K_total, BLOCK_SIZE_K)):
rk = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
k_mask = rk < K_total
kernel_idx = rk
c = kernel_idx // (K * K)
kh = kernel_idx % (K * K) // K
kw = kernel_idx % K
h_in = h_out[:, None] * stride + kh[None, :] * dilation - padding
w_in = w_out[:, None] * stride + kw[None, :] * dilation - padding
in_bounds = (h_in >= 0) & (h_in < H) & (w_in >= 0) & (w_in < W)
full_mask = m_mask[:, None] & k_mask[None, :] & in_bounds
input_offsets = (batch_idx[:, None] * input_batch_stride + c[None,
:] * input_channel_stride + h_in * input_height_stride + w_in *
input_width_stride)
input_block = tl.load(input_ptr + input_offsets, mask=full_mask,
other=0.0)
weight_block = tl.load(weight_ptr + rn[None, :] *
weight_out_channel_stride + rk[:, None] * weight_kernel_stride,
mask=k_mask[:, None] & (rn[None, :] < F), other=0.0)
input_block = input_block.to(tl.float16)
weight_block = weight_block.to(tl.float16)
acc += tl.dot(input_block, weight_block)
output_offsets = batch_idx[:, None] * output_batch_stride + rn[None, :
] * output_channel_stride + h_out[:, None
] * output_height_stride + w_out[:, None] * output_width_stride
tl.store(output_ptr + output_offsets, acc, mask=m_mask[:, None] & (rn[
None, :] < F))
Describe the bug
A clear and concise description of what the bug is.
To Reproduce
Steps to reproduce the behavior:
- Go to '...'
- Click on '....'
- Scroll down to '....'
- See error
Expected behavior
A clear and concise description of what you expected to happen.
Screenshots
If applicable, add screenshots to help explain your problem.
Desktop (please complete the following information):
- OS: [e.g. iOS]
- Browser [e.g. chrome, safari]
- Version [e.g. 22]
Smartphone (please complete the following information):
- Device: [e.g. iPhone6]
- OS: [e.g. iOS8.1]
- Browser [e.g. stock browser, safari]
- Version [e.g. 22]
Additional context
Add any other context about the problem here.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels