Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
7f34e00
Implement `thunder.executors.custom_op_ex._override_custom_op_forward`
crcrpar Aug 28, 2025
60f393f
add custom_op to expected all-executors
crcrpar Aug 28, 2025
203fc8c
add doc
crcrpar Aug 28, 2025
1f45978
not working right now
crcrpar Aug 29, 2025
454efb1
add custom_op_ex to docs
crcrpar Aug 30, 2025
89d7517
fix bsym name check
crcrpar Sep 2, 2025
72edd3e
nvfuser custom op def for overriding
crcrpar Sep 2, 2025
69aab26
define 2nd custom module
crcrpar Sep 3, 2025
94c71c5
nvfuser translator for custom op
crcrpar Sep 8, 2025
b7efacf
parametrize `disable_torch_autograd`
crcrpar Sep 8, 2025
cce6fc3
cast result to a dtype
crcrpar Sep 8, 2025
fb40e78
fix signature of custom_op mul
crcrpar Sep 8, 2025
d1b9bed
Register custom_op as `is_prim=True`
crcrpar Sep 9, 2025
dfe9c1a
add `_register_nvfuser_translator` to `__all__`
crcrpar Sep 9, 2025
724c855
docstring
crcrpar Sep 9, 2025
612215f
remove unused `_is_custom_op_symbol`
crcrpar Sep 9, 2025
7841212
expose `_register_nvfuser_translator`
crcrpar Sep 9, 2025
ff5f59a
minor update
crcrpar Sep 9, 2025
607e412
backward extrace
crcrpar Sep 9, 2025
7d5906b
fixture to de-register custom op
crcrpar Sep 26, 2025
94e5459
deregister `custom_op` in fixture
crcrpar Sep 26, 2025
a284a4f
remove `_override_custom_op_forward` in favor of `_register_nvfuser_t…
crcrpar Sep 29, 2025
7f73d0a
remove `custom_op_ex` from docs
crcrpar Sep 29, 2025
9843581
check nvfuer availability in deregistration util func
crcrpar Oct 9, 2025
59b6fe7
test nvfuser 5230
crcrpar Oct 12, 2025
09fcd0f
Implement NVFP4 custom operations registration and quantization optio…
crcrpar Oct 12, 2025
ebc0de6
Refactor NVFP4 custom operations registration and quantization logic.…
crcrpar Oct 12, 2025
dfba25a
small fixes on the model
jjsjann123 Oct 22, 2025
281dda3
Merge remote-tracking branch 'origin/crpa/try-nvfuer5230' into jj/try…
jjsjann123 Nov 12, 2025
e865583
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions thunder/benchmarks/benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,11 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="[Experimental] Quantize nn.Linear to NVFP4. Note: nvfuser has not yet implemented nvfp4_matmul translator",
)
parser.add_argument(
"--quantize-linear",
action="store_true",
help="[Experimental] Quantize nn.Linear to NVFP4. Note: nvfuser has not yet implemented nvfp4_matmul translator",
)
parser.add_argument(
"--enable-nv-linear",
action="store_true",
Expand Down
27 changes: 14 additions & 13 deletions thunder/benchmarks/layers_for_inference_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def nvfuser_f16a_nvfp4weight_scaled_grouped_mm(
problem_sizes: torch.Tensor,
) -> torch.Tensor:
hp_weight = torch.empty(
(fp4_weight.size(0), fp4_weight.size(1), fp4_weight.size(2) * 2),
(fp4_weight.size(0), fp4_weight.size(1) * 2, fp4_weight.size(2)),
device=activation.device,
dtype=activation.dtype,
)
Expand All @@ -277,15 +277,15 @@ def _(
blockscale_offsets: torch.Tensor,
problem_sizes: torch.Tensor,
) -> torch.Tensor:
# fp4_weight shape: (groups, in_features, out_features // 2)
# fp4_weight shape: (groups, in_features // 2, out_features)
# Validate that activation has at least 1 dimension
if activation.ndim == 0:
raise ValueError(f"Expected activation to have at least 1 dimension, got {activation.ndim}")

# After unpacking: (groups, in_features, out_features)
# Output shape should match activation.shape[:-1] + (out_features,)
# This handles both 2D (tokens, hidden) and 3D (batch, seq_len, hidden) inputs
out_features = fp4_weight.size(2) * 2 # Unpack from FP4
out_features = fp4_weight.size(2)
output_shape = activation.shape[:-1] + (out_features,)
return torch.empty(output_shape, device=activation.device, dtype=activation.dtype)

Expand Down Expand Up @@ -398,31 +398,32 @@ def quantize_grouped_linear_weight_to_nvfp4(
"""Quantize grouped linear's weight to nvfp4

Args:
weight: Parameter of `GroupedLinear` of [g, n, k]
weight: Parameter of `GroupedLinear` of [g, k, n]

Returns:
fp4_weight: [g, n, k // 2]
fp4_weight: [g, k // 2, n]
scale_factors: [g, n, k // 16]
global_scales: [g]
"""
assert weight.ndim == 3, "Weight must be a 3D tensor"

device: torch.device = weight.device
g, n, k = weight.size()
g, k, n = weight.size()

with device:
fp4_weight = torch.empty((g, n, k // 2), dtype=torch.float4_e2m1fn_x2)
global_scales = torch.empty((g,), dtype=torch.float32)
scale_factors = torch.empty((g, n, k // 16), dtype=torch.float8_e4m3fn)

weight = weight.transpose(-1, -2).contiguous()
for i in range(g):
cur_weight = weight[i]
global_scales[i] = cur_weight.abs().amax()
cur_fp4_weight, cur_scale_factors = pytorch_nvfp4_quantize(cur_weight, global_scales[i])
fp4_weight[i] = cur_fp4_weight
scale_factors[i] = linear_to_swizzled_128_4(cur_scale_factors)

return fp4_weight, scale_factors, global_scales
return fp4_weight.transpose(-1, -2), scale_factors, global_scales


class NVFP4InferenceGroupedLinear(nn.Module):
Expand Down Expand Up @@ -451,8 +452,8 @@ def compute_auxiliary_tensors(
problem_sizes = torch.stack(
[
tokens_per_group,
torch.full_like(tokens_per_group, hidden_states.size(1)),
torch.full_like(tokens_per_group, out_features),
torch.full_like(tokens_per_group, hidden_states.size(1)),
],
dim=1,
)
Expand All @@ -463,7 +464,7 @@ def compute_auxiliary_tensors(
torch.zeros(1, dtype=torch.int32, device=tokens_per_group.device),
torch.cumsum(rounded_tokens, 0, dtype=torch.int32),
]
)
)[0:-1]
return blockscale_offsets, problem_sizes

# TODO: Update this accordingly to the progress of nvfp4 kernel implementation.
Expand All @@ -476,14 +477,14 @@ def forward(
) -> torch.Tensor:
if blockscale_offsets is None or problem_sizes is None:
# Compute them if not provided (backward compatibility)
out_features = self.fp4_weight.size(2) * 2
out_features = self.fp4_weight.size(2)
blockscale_offsets, problem_sizes = self.compute_auxiliary_tensors(hidden_states, offsets, out_features)
return torch.ops.nvf_cutlass.f16a_nvfp4weight_scaled_grouped_mm(
hidden_states,
self.fp4_weight,
self.weight_scaling_factor,
self.weight_global_scale,
offsets,
offsets[:-1],
blockscale_offsets,
problem_sizes,
)
Expand Down Expand Up @@ -529,7 +530,7 @@ def __init__(

def forward(self, hidden_states: torch.Tensor, offsets: torch.Tensor) -> torch.Tensor:
# Compute auxiliary tensors once for all three operations
intermediate_features = self.gate_proj.fp4_weight.size(2) * 2
intermediate_features = self.gate_proj.fp4_weight.size(2)
blockscale_offsets_gate, problem_sizes_gate = NVFP4InferenceGroupedLinear.compute_auxiliary_tensors(
hidden_states, offsets, intermediate_features
)
Expand All @@ -540,7 +541,7 @@ def forward(self, hidden_states: torch.Tensor, offsets: torch.Tensor) -> torch.T
intermediate = torch.nn.functional.silu(gate_out) * up_out

# For down_proj, we need different problem_sizes (different output features)
hidden_features = self.down_proj.fp4_weight.size(2) * 2
hidden_features = self.down_proj.fp4_weight.size(2)
blockscale_offsets_down, problem_sizes_down = NVFP4InferenceGroupedLinear.compute_auxiliary_tensors(
intermediate, offsets, hidden_features
)
Expand Down
1 change: 1 addition & 0 deletions thunder/torch/custom_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,3 +607,4 @@ def mul_translator(a, b, c=None, *, fd, lc_to_nv_map):
from thunder.executors.torchex import _always_executable

register_supported(symbol, translator_for_nvfuser, checker or _always_executable)
# register_supported(symbol.id, translator_for_nvfuser, checker or _always_executable)
Loading