Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -1649,6 +1649,8 @@ def max_pool_with_indices_backward_meta(
nll_loss = _register_torch_operation("nll_loss", module=torch.nn.functional)
pad = _register_torch_operation("pad", module=torch.nn.functional)
scaled_dot_product_attention = _register_torch_operation("scaled_dot_product_attention", module=torch.nn.functional)
if hasattr(torch.nn.functional, "scaled_mm"):
scaled_mm = _register_torch_operation("scaled_mm", module=torch.nn.functional)
softmax = _register_torch_operation("softmax", like=ltorch._softmax)


Expand Down Expand Up @@ -1975,6 +1977,8 @@ def adaptive_avg_pool2d_bwd_wrapper(
pad_prim_impl = ex.register_operator("torch_pad_prim_impl", meta=prims.pad.meta, fn=_pad_prim_impl)
_register_implementation(prims.pad, pad_prim_impl, checker=_always_executable)
_register_implementation(ltorch._softmax, checker=_always_executable, execution_transform=_softmax_transform)
if hasattr(torch.nn.functional, "scaled_mm"):
_register_implementation(ltorch.scaled_mm, scaled_mm, checker=_always_executable)
_register_implementation(ltorch.scaled_dot_product_attention, scaled_dot_product_attention, checker=_always_executable)


Expand Down
300 changes: 300 additions & 0 deletions thunder/tests/test_ops.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from collections.abc import Callable
import os
import functools

import numpy as np
import pytest
import torch
from torch.testing import assert_close

if hasattr(torch.nn.functional, "scaled_mm"):
from torch.nn.functional import ScalingType, SwizzleType

import thunder
import thunder.core.devices as devices
import thunder.core.dtypes as dtypes
Expand Down Expand Up @@ -419,6 +423,302 @@ def fn(a):
assert_close(b, b_ref)


def _cuda_version_tuple() -> tuple[int, int] | None:
if torch.version.cuda is None:
return None
parts = torch.version.cuda.split(".")
try:
major = int(parts[0])
minor = int(parts[1]) if len(parts) > 1 else 0
return major, minor
except ValueError:
return None


def _require_scaled_mm(fn):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
if not hasattr(torch.nn.functional, "scaled_mm"):
pytest.skip("torch.nn.functional.scaled_mm is not found in this PyTorch")
return fn(*args, **kwargs)

return wrapper


def _ensure_fp8_tensorwise(device: torch.device) -> None:
if torch.cuda.get_device_capability(device) < (8, 9):
pytest.skip("scaled_mm tensor-wise support requires SM89 or newer")


def _require_fp8_tensorwise(fn):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
device = torch.device("cuda")
_ensure_fp8_tensorwise(device)
return fn(*args, **kwargs)

return wrapper


def _require_fp8_rowwise(device: torch.device) -> None:
_ensure_fp8_tensorwise(device)
if torch.cuda.get_device_capability(device) < (9, 0):
pytest.skip("row-wise scaled_mm requires SM90 or newer")
cuda_version = _cuda_version_tuple()
if cuda_version is not None and cuda_version < (12, 9):
pytest.skip("row-wise scaled_mm requires CUDA 12.9 or newer")


def _require_fp8_blockwise(device: torch.device) -> None:
_require_fp8_rowwise(device)


# Adapted from https://github.com/pytorch/pytorch/blob/b4403bfc62ca97eec554cdf815baab1fe93057d9/test/test_scaled_matmul_cuda.py#L645-L659
@requiresCUDA
@_require_fp8_tensorwise
@_require_scaled_mm
def test_scaled_mm_tensorwise_matches_torch():
device = torch.device("cuda")

def reference_fn(mat_a, mat_b, scale_a, scale_b):
return torch.nn.functional.scaled_mm(
mat_a,
mat_b,
scale_a,
ScalingType.TensorWise,
scale_b,
ScalingType.TensorWise,
swizzle_a=SwizzleType.NO_SWIZZLE,
swizzle_b=SwizzleType.NO_SWIZZLE,
output_dtype=torch.bfloat16,
)

M, K, N = 16, 32, 16
mat_a = torch.randn(M, K, device=device, dtype=torch.float32)
mat_b = torch.randn(K, N, device=device, dtype=torch.float32)
mat_a_lp = mat_a.to(torch.float8_e4m3fn)
mat_b_lp = mat_b.to(torch.float8_e4m3fn)
scale_a = torch.tensor(1.0, device=device, dtype=torch.float32)
scale_b = torch.tensor(1.0, device=device, dtype=torch.float32)

try:
expected = reference_fn(mat_a_lp, mat_b_lp, scale_a, scale_b)
except (NotImplementedError, RuntimeError) as exc:
pytest.skip(str(exc))

jf = thunder.jit(reference_fn)
result = jf(mat_a_lp, mat_b_lp, scale_a, scale_b)
assert_close(result, expected)


# Adapted from https://github.com/pytorch/pytorch/blob/b4403bfc62ca97eec554cdf815baab1fe93057d9/test/test_scaled_matmul_cuda.py#L862-L910
@requiresCUDA
@_require_fp8_tensorwise
@_require_scaled_mm
def test_scaled_mm_matches_scaled_data():
device = torch.device("cuda")

def quantize_to_fp8(tensor):
dtype = torch.float8_e4m3fn
max_val = torch.finfo(dtype).max
amax = tensor.abs().max()
encode = (max_val / torch.clamp(amax, min=1e-12)).to(torch.float32)
quant = torch.clamp(tensor * encode, min=-max_val, max=max_val).to(dtype)
decode = encode.reciprocal()
return quant, decode, encode

def scaled_mm_fp8(mat_a, mat_b, scale_a, scale_b, *, out_dtype):
return torch.nn.functional.scaled_mm(
mat_a,
mat_b,
scale_a,
ScalingType.TensorWise,
scale_b,
ScalingType.TensorWise,
swizzle_a=SwizzleType.NO_SWIZZLE,
swizzle_b=SwizzleType.NO_SWIZZLE,
output_dtype=out_dtype,
)

M, K, N = 32, 64, 32
mat_a = torch.randn(M, K, device=device, dtype=torch.float32)
mat_b_base = torch.randn(N, K, device=device, dtype=torch.float32)

mat_a_lp, decode_a, encode_a = quantize_to_fp8(mat_a)
mat_b_lp_pre, decode_b, encode_b = quantize_to_fp8(mat_b_base)
# To use cublaslt, the second matrix needs to be column-major.
mat_b_lp = mat_b_lp_pre.t()

try:
reference = scaled_mm_fp8(mat_a_lp, mat_b_lp, decode_a, decode_b, out_dtype=torch.float32)
except (NotImplementedError, RuntimeError) as exc:
pytest.skip(str(exc))

jf = thunder.jit(lambda a, b, sa, sb: scaled_mm_fp8(a, b, sa, sb, out_dtype=torch.float32))
thunder_out = jf(mat_a_lp, mat_b_lp, decode_a, decode_b)

assert_close(thunder_out, reference)


@requiresCUDA
@_require_scaled_mm
def test_scaled_mm_rowwise_matches_torch():
device = torch.device("cuda")
_require_fp8_rowwise(device)

def reference_fn(mat_a, mat_b, scale_a, scale_b):
return torch.nn.functional.scaled_mm(
mat_a,
mat_b,
scale_a,
ScalingType.RowWise,
scale_b,
ScalingType.RowWise,
swizzle_a=SwizzleType.NO_SWIZZLE,
swizzle_b=SwizzleType.NO_SWIZZLE,
output_dtype=torch.bfloat16,
)

M, K, N = 16, 32, 16
mat_a = torch.randn(M, K, device=device, dtype=torch.float32)
mat_b_base = torch.randn(N, K, device=device, dtype=torch.float32)
mat_a_lp = mat_a.to(torch.float8_e4m3fn)
# To use cublaslt, the second matrix needs to be column-major.
mat_b_lp = mat_b_base.to(torch.float8_e4m3fn).t()
scale_a = torch.ones((M, 1), device=device, dtype=torch.float32)
scale_b = torch.ones((1, N), device=device, dtype=torch.float32)

try:
expected = reference_fn(mat_a_lp, mat_b_lp, scale_a, scale_b)
except (NotImplementedError, RuntimeError) as exc:
pytest.skip(str(exc))

jf = thunder.jit(reference_fn)
result = jf(mat_a_lp, mat_b_lp, scale_a, scale_b)
assert_close(result, expected)


@requiresCUDA
@_require_scaled_mm
def test_scaled_mm_rowwise_matches_scaled_data():
device = torch.device("cuda")
_require_fp8_rowwise(device)

dtype_fp8 = torch.float8_e4m3fn
max_val = torch.finfo(dtype_fp8).max

def rowwise_quantize(tensor, *, dim):
amax = tensor.abs().amax(dim=dim, keepdim=True)
encode = (max_val / torch.clamp(amax, min=1e-12)).to(torch.float32)
quant = torch.clamp(tensor * encode, min=-max_val, max=max_val).to(dtype_fp8)
decode = encode.reciprocal()
return quant, decode, encode

def scaled_mm_rowwise(mat_a, mat_b, scale_a, scale_b, *, out_dtype):
return torch.nn.functional.scaled_mm(
mat_a,
mat_b,
scale_a,
ScalingType.RowWise,
scale_b,
ScalingType.RowWise,
swizzle_a=SwizzleType.NO_SWIZZLE,
swizzle_b=SwizzleType.NO_SWIZZLE,
output_dtype=out_dtype,
)

M, K, N = 32, 64, 32
mat_a = torch.randn(M, K, device=device, dtype=torch.bfloat16)
mat_b = torch.randn(K, N, device=device, dtype=torch.bfloat16)

mat_a_lp, decode_a, encode_a = rowwise_quantize(mat_a.to(torch.float32), dim=1)
mat_b_lp, decode_b, encode_b = rowwise_quantize(mat_b.to(torch.float32), dim=0)

try:
reference = scaled_mm_rowwise(mat_a_lp, mat_b_lp, decode_a, decode_b, out_dtype=torch.bfloat16)
except (NotImplementedError, RuntimeError) as exc:
pytest.skip(str(exc))

jf = thunder.jit(lambda a, b, sa, sb: scaled_mm_rowwise(a, b, sa, sb, out_dtype=torch.bfloat16))
thunder_out = jf(mat_a_lp, mat_b_lp, decode_a, decode_b)

reference_f32 = reference.to(torch.float32)
thunder_out_f32 = thunder_out.to(torch.float32)

assert_close(thunder_out_f32, reference_f32, atol=3e-2, rtol=3e-2)


def _blockwise_quantize(tensor: torch.Tensor, block_rows: int, block_cols: int) -> tuple[torch.Tensor, torch.Tensor]:
dtype_fp8 = torch.float8_e4m3fn
max_val = torch.finfo(dtype_fp8).max

M, K = tensor.shape
assert M % block_rows == 0 and K % block_cols == 0

reshaped = tensor.reshape(M // block_rows, block_rows, K // block_cols, block_cols)
amax = reshaped.abs().amax(dim=(1, 3), keepdim=True)
encode = (max_val / torch.clamp(amax, min=1e-12)).to(torch.float32)
quant = torch.clamp(reshaped * encode, min=-max_val, max=max_val).to(dtype_fp8)

return quant.reshape(M, K), encode.reshape(M // block_rows, K // block_cols).to(tensor.device)


@requiresCUDA
@_require_scaled_mm
@pytest.mark.parametrize("output_dtype", [torch.bfloat16])
@pytest.mark.parametrize("lhs_block,rhs_block", [(1, 1), (128, 1), (1, 128)])
def test_scaled_mm_blockwise_matches_torch(output_dtype, lhs_block, rhs_block):
device = torch.device("cuda")
_require_fp8_blockwise(device)

M, K, N = 256, 256, 256
mat_a = torch.randn(M, K, device=device, dtype=output_dtype).pow(3)
mat_b_rows = torch.randn(N, K, device=device, dtype=output_dtype).pow(3)

mat_a_lp, encode_a = _blockwise_quantize(mat_a.to(torch.float32), lhs_block, 128)
mat_b_lp_rows, encode_b = _blockwise_quantize(mat_b_rows.to(torch.float32), rhs_block, 128)
mat_b_lp = mat_b_lp_rows.t().contiguous()

scale_a = encode_a.reciprocal().contiguous()
scale_b = encode_b.reciprocal().t().contiguous()

recipe_map = {
1: ScalingType.BlockWise1x128,
128: ScalingType.BlockWise128x128,
}

try:
expected = torch.nn.functional.scaled_mm(
mat_a_lp,
mat_b_lp,
scale_a,
recipe_map[lhs_block],
scale_b,
recipe_map[rhs_block],
swizzle_a=SwizzleType.NO_SWIZZLE,
swizzle_b=SwizzleType.NO_SWIZZLE,
output_dtype=output_dtype,
)
except (RuntimeError, NotImplementedError, ValueError) as exc:
pytest.skip(str(exc))

fn = thunder.jit(
lambda a, b, sa, sb: torch.nn.functional.scaled_mm(
a,
b,
sa,
recipe_map[lhs_block],
sb,
recipe_map[rhs_block],
swizzle_a=SwizzleType.NO_SWIZZLE,
swizzle_b=SwizzleType.NO_SWIZZLE,
output_dtype=output_dtype,
)
)
thunder_out = fn(mat_a_lp, mat_b_lp, scale_a, scale_b)
assert_close(thunder_out, expected)


# https://github.com/Lightning-AI/lightning-thunder/issues/1857
def test_max_with_int():
def f(x, ids):
Expand Down
Loading
Loading