Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
3dc5c96
add simulated nvfp4 gemv example.
vickiw973 Oct 8, 2025
9b5aae3
modify the code
vickiw973 Oct 9, 2025
5726cbc
fix function failure for fp4 simulated gemv
vickiw973 Oct 10, 2025
232ab54
rename the folder
vickiw973 Oct 10, 2025
eb7cf9e
remove useless files
vickiw973 Oct 10, 2025
9a1d6c9
improve testing time.
vickiw973 Oct 10, 2025
2eaef11
remove useless file.
vickiw973 Oct 10, 2025
174ffdf
simplify nvfp4 gemv code
vickiw973 Oct 15, 2025
4d3bd27
add nvfp4 gemm code.
vickiw973 Oct 15, 2025
8410579
fix typo in comments.
vickiw973 Oct 16, 2025
82e0912
add dual gemm example
vickiw973 Oct 16, 2025
3492972
add group nvfp4 example
vickiw973 Oct 16, 2025
cf64255
move scale factor reorder operation to host.
vickiw973 Oct 21, 2025
d229c90
move scale factor initialization function to reference.
vickiw973 Oct 21, 2025
a9e20d4
simplify code
vickiw973 Oct 21, 2025
20c0bb0
remove useless files.
vickiw973 Oct 21, 2025
634a0b3
move some costs to host.
vickiw973 Oct 21, 2025
3ad7680
improve speed of light analysis.
vickiw973 Oct 23, 2025
0d7d037
improve comments.
vickiw973 Oct 23, 2025
1b76272
WIP: work on integrating with the platform
S1ro1 Nov 4, 2025
f0a784d
add local eval file.
vickiw973 Nov 4, 2025
724160c
tight the tolerance value
vickiw973 Nov 5, 2025
c4aecfe
optimize data convert function in reference.
vickiw973 Nov 6, 2025
e26912c
add more explanation about why we need a seperate compile_func.
vickiw973 Nov 6, 2025
73175a4
use cute tensor to do accumulation operation.
vickiw973 Nov 8, 2025
d6bbb97
clean codes.
vickiw973 Nov 8, 2025
0bb660a
improve comments.
vickiw973 Nov 8, 2025
329369c
improve comments.
vickiw973 Nov 8, 2025
7a8d0cc
fix compilation error.
vickiw973 Nov 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
489 changes: 489 additions & 0 deletions problems/nvidia/eval.py

Large diffs are not rendered by default.

194 changes: 194 additions & 0 deletions problems/nvidia/nvfp4_dual_gemm/reference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import torch
from task import input_t, output_t
from utils import make_match_reference

# Scaling factor vector size
sf_vec_size = 16

# Helper function for ceiling division
def ceil_div(a, b):
return (a + b - 1) // b

# Helper function to convert scale factor tensor to blocked format
def to_blocked(input_matrix):
rows, cols = input_matrix.shape

# Please ensure rows and cols are multiples of 128 and 4 respectively
n_row_blocks = ceil_div(rows, 128)
n_col_blocks = ceil_div(cols, 4)

padded = input_matrix
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)

return rearranged.flatten()


def ref_kernel(
data: input_t,
) -> output_t:
"""
PyTorch reference implementation of NVFP4 block-scaled dual GEMM with silu activation,
C = silu(A @ B1) * (A @ B2).
"""
a_ref, b1_ref, b2_ref, sfa_ref_cpu, sfb1_ref_cpu, sfb2_ref_cpu, _, _, _, c_ref = data

# Get dimensions from MxNxL layout
m, n, l = c_ref.shape

# Call torch._scaled_mm to compute the GEMV result
ref1 = torch.empty(
(l, m, n),
dtype=torch.float32,
device="cuda",
).permute(1, 2, 0)
ref2 = torch.empty(
(l, m, n),
dtype=torch.float32,
device="cuda",
).permute(1, 2, 0)
for l_idx in range(l):
# Convert the scale factor tensor to blocked format
scale_a = to_blocked(sfa_ref_cpu[:, :, l_idx])
scale_b1 = to_blocked(sfb1_ref_cpu[:, :, l_idx])
scale_b2 = to_blocked(sfb2_ref_cpu[:, :, l_idx])
# (m, k) @ (n, k).T -> (m, n)
res1 = torch._scaled_mm(
a_ref[:, :, l_idx],
b1_ref[:, :, l_idx].transpose(0, 1),
scale_a.cuda(),
scale_b1.cuda(),
bias=None,
out_dtype=torch.float32,
)
ref1[:, :, l_idx] = res1

res2 = torch._scaled_mm(
a_ref[:, :, l_idx],
b2_ref[:, :, l_idx].transpose(0, 1),
scale_a.cuda(),
scale_b2.cuda(),
bias=None,
out_dtype=torch.float32,
)
ref2[:, :, l_idx] = res2
# Do silu on the first GEMM result and multiply with the second GEMM result
c_ref = (torch.nn.functional.silu(ref1) * ref2).to(torch.float16)
return c_ref


def generate_input(
m: int,
n: int,
k: int,
l: int,
seed: int,
):
"""
Generate input tensors for NVFP4 block-scaled dual GEMM with silu activation,
C = silu(A @ B1) * (A @ B2).

Args:
m: Number of rows in matrix A
n: Number of columns in matrix B1 and B2
k: Number of columns in A and rows of B1 and B2
l: Batch size
seed: Random seed for reproducibility

Returns:
Tuple of (a, b, scale_a, scale_b, c) where:
a: [m, k, l] - Input matrix in torch.float4e2m1fn_x2 data type
b1: [n, k, l] - Input matrix in torch.float4e2m1fn_x2 data type
b2: [n, k, l] - Input matrix in torch.float4e2m1fn_x2 data type
scale_a: [m, k, l] - Input scale factors in torch.float8e4m3fn data type
scale_b1: [n, k, l] - Input scale factors in torch.float8e4m3fn data type
scale_b2: [n, k, l] - Input scale factors in torch.float8e4m3fn data type
scale_a_permuted: [32, 4, rest_m, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type
scale_b1_permuted: [32, 4, rest_n, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type
scale_b2_permuted: [32, 4, rest_n, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type
c: [m, n, l] - Output matrix in torch.float16 data type
"""
torch.manual_seed(seed)

# Generate uint8 tensor, then convert to float4e2m1fn_x2 data type
a_ref = torch.randint(
0, 2, (l, m, k // 2), dtype=torch.uint8, device="cuda"
).permute(1, 2, 0)
b1_ref = torch.randint(
0, 2, (l, n, k // 2), dtype=torch.uint8, device="cuda"
).permute(1, 2, 0)
b2_ref = torch.randint(
0, 2, (l, n, k // 2), dtype=torch.uint8, device="cuda"
).permute(1, 2, 0)
a_ref = a_ref.view(torch.float4_e2m1fn_x2)
b1_ref = b1_ref.view(torch.float4_e2m1fn_x2)
b2_ref = b2_ref.view(torch.float4_e2m1fn_x2)

# Create float16 output tensor
c_ref = torch.randn((l, m, n), dtype=torch.float16, device="cuda").permute(
1, 2, 0
)

# Helper function to prepare the scale factor tensors for both reference
# kernel and customize kernel. The customized data layout can be found in:
# https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout
def create_scale_factor_tensors(l, mn, sf_k):
# Create the reference scale factor tensor (mn, sf_k, l) on CPU.
ref_shape = (l, mn, sf_k)
ref_permute_order = (1, 2, 0)
# Init with uint8 tensor, then convert to float8_e4m3fn
ref_f8_random_int = torch.randint(1, 3, ref_shape, dtype=torch.int8, device='cuda')
ref_f8_torch_tensor = ref_f8_random_int.to(dtype=torch.float8_e4m3fn)
# permute to match ref_permute_order
ref_f8_torch_tensor_permuted = ref_f8_torch_tensor.permute(*ref_permute_order)

atom_m = (32, 4)
atom_k = 4
mma_shape = (
l, # batch size
ceil_div(mn, atom_m[0] * atom_m[1]),
ceil_div(sf_k, atom_k),
atom_m[0],
atom_m[1],
atom_k,
)

# Reorder scale factor tensor to (32, 4, rest_m, 4, rest_k, l) layout
# Which is needed by the CuTe customized kernel
mma_permute_order = (3, 4, 1, 5, 2, 0)
# Generate a random int8 tensor, then convert to float8_e4m3fn
rand_int_tensor = torch.randint(0, 2, mma_shape, dtype=torch.int8, device='cuda')
reordered_f8_torch_tensor = rand_int_tensor.to(dtype=torch.float8_e4m3fn)
# Permute according to mma_permute_order
reordered_f8_torch_tensor = reordered_f8_torch_tensor.permute(*mma_permute_order)

# GPU-side vectorized reordering (replaces slow CPU nested loops)
# Create index grids for all dimensions
i_idx = torch.arange(mn, device='cuda')
j_idx = torch.arange(sf_k, device='cuda')
b_idx = torch.arange(l, device='cuda')

# Create meshgrid for all combinations of (i, j, b)
i_grid, j_grid, b_grid = torch.meshgrid(i_idx, j_idx, b_idx, indexing='ij')

# Calculate target indices in vectorized manner
mm = i_grid // (atom_m[0] * atom_m[1])
mm32 = i_grid % atom_m[0]
mm4 = (i_grid % 128) // atom_m[0]
kk = j_grid // atom_k
kk4 = j_grid % atom_k

# Perform the reordering with advanced indexing (all on GPU)
reordered_f8_torch_tensor[mm32, mm4, mm, kk4, kk, b_grid] = ref_f8_torch_tensor_permuted[i_grid, j_grid, b_grid]

return ref_f8_torch_tensor_permuted.cpu(), reordered_f8_torch_tensor

sf_k = ceil_div(k, sf_vec_size)
sfa_ref_cpu, sfa_ref_permuted = create_scale_factor_tensors(l, m, sf_k)
sfb1_ref_cpu, sfb1_ref_permuted = create_scale_factor_tensors(l, n, sf_k)
sfb2_ref_cpu, sfb2_ref_permuted = create_scale_factor_tensors(l, n, sf_k)

return (a_ref, b1_ref, b2_ref, sfa_ref_cpu, sfb1_ref_cpu, sfb2_ref_cpu, sfa_ref_permuted, sfb1_ref_permuted, sfb2_ref_permuted, c_ref)


check_implementation = make_match_reference(ref_kernel, rtol=1e-03, atol=1e-03)
Loading