Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 202 additions & 0 deletions thunder/benchmarks/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@
HFBenchmark,
LinearLoRABenchmark,
DeepSeekSGLangMoEBenchmark,
UserFacingBenchmarkMeta,
thunder_apex_executor,
thunder_apex_nvfuser_executor,
thunder_cudnn_executor,
thunder_cudnn_nvfuser_executor,
thunder_cudnn_layer_norm_executor,
thunder_executor,
thunderfx_executor,
thunder_sdpa_torch_compile_nvfuser_executor,
Expand Down Expand Up @@ -1028,3 +1030,203 @@ def test_optim_functional(
args, kwargs = bench.make_batch()

benchmark_for_compute_type(compute_type, benchmark, fn, args, kwargs)


def cutlass_dsl_ex_executor(fn: Callable) -> Callable:
from thunder.executors.cutlass_dsl_ex import cutlass_dsl_ex

torch.backends.cuda.matmul.allow_tf32 = True
return thunder.jit(fn, disable_torch_autograd=True, executors=[cutlass_dsl_ex])


def nvfuserex_executor(fn: Callable) -> Callable:
from thunder.executors.nvfuserex import nvfuserex

torch.backends.cuda.matmul.allow_tf32 = True
return thunder.jit(fn, executors=[nvfuserex])


class BaseBenchmarkForQuack(Benchmark, metaclass=UserFacingBenchmarkMeta):
Copy link

Copilot AI Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class does not call Benchmark.init during initialization. (BaseBenchmarkForQuack.init may be missing a call to a base class init)

Copilot uses AI. Check for mistakes.
from thunder.benchmarks import BenchmarkArg

_args = (
BenchmarkArg("shape", description="The shape of the input tensor"),
BenchmarkArg("dtype", description="The dtype of the input tensor"),
BenchmarkArg("fn", description="The function to benchmark"),
)

def __init__(self, shape: tuple[int, int], fn: Callable):
self.shape = shape
self._fn = fn

@property
def description(self) -> str:
return "Benchmark for cross_entropy, softmax, layernorm, and rmsnorm with cutlass_dsl_ex, nvfuserex, and torch_compile_executor"


class CrossEntropyBenchmarkForQuack(BaseBenchmarkForQuack):
def __init__(self, shape: tuple[int, int], dtype: torch.dtype):
super().__init__(shape, torch.nn.functional.cross_entropy)
self.dtype = dtype

def make_batch(self) -> tuple[list, dict]:
return [
torch.randn(self.shape, device="cuda", dtype=self.dtype),
torch.randint(0, 16, (self.shape[0],), device="cuda"),
], {}

@property
def name(self) -> str:
return f"CrossEntropyBenchmarkForQuack({self.shape})"

def fn(self) -> Callable:
def f(*args):
return self._fn(*args, reduction="none")

return f


class SoftmaxBenchmarkForQuack(BaseBenchmarkForQuack):
def __init__(self, shape: tuple[int, int], dtype: torch.dtype):
super().__init__(shape, torch.nn.functional.softmax)
self.dtype = dtype

def make_batch(self) -> tuple[list, dict]:
return [torch.randn(self.shape, device="cuda", dtype=self.dtype)], {}

@property
def name(self) -> str:
return f"SoftmaxBenchmarkForQuack({self.shape})"

def fn(self) -> Callable:
def f(*args):
return self._fn(*args, dim=-1)

return f


class LayerNormBenchmarkForQuack(BaseBenchmarkForQuack):
def __init__(self, shape: tuple[int, int], dtype: torch.dtype):
import torch.nn as nn

super().__init__(shape, torch.nn.functional.layer_norm)
self.dtype = dtype
self.layer = nn.LayerNorm(self.shape[1]).to(device="cuda", dtype=self.dtype)

def make_batch(self) -> tuple[list, dict]:
return [torch.randn(self.shape, device="cuda", dtype=self.dtype)], {}

@property
def name(self) -> str:
return f"LayerNormBenchmarkForQuack({self.shape})"

def fn(self) -> Callable:
def f(*args):
return self.layer(*args)

return f


class RMSNormBenchmarkForQuack(BaseBenchmarkForQuack):
def __init__(self, shape: tuple[int, int], dtype: torch.dtype):
import torch.nn as nn

super().__init__(shape, torch.nn.functional.rms_norm)
self.dtype = dtype
self.layer = nn.RMSNorm(self.shape[1]).to(device="cuda", dtype=self.dtype)

def make_batch(self) -> tuple[list, dict]:
return [torch.randn(self.shape, device="cuda", dtype=self.dtype)], {}

@property
def name(self) -> str:
return f"RMSNormBenchmarkForQuack({self.shape})"

def fn(self) -> Callable:
def f(*args):
return self.layer(*args)

return f


# Benchmark for cross_entropy, softmax, layernorm, and rmsnorm with cutlass_dsl_ex, nvfuserex, and torch_compile_executor
# Input shapes (M, N) should cover the following cases
quack_bench_executors = (
cutlass_dsl_ex_executor,
nvfuserex_executor,
torch_compile_executor,
)
quack_bench_shapes = (
(32768, 512),
(32768, 1024),
(32768, 2048),
(32768, 4096),
(32768, 8192),
(32768, 16384),
(32768, 32768),
(32768, 65536),
(32768, 131072),
(32768, 262144),
(8192, 512),
(8192, 1024),
(8192, 2048),
(8192, 4096),
(8192, 8192),
(8192, 16384),
(8192, 32768),
(8192, 65536),
(8192, 131072),
(8192, 262144),
)
quack_bench_shape_ids = [f"{m}_{n}" for m, n in quack_bench_shapes]
dtypes = (
torch.float32,
torch.bfloat16,
torch.float16,
)


def _run_benchmark_for_quack(
benchmark, executor, benchmark_cls, dtype, shape: tuple[int, int], compute_type: ComputeType
):
bench = benchmark_cls(shape, dtype)
args, kwargs = bench.make_batch()
fn = executor(bench.fn())
benchmark_for_compute_type(compute_type, benchmark, fn, args, kwargs)


@pytest.mark.parametrize("executor", quack_bench_executors)
@pytest.mark.parametrize("dtype", dtypes, ids=(str(d) for d in dtypes))
@pytest.mark.parametrize("shape", quack_bench_shapes, ids=quack_bench_shape_ids)
@parametrize_compute_type_only_inference
def test_benchmark_quack_cross_entropy(benchmark, executor, dtype, shape: tuple[int, int], compute_type: ComputeType):
_run_benchmark_for_quack(benchmark, executor, CrossEntropyBenchmarkForQuack, dtype, shape, compute_type)


@pytest.mark.parametrize("executor", quack_bench_executors)
@pytest.mark.parametrize("dtype", dtypes, ids=(str(d) for d in dtypes))
@pytest.mark.parametrize("shape", quack_bench_shapes, ids=quack_bench_shape_ids)
@parametrize_compute_type_only_inference
def test_benchmark_quack_softmax(benchmark, executor, dtype, shape: tuple[int, int], compute_type: ComputeType):
_run_benchmark_for_quack(benchmark, executor, SoftmaxBenchmarkForQuack, dtype, shape, compute_type)


quack_layer_norm_executors = quack_bench_executors
if thunder_cudnn_layer_norm_executor is not None:
quack_layer_norm_executors += (thunder_cudnn_layer_norm_executor,)


@pytest.mark.parametrize("executor", quack_layer_norm_executors)
@pytest.mark.parametrize("dtype", dtypes, ids=(str(d) for d in dtypes))
@pytest.mark.parametrize("shape", quack_bench_shapes, ids=quack_bench_shape_ids)
@parametrize_compute_type_only_inference
def test_benchmark_quack_layer_norm(benchmark, executor, dtype, shape: tuple[int, int], compute_type: ComputeType):
_run_benchmark_for_quack(benchmark, executor, LayerNormBenchmarkForQuack, dtype, shape, compute_type)


@pytest.mark.parametrize("executor", quack_bench_executors)
@pytest.mark.parametrize("dtype", dtypes, ids=(str(d) for d in dtypes))
@pytest.mark.parametrize("shape", quack_bench_shapes, ids=quack_bench_shape_ids)
@parametrize_compute_type_only_inference
def test_benchmark_quack_rms_norm(benchmark, executor, dtype, shape: tuple[int, int], compute_type: ComputeType):
_run_benchmark_for_quack(benchmark, executor, RMSNormBenchmarkForQuack, dtype, shape, compute_type)
Loading
Loading