diff --git a/.github/container/cutlass_dsl_jax/pyproject.toml b/.github/container/cutlass_dsl_jax/pyproject.toml index 4bbc0941b..d7ffc523b 100644 --- a/.github/container/cutlass_dsl_jax/pyproject.toml +++ b/.github/container/cutlass_dsl_jax/pyproject.toml @@ -5,7 +5,7 @@ readme = "README.md" requires-python = ">=3.12" dependencies = [ "jax>=0.6.2", - "nvidia-cutlass-dsl>=4.2.1" + "nvidia-cutlass-dsl>=4.3.1" ] dynamic = ["version"] diff --git a/.github/container/cutlass_dsl_jax/src/jax_cutlass/compile.py b/.github/container/cutlass_dsl_jax/src/jax_cutlass/compile.py index 4f904cfa1..5cee36710 100644 --- a/.github/container/cutlass_dsl_jax/src/jax_cutlass/compile.py +++ b/.github/container/cutlass_dsl_jax/src/jax_cutlass/compile.py @@ -16,7 +16,7 @@ import gc import ctypes import inspect -from typing import Any, Callable +from typing import Any, Callable, Optional from dataclasses import dataclass from functools import partial from pathlib import Path @@ -77,6 +77,8 @@ class FunctionSpec: output_layout: tuple[tuple[int, ...]] output_mode: tuple[TensorMode, ...] convert_tensors: bool + compile_options: str + use_static_tensors: bool kwargs: tuple[tuple[str, Any]] def get_compile_args(self): @@ -121,8 +123,14 @@ def jit_wrapper( # split buffer argument into inputs and outputs and return to tree ins, outs = args[: len(spec.in_args)], args[(len(spec.in_args)) :] if cutlass.const_expr(spec.convert_tensors): - ins = [x.get_tensor(a.mode) for x, a in zip(ins, spec.in_args)] - outs = [x.get_tensor(a.mode) for x, a in zip(outs, spec.out_args)] + ins = [ + x.get_tensor(a.mode, spec.use_static_tensors) + for x, a in zip(ins, spec.in_args) + ] + outs = [ + x.get_tensor(a.mode, spec.use_static_tensors) + for x, a in zip(outs, spec.out_args) + ] ins = jax.tree.unflatten(spec.input_tree, ins) outs = jax.tree.unflatten(spec.output_tree, outs) wrapped_fn(stream, *ins, *outs, **dict(spec.kwargs)) @@ -176,6 +184,8 @@ def build_function_spec( output_mode, input_output_aliases, convert_tensors, + compile_options, + use_static_tensors, kwargs, ): # TODO: Improve type checking and validate pytree structures. @@ -233,6 +243,8 @@ def build_function_spec( tuple(output_layout), tuple(output_mode), convert_tensors, + compile_options, + use_static_tensors, tuple((k, kwargs[k]) for k in kwargs), ) @@ -260,7 +272,11 @@ def get_or_compile_kernel(fn, spec, stream): with _compile_lock: start = time.time() try: - compiled_fn = cutlass.cute.compile( + cute_compile = cutlass.cute.compile + if spec.compile_options: + cute_compile = partial(cute_compile, options=spec.compile_options) + + compiled_fn = cute_compile( jit_wrapper, cuda.CUstream(stream), spec.get_compile_args(), @@ -313,15 +329,6 @@ def initialize_cutlass_dsl(): if _CUTLASS_DSL_INITIALIZED: return - # TODO(mgoldfarb-nvidia): There are several runtime libraries that export C++ symbols - # which conflict with jax libraries. Initializing cutlass before jax will cause these - # symbols to incorrectly interpose. Our WAR is to for loading of jaxlib and its - # dependant libraries to ensure all symbols are loaded prior to compiling cutedsl programs. - # This linking issue is planed to be resolved in cute DSL 4.3. - jaxlib_common = Path(jaxlib.__file__).parent / "libjax_common.so" - if jaxlib_common.exists(): - ctypes.CDLL(str(jaxlib_common), mode=ctypes.RTLD_GLOBAL) - kernel = _DummyInitKernel() with _compile_lock: logger.debug("Initializing cutlass dsl...") diff --git a/.github/container/cutlass_dsl_jax/src/jax_cutlass/primitive.py b/.github/container/cutlass_dsl_jax/src/jax_cutlass/primitive.py index a60adcddb..31b6abcdb 100644 --- a/.github/container/cutlass_dsl_jax/src/jax_cutlass/primitive.py +++ b/.github/container/cutlass_dsl_jax/src/jax_cutlass/primitive.py @@ -57,6 +57,8 @@ def cutlass_call( input_output_aliases={}, convert_tensors=True, allow_cuda_graph=True, + compile_options=None, + use_static_tensors=False, **kwargs, ): """Creates a callable that invokes a @cute.jit function. @@ -77,6 +79,11 @@ def cutlass_call( convert_tensors: Jax array buffers will be converted to cute.Tensor with static shape and layout. If disabled the kernel is instead given a JaxArray pointer directly. allow_cuda_graph: If false will prevent XLA from building a cuda graph of for this call. + compile_options: Optional compiler arguments to pass into cute.compile. + use_static_tensors: If True, tensor shapes and strides are treated as constexpr values by + default. This can improve performance through compiler specialization but may not work + properly with all kernels. Specific tensors may be marked static or dynamic using the mode + and override this flag. kwargs: Optional constexpr parameters to pass into the kernel fn. Note: This API is experimental and subject to change! @@ -94,6 +101,8 @@ def cutlass_call( input_output_aliases=input_output_aliases, convert_tensors=convert_tensors, allow_cuda_graph=allow_cuda_graph, + compile_options=compile_options, + use_static_tensors=use_static_tensors, **kwargs, ) @@ -127,6 +136,8 @@ def _cutlass_call_impl( input_output_aliases, convert_tensors, allow_cuda_graph, + compile_options, + use_static_tensors, **kwargs, ): multiple_results = isinstance(output_shape_dtype, Sequence) @@ -200,7 +211,9 @@ def call_wrapper(*args): # information we got as input. for idx, (arg, mode) in enumerate(zip(args_flat, input_mode_flat)): if mode.mode is not None and len(mode.mode) != len(arg.shape): - raise ValueError(f"Input #{idx} has invalid mode.") + raise ValueError( + f"Input #{idx} has invalid mode {mode.mode} for shape {arg.shape}." + ) for idx, (arg, mode) in enumerate( zip(output_shape_dtype_flat, output_mode_flat) ): @@ -220,6 +233,8 @@ def call_wrapper(*args): input_output_aliases=tuple(input_output_aliases.items()), convert_tensors=convert_tensors, allow_cuda_graph=allow_cuda_graph, + compile_options=compile_options, + use_static_tensors=use_static_tensors, **kwargs, ) @@ -247,6 +262,8 @@ def cutlass_call_inner_p_impl( input_output_aliases, convert_tensors, allow_cuda_graph, + compile_options, + use_static_tensors, **kwargs, ): input_output_aliases = dict(input_output_aliases) @@ -266,6 +283,8 @@ def cutlass_call_inner_p_impl( output_mode_flat, input_output_aliases, convert_tensors, + compile_options, + use_static_tensors, kwargs, ) diff --git a/.github/container/cutlass_dsl_jax/src/jax_cutlass/types.py b/.github/container/cutlass_dsl_jax/src/jax_cutlass/types.py index 4ac538db0..4b0f333b1 100644 --- a/.github/container/cutlass_dsl_jax/src/jax_cutlass/types.py +++ b/.github/container/cutlass_dsl_jax/src/jax_cutlass/types.py @@ -22,7 +22,6 @@ from operator import mul from itertools import chain from typing import Annotated -from enum import Enum import cuda.bindings.driver as cuda @@ -71,11 +70,12 @@ class TensorMode: Arguments: mode : Specifies the position of each mode in the tensor (M0, M1, ... MN) """ + mode: tuple[int, ...] | None = field(metadata=dict(static=True), default=None) # Indicates the shape and strides will be defined statically. Enabling may enable # additional optimization. Kernels that do not support static shapes will generate # compile errors if this is enabled so we leave it off by default. - static: bool = field(metadata=dict(static=True), default=False) + static: bool = field(metadata=dict(static=True), default=None) # Overrides the default pointer alignment. Generally this should not be changed # but is left here to provide a hook. ptr_assumed_align: int = field( @@ -243,7 +243,12 @@ def align(self, min_align: int, *, loc=None, ip=None) -> "JaxArray": return JaxArray(self.ptr.align(min_align, loc, ip), self._shape, self._order) def get_layout( - self, mode: tuple[int, ...] | TensorMode = None, *, loc=None, ip=None + self, + mode: tuple[int, ...] | TensorMode = None, + use_static_tensors: bool = False, + *, + loc=None, + ip=None, ) -> cute.Layout: """Create a cute.Layout from this JaxArray. @@ -256,26 +261,34 @@ def get_layout( :type tuple[int,...]: Tuple that is same size as shape. """ if isinstance(mode, (tuple, list)): - mode = TensorMode(mode) + mode = TensorMode(mode, static=use_static_tensors) + + if (mode.static is None and use_static_tensors) or mode.static: + shape = self._shape + else: + shape = [cutlass.as_numeric(m) for m in self._shape] - shape = ( - self._shape if mode.static else [cutlass.as_numeric(m) for m in self._shape] - ) layout = cute.make_ordered_layout(tuple(shape), self._order, loc=loc, ip=ip) if mode is not None and mode.mode is not None: layout = cute.select(layout, mode.mode) return layout def get_tensor( - self, mode: tuple[int, ...] | TensorMode = None, *, loc=None, ip=None + self, + mode: tuple[int, ...] | TensorMode = None, + use_static_tensors: bool = False, + *, + loc=None, + ip=None, ) -> cute.Tensor: """Create a cute.Tensor from this JaxArray. :param mode: Maps the physical shape dimension to logical shape dimensions. If not given the physical layout is used. + :param use_static_tensors: Defaults tensor shape and stride to static if no mode is given. :type tuple[int,...]: Tuple that is same size as shape. :see get_layout """ - layout = self.get_layout(mode, loc=loc, ip=ip) + layout = self.get_layout(mode, use_static_tensors, loc=loc, ip=ip) return cute.make_tensor(self.ptr, layout) # Utility methods diff --git a/.github/container/cutlass_dsl_jax/src/jax_cutlass/version.py b/.github/container/cutlass_dsl_jax/src/jax_cutlass/version.py index 24cc4300f..e20c6a7a2 100644 --- a/.github/container/cutlass_dsl_jax/src/jax_cutlass/version.py +++ b/.github/container/cutlass_dsl_jax/src/jax_cutlass/version.py @@ -13,5 +13,5 @@ # limitations under the License. -__version_info__ = (0, 2, 0) +__version_info__ = (0, 3, 0) __version__ = ".".join(str(v) for v in __version_info__) diff --git a/.github/container/cutlass_dsl_jax/tests/blackwell/__init__.py b/.github/container/cutlass_dsl_jax/tests/blackwell/__init__.py new file mode 100644 index 000000000..070b8c0d7 --- /dev/null +++ b/.github/container/cutlass_dsl_jax/tests/blackwell/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/.github/container/cutlass_dsl_jax/tests/blackwell/test_block_scaled_gemm.py b/.github/container/cutlass_dsl_jax/tests/blackwell/test_block_scaled_gemm.py new file mode 100644 index 000000000..59a477855 --- /dev/null +++ b/.github/container/cutlass_dsl_jax/tests/blackwell/test_block_scaled_gemm.py @@ -0,0 +1,217 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from collections import defaultdict +from typing import List, Type, Tuple, Union, Optional +import os + +import pytest +import jax +import jax.numpy as jnp + +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils + +from jax_cutlass import cutlass_call, jax_to_cutlass_dtype + +from ..tensor import ( + create_a_tensor, + create_b_tensor, + create_cd_tensor, + gemm_a_mode, + gemm_b_mode, + gemm_c_mode, + gemm_c_shape, + gemm_reference_einsum, +) + +from blackwell.dense_blockscaled_gemm_persistent import ( + Sm100BlockScaledPersistentDenseGemmKernel, +) + + +@pytest.mark.parametrize( + "problem_size", + [ + pytest.param((8 * 1024, 8 * 1024, 8 * 1024, 1), id="M8092-N8092-K8092-L1"), + pytest.param((8 * 1024, 4 * 1024, 4 * 1024, 1), id="M8092-N4096-K4096-L1"), + pytest.param((16 * 1024, 16 * 1024, 16 * 1024, 1), id="M16K-N16K-K16-L1"), + ], +) +@pytest.mark.parametrize( + "mma_tile_shape_mn", + [ + pytest.param((128, 128), id="MMA_128x128"), + # pytest.param((256, 128), id="MMA_256x128"), + # pytest.param((256, 256), id="MMA_256x256"), + ], +) +@pytest.mark.parametrize( + "is_2sm, cluster_shape_mn", + [ + # pytest.param(False, (1, 1), id="1SM-1x1"), + pytest.param(False, (2, 1), id="1SM-2x1"), + # pytest.param(False, (2, 2), id="1SM-2x2"), + # pytest.param(False, (4, 1), id="1SM-4x1"), + pytest.param(True, (2, 1), id="2SM-2x1"), + # pytest.param(True, (2, 2), id="2SM-2x2"), + # pytest.param(True, (4, 1), id="2SM-4x1"), + ], +) +@pytest.mark.parametrize( + "ab_dtype, c_dtype, sf_dtype, sf_vec_size", + [ + pytest.param( + "float4_e2m1fn", "float16", "float8_e8m0fnu", 16, id="mxfp4xmxfp4xf16" + ), + pytest.param( + "float4_e2m1fn", "float16", "float8_e4m3fn", 16, id="nvfp4xnvfp4xf16" + ), + ], +) +@pytest.mark.parametrize( + "a_major, b_major, c_major", + [ + # n.b. only k major a/b is supported by this test fixture. + pytest.param("k", "k", "n", id="kkn_major"), + ], +) +@pytest.mark.requires_device("B200") +def test_dense_block_scaled_gemm( + benchmark, + problem_size, + mma_tile_shape_mn, + is_2sm, + cluster_shape_mn, + ab_dtype, + c_dtype, + sf_dtype, + sf_vec_size, + a_major, + b_major, + c_major, +): + def ceil_div(a, b): + return (a + b - 1) // b + + m, n, k, l = problem_size + sf_k = ceil_div(k, sf_vec_size) + + if not Sm100BlockScaledPersistentDenseGemmKernel.can_implement( + jax_to_cutlass_dtype(ab_dtype), + jax_to_cutlass_dtype(sf_dtype), + sf_vec_size, + jax_to_cutlass_dtype(c_dtype), + mma_tile_shape_mn, + cluster_shape_mn, + m, + n, + k, + l, + a_major, + b_major, + c_major, + ): + pytest.skip( + f"Sm100BlockScaledPersistentDenseGemmKernel does not support test config." + ) + + if not is_2sm and mma_tile_shape_mn[0] not in (64, 128): + pytest.skip(f"Skipping {is_2sm=} {mma_tile_shape_mn=}") + + akey, asfkey, bkey, bsfkey = jax.random.split(jax.random.key(1337), 4) + a = create_a_tensor(l, m, k, a_major, ab_dtype, akey, minval=-1.0, maxval=1.0) + b = create_b_tensor(l, n, k, b_major, ab_dtype, bkey, minval=-2.0, maxval=2.0) + + assert a_major == "k", "a_major must be k" + assert b_major == "k", "b_major must be k" + + # See https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-a-layout-1x + # Scale factors are using .scale_vec::4X / .block16 config to support nvfp4 and mxfp4 + atom_mn = (32, 4) + atom_k = 4 + + sfa = create_a_tensor(l, m, sf_k, a_major, sf_dtype, asfkey, minval=1.0, maxval=3.0) + sfa_ref = sfa + sfa = sfa.reshape( + l, + ceil_div(m, atom_mn[0] * atom_mn[1]), + atom_mn[1], + atom_mn[0], + ceil_div(sf_k, atom_k), + atom_k, + ) + # TODO: See if we can pass this layout mapping from jax primitive (it requires grouping) + sfa = sfa.transpose(0, 1, 4, 3, 2, 5) + + sfb = create_b_tensor(l, n, sf_k, b_major, sf_dtype, bsfkey, minval=1.0, maxval=3.0) + sfb_ref = sfb + sfb = sfb.reshape( + l, + ceil_div(n, atom_mn[0] * atom_mn[1]), + atom_mn[1], + atom_mn[0], + ceil_div(sf_k, atom_k), + atom_k, + ) + sfb = sfb.transpose(0, 1, 4, 3, 2, 5) + + gemm = Sm100BlockScaledPersistentDenseGemmKernel( + sf_vec_size, + mma_tile_shape_mn, + cluster_shape_mn, + ) + + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1] + ) + + def launch(a, b, sfa, sfb): + call = ( + lambda stream, a, b, sfa, sfb, c, *, max_active_clusters, epilogue_op: gemm( + a, b, sfa, sfb, c, max_active_clusters, stream, epilogue_op + ) + ) + return cutlass_call( + call, + input_mode=(gemm_a_mode(a_major), gemm_b_mode(b_major), None, None), + output_mode=(gemm_c_mode(c_major),), + output_shape_dtype=jax.ShapeDtypeStruct( + gemm_c_shape(l, m, n, c_major), c_dtype + ), + epilogue_op=lambda x: x, + max_active_clusters=max_active_clusters, + )(a, b, sfa, sfb) + + c = launch(a, b, sfa, sfb) + + c_ref = gemm_reference_einsum( + a, + b, + acc_dtype=jnp.float16, + c_dtype=c_dtype, + a_major=a_major, + b_major=b_major, + c_major=c_major, + sf_a=sfa_ref, + sf_b=sfb_ref, + ) + + assert jnp.allclose(c, c_ref) + + with benchmark.runner("blackwell_dense_block_scaled_gemm.txt") as runner: + runner(launch, a, b, sfa, sfb) diff --git a/.github/container/cutlass_dsl_jax/tests/blackwell/test_gemm.py b/.github/container/cutlass_dsl_jax/tests/blackwell/test_gemm.py new file mode 100644 index 000000000..72756c934 --- /dev/null +++ b/.github/container/cutlass_dsl_jax/tests/blackwell/test_gemm.py @@ -0,0 +1,162 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from collections import defaultdict +from typing import List, Type, Tuple, Union, Optional +import os + +import pytest +import jax +import jax.numpy as jnp + +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils + +from jax_cutlass import cutlass_call, jax_to_cutlass_dtype, TensorMode as T + +from ..tensor import ( + create_a_tensor, + create_b_tensor, + create_cd_tensor, + gemm_a_mode, + gemm_b_mode, + gemm_c_mode, + gemm_c_shape, + gemm_reference_einsum, +) + +from blackwell.dense_gemm_persistent import PersistentDenseGemmKernel + + +@pytest.mark.parametrize( + "problem_size", + [ + pytest.param((8 * 1024, 8 * 1024, 8 * 1024, 1), id="M8092-N8092-K8092-L1"), + # pytest.param((8 * 1024, 4 * 1024, 4 * 1024, 1), id="M8092-N4096-K4096-L1"), + # pytest.param((16 * 1024, 16 * 1024, 16 * 1024, 1), id="M16K-N16K-K16-L1"), + ], +) +@pytest.mark.parametrize( + "mma_tile_shape_mn", + [ + pytest.param((128, 128), id="MMA_128x128"), + # pytest.param((256, 128), id="MMA_256x128"), + # pytest.param((256, 256), id="MMA_256x256"), + ], +) +@pytest.mark.parametrize( + "is_2sm, cluster_shape_mn", + [ + pytest.param(False, (1, 1), id="1SM-1x1"), + # pytest.param(False, (2, 1), id="1SM-2x1"), + # pytest.param(False, (2, 2), id="1SM-2x2"), + # pytest.param(False, (4, 1), id="1SM-4x1"), + pytest.param(True, (2, 1), id="2SM-2x1"), + # pytest.param(True, (2, 2), id="2SM-2x2"), + # pytest.param(True, (4, 1), id="2SM-4x1"), + ], +) +@pytest.mark.parametrize( + "use_tma_store", + [ + pytest.param(False, id="NTS"), + pytest.param(True, id="TS"), + ], +) +@pytest.mark.parametrize( + "a_dtype, b_dtype, c_dtype, acc_dtype", + [ + pytest.param( + "float16", "float16", "float16", "float32", id="bf16xbf16xbf16xfp32" + ), + pytest.param( + "float8_e4m3fn", "float8_e4m3fn", "float16", "float32", id="fp8xfp8xf16xf32" + ), + ], +) +@pytest.mark.parametrize( + "a_major, b_major, c_major", + [ + pytest.param("k", "k", "n", id="kkn_major"), + # pytest.param("m", "n", "n", id="mnn_major"), + # pytest.param("m", "n", "m", id="mnm_major"), + ], +) +@pytest.mark.requires_device("B200") +def test_dense_gemm( + benchmark, + problem_size, + mma_tile_shape_mn, + is_2sm, + cluster_shape_mn, + use_tma_store, + a_dtype, + b_dtype, + c_dtype, + acc_dtype, + a_major, + b_major, + c_major, +): + if not is_2sm and mma_tile_shape_mn[0] not in (64, 128): + pytest.skip(f"Skipping {is_2sm=} {mma_tile_shape_mn=}") + + m, n, k, l = problem_size + + akey, bkey = jax.random.split(jax.random.key(1337), 2) + a = create_a_tensor(l, m, k, a_major, a_dtype, akey) + b = create_b_tensor(l, n, k, b_major, b_dtype, bkey) + + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1] + ) + gemm = PersistentDenseGemmKernel( + jax_to_cutlass_dtype(acc_dtype), + is_2sm, + mma_tile_shape_mn, + cluster_shape_mn, + use_tma_store, + ) + call = lambda stream, a, b, c, **kwargs: gemm( + a, b, c, max_active_clusters, stream, **kwargs + ) + + def launch(a, b): + return cutlass_call( + call, + input_mode=(gemm_a_mode(a_major), gemm_b_mode(b_major)), + output_mode=(gemm_c_mode(c_major),), + output_shape_dtype=jax.ShapeDtypeStruct( + gemm_c_shape(l, m, n, c_major), c_dtype + ), + epilogue_op=lambda x: x, + )(a, b) + + c = launch(a, b) + c_ref = gemm_reference_einsum( + a, + b, + acc_dtype=acc_dtype, + c_dtype=c_dtype, + a_major=a_major, + b_major=b_major, + c_major=c_major, + ) + assert jnp.allclose(c, c_ref) + + with benchmark.runner("blackwell_dense_gemm.txt") as runner: + runner(launch, a, b) diff --git a/.github/container/cutlass_dsl_jax/tests/blackwell/test_grouped_gemm.py b/.github/container/cutlass_dsl_jax/tests/blackwell/test_grouped_gemm.py new file mode 100644 index 000000000..6bed187c1 --- /dev/null +++ b/.github/container/cutlass_dsl_jax/tests/blackwell/test_grouped_gemm.py @@ -0,0 +1,477 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial, reduce +from collections import defaultdict +import pytest +import jax +import jax.numpy as jnp +import os + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils + +from jax_cutlass import cutlass_call, jax_to_cutlass_dtype, TensorMode as TM + +from ..tensor import ( + create_a_tensor, + create_b_tensor, + create_cd_tensor, + gemm_reference_einsum, + gemm_a_mode, + gemm_b_mode, + gemm_c_mode, +) + +# Import from cutlass examples +from blackwell.grouped_gemm import GroupedGemmKernel + +# Needed for int64 types +jax.config.update("jax_enable_x64", True) + + +class JaxGroupGemmKernel: + """A Jax wrapper around GroupGemmKernel. + + The jax flavor of group gemm takes as input a single unified tensor and runs an aux + kernel to extract the addresses of the groups. This allows the use of the existing + group gemm kernel from cutlass w/o modification. + """ + + def __init__( + self, + a_mode, + b_mode, + c_mode, + group_count, + acc_dtype, + is_2sm, + mma_tile_shape_mn, + cluster_shape_mn, + tensormap_update_mode, + num_tensormap_buffers, + max_active_clusters, + total_num_clusters, + ): + self._gemm = GroupedGemmKernel( + jax_to_cutlass_dtype(acc_dtype), + is_2sm, + mma_tile_shape_mn, + cluster_shape_mn, + tensormap_update_mode, + ) + self._a_mode = a_mode + self._b_mode = b_mode + self._c_mode = c_mode + self._group_count = group_count + self._num_tensormap_buffers = num_tensormap_buffers + self._max_active_clusters = max_active_clusters + self._total_num_clusters = total_num_clusters + + @partial(jax.jit, static_argnums=[0], donate_argnums=[3]) + def __call__( + self, + tensor_a, + tensor_b, + tensor_c, + group_offsets, + problem_sizes_mnkl, + strides_abc, + ): + + # Storage for tensormap in gmem + tensormap = jnp.zeros( + ( + self._num_tensormap_buffers, + GroupedGemmKernel.num_tensormaps, + GroupedGemmKernel.bytes_per_tensormap // 8, + ), + dtype=jnp.int64, + ) + + # Storage for pointer offsets to each tensor. + ptrs_abc = jnp.zeros((group_offsets.shape[0], 3), jnp.int64) + + c, tmap, ptrs = cutlass_call( + fn=self.launch, + output_shape_dtype=(tensor_c, tensormap, ptrs_abc), + input_output_aliases={2: 0, 7: 1, 6: 2}, + group_count=self._group_count, + total_num_clusters=self._total_num_clusters, + max_active_clusters=self._max_active_clusters, + input_mode=( + self._a_mode, + self._b_mode, + self._c_mode, + None, + None, + None, + None, + None, + ), + output_mode=(self._c_mode, None, None), + use_static_tensors=True, + )( + tensor_a, + tensor_b, + tensor_c, + group_offsets, + problem_sizes_mnkl, + strides_abc, + ptrs_abc, + tensormap, + ) + return c + + @cute.jit + def launch( + self, + stream: cuda.CUstream, + initial_a: cute.Tensor, + initial_b: cute.Tensor, + initial_c: cute.Tensor, + group_offsets: cute.Tensor, + problem_shape_mnkl: cute.Tensor, + strides_abc: cute.Tensor, + tensor_address_abc: cute.Tensor, + tensormap_cute_tensor: cute.Tensor, + *, + group_count: cutlass.Constexpr[int], + total_num_clusters: cutlass.Constexpr[int], + max_active_clusters: cutlass.Constexpr[int], + ): + extract_tensor_address_kernel( + group_offsets, initial_a, initial_b, initial_c, tensor_address_abc + ).launch( + stream=stream, grid=[tensor_address_abc.shape[0], 1, 1], block=[1, 1, 1] + ) + + self._gemm( + initial_a, + initial_b, + initial_c, + group_count, + problem_shape_mnkl, + strides_abc, + tensor_address_abc, + total_num_clusters, + tensormap_cute_tensor, + max_active_clusters, + stream, + ) + + @cute.kernel + def extract_tensor_address_kernel( + group_offsets: cute.Tensor, + tensor_a: cute.Tensor, + tensor_b: cute.Tensor, + tensor_c: cute.Tensor, + dst: cute.Tensor, + ): + # mkl, nkl, mnl + bidx, _, _ = cute.arch.block_idx() + + num_groups = group_offsets.shape[0] + group_offset = group_offsets[bidx] + per_expert_size = tensor_b.shape[0] // num_groups + + a_offset = ( + cute.Int64(group_offset) + * tensor_a.stride[0] + * tensor_a.element_type.width + // 8 + ) + a_ptr = tensor_a.iterator.toint() + a_offset + dst[bidx, 0] = a_ptr + + b_offset = ( + cute.Int64(bidx) + * per_expert_size + * tensor_b.stride[0] + * tensor_b.element_type.width + // 8 + ) + b_ptr = tensor_b.iterator.toint() + b_offset + dst[bidx, 1] = b_ptr + + c_offset = ( + cute.Int64(group_offset) + * tensor_c.stride[0] + * tensor_c.element_type.width + // 8 + ) + c_ptr = tensor_c.iterator.toint() + c_offset + dst[bidx, 2] = c_ptr + + +@partial(jax.jit, static_argnums=[0, 1, 3]) +def generate_group_sizes( + expert_count, token_count, key, uniform_group_size=False, round_group_sizes=8 +): + if uniform_group_size: + return jnp.array([token_count // expert_count] * expert_count) + round_group_sizes = float(round_group_sizes) + key1, key2 = jax.random.split(key, 2) + v = jax.random.truncated_normal(key1, -2.0, 2.0, expert_count) + 2.0 + expert_probs = v / jnp.sum(v) + expert_assignment = jax.random.choice( + key2, expert_count, (token_count,), p=expert_probs + ) + group_sizes = jnp.bincount(expert_assignment, length=expert_count) + group_sizes = round_group_sizes * jnp.floor( + group_sizes.astype(jnp.float32) / round_group_sizes + ) + group_sizes = group_sizes.at[0].add(token_count - group_sizes.sum()) + return group_sizes.astype(jnp.int32) + + +@pytest.mark.parametrize( + "uniform_groups", + [pytest.param(True, id="UNIFORM"), pytest.param(False, id="RANDOM")], +) +@pytest.mark.parametrize( + "problem_size", + [ + pytest.param( + (16, 8 * 1024, int(1.5 * 1024), 3 * 1024, 1), id="E16-M8192-N1536-K3072-L1" + ), + pytest.param( + (128, 32 * 1024, int(1.5 * 1024), 2048, 1), id="E128-M32768-N1536-K2048-L1" + ), + ], +) +@pytest.mark.parametrize( + "tensormap_update_mode", + [ + # pytest.param(utils.TensorMapUpdateMode.GMEM, id="GMEM"), + pytest.param(utils.TensorMapUpdateMode.SMEM, id="SMEM"), + ], +) +@pytest.mark.parametrize( + "mma_tile_shape_mn", + [ + pytest.param((128, 128), id="MMA_128x128"), + # pytest.param((256, 128), id="MMA_256x128"), + # pytest.param((256, 256), id="MMA_256x256"), + ], +) +@pytest.mark.parametrize( + "is_2sm, cluster_shape_mn", + [ + pytest.param(False, (1, 1), id="1SM-1x1"), + # pytest.param(False, (2, 1), id="1SM-2x1"), + # pytest.param(False, (2, 2), id="1SM-2x2"), + # pytest.param(False, (4, 1), id="1SM-4x1"), + pytest.param(True, (2, 1), id="2SM-2x1"), + # pytest.param(True, (2, 2), id="2SM-2x2"), + # pytest.param(True, (4, 1), id="2SM-4x1"), + ], +) +@pytest.mark.parametrize( + "a_dtype, b_dtype, c_dtype, acc_dtype", + [ + pytest.param( + jnp.float16, + jnp.float16, + jnp.float16, + jnp.float32, + id="bf16xbf16xbf16xfp32", + ), + pytest.param( + jnp.float8_e4m3fn, + jnp.float8_e4m3fn, + jnp.float16, + jnp.float32, + id="fp8xfp8xf16xf32", + ), + ], +) +@pytest.mark.parametrize( + "a_major, b_major, c_major", + [ + pytest.param("k", "k", "n", id="kkn_major"), + # pytest.param("k", "n", "n", id="knn_major"), + ], +) +@pytest.mark.requires_device("B200") +def test_grouped_gemm( + benchmark, + problem_size, + uniform_groups, + mma_tile_shape_mn, + cluster_shape_mn, + tensormap_update_mode, + is_2sm, + a_dtype, + b_dtype, + c_dtype, + acc_dtype, + a_major, + b_major, + c_major, +): + key = jax.random.key(1337) + + num_groups, m, n, k, l = problem_size + + # Skip invalid mma tile shape + if not ( + (not is_2sm and mma_tile_shape_mn[0] in [64, 128]) + or (is_2sm and mma_tile_shape_mn[0] in [128, 256]) + ): + raise pytest.skip(f"Skip invalid mma tiler M {mma_tile_shape_mn[0]}") + + if mma_tile_shape_mn[1] not in range(32, 257, 32): + raise pytest.skip(f"Skip invalid mma tiler N {mma_tile_shape_mn[1]}") + + if m % (mma_tile_shape_mn[0] * cluster_shape_mn[0]) != 0: + pytest.skip(f"Problem too small for M tiling.") + + if n % (mma_tile_shape_mn[1] * cluster_shape_mn[1]) != 0: + pytest.skip(f"Problem too small for N tiling.") + + # Skip illegal cluster shape + if cluster_shape_mn[0] % (2 if is_2sm else 1) != 0: + raise pytest.skip( + f"cluster_shape_m need align with is_2sm config {cluster_shape_mn}" + ) + + tensors_abc = [] + problem_sizes_mnkl = [] + strides_abc = [] + + gkey, key = jax.random.split(key) + group_sizes = generate_group_sizes(num_groups, m, gkey, uniform_groups) + assert group_sizes.sum() == m, "unexpected group sizes" + + # Build separate tensors for each expert. It is expected that the total tokens will + # sum to m. n is uniform across all experts. + for idx in range(num_groups): + sub_m = int(group_sizes[idx]) + akey, bkey, ckey, key = jax.random.split(key, 4) + + tensor_a = create_a_tensor(l, sub_m, k, a_major, a_dtype, akey) + tensor_b = create_b_tensor(l, n, k, b_major, b_dtype, bkey) + tensor_c = create_cd_tensor(l, sub_m, n, c_major, c_dtype, ckey, fill_value=0.0) + tensors_abc.append((tensor_a, tensor_b, tensor_c)) + + stride_mk_a = (k, 1) if a_major == "k" else (1, m) # mkl + stride_nk_b = (k, 1) if b_major == "k" else (1, n * num_groups) # nkl + stride_mn_c = (n, 1) if c_major == "n" else (1, m) # mnl + + strides_abc.append([stride_mk_a, stride_nk_b, stride_mn_c]) + problem_sizes_mnkl.append(((sub_m, n, k, l))) + + # layout (num_groups, 3, 2):(6, 2, 1) + strides_abc_tensor = jnp.array(strides_abc, dtype=jnp.int32) + problem_sizes_mnkl_tensor = jnp.array(problem_sizes_mnkl, dtype=jnp.int32) + group_offsets = jnp.cumsum(group_sizes) - group_sizes + + # get number of SMs by querying max active clusters with 1x1 cluster shape + hardware_info = cutlass.utils.HardwareInfo() + num_sms = hardware_info.get_device_multiprocessor_count() + max_active_clusters = hardware_info.get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1] + ) + num_tensormap_buffers = num_sms + + def compute_total_num_clusters(problem_sizes_mnkl, cga_tile_shape_mn): + total_num_clusters = 0 + for m, n, _, _ in problem_sizes_mnkl: + num_clusters_mn = tuple( + (x + y - 1) // y for x, y in zip((m, n), cga_tile_shape_mn) + ) + total_num_clusters += reduce(lambda x, y: x * y, num_clusters_mn) + return total_num_clusters + + def compute_cga_tile_shape(mma_tile_shape_mn, cluster_shape_mn, is_2sm): + cta_tile_shape_mn = list(mma_tile_shape_mn) + if is_2sm: + cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2 + return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn)) + + cga_tile_shape_mn = compute_cga_tile_shape( + mma_tile_shape_mn, cluster_shape_mn, is_2sm + ) + total_num_clusters = compute_total_num_clusters( + problem_sizes_mnkl, cga_tile_shape_mn + ) + + gemm = JaxGroupGemmKernel( + gemm_a_mode(a_major), + gemm_b_mode(b_major), + gemm_c_mode(c_major), + num_groups, + acc_dtype, + is_2sm, + mma_tile_shape_mn, + cluster_shape_mn, + tensormap_update_mode, + num_tensormap_buffers, + max_active_clusters, + total_num_clusters, + ) + + # Create the combined tensors by concatenating along the appropriate axis + am_axis = gemm_a_mode(a_major)[0] # mkl + bn_axis = gemm_b_mode(b_major)[0] # nkl + cm_axis = gemm_c_mode(c_major)[0] # mnl + tensor_a_device = jnp.concatenate([x[0] for x in tensors_abc], axis=am_axis) + tensor_b_device = jnp.concatenate([x[1] for x in tensors_abc], axis=bn_axis) + tensor_c_device = jnp.concatenate([x[2] for x in tensors_abc], axis=cm_axis) + + # Note: this call setup is a bit tricky because we need to extract addresses + # from tensor_c. To do this we donate tensor_c so we can treat it as both an + # input and output ensuring it has a stable allocation. + tensor_c_device = gemm( + tensor_a_device, + tensor_b_device, + tensor_c_device, + group_offsets, + problem_sizes_mnkl_tensor, + strides_abc_tensor, + ) + + c_ref = [] + for idx in range(num_groups): + c_ref.append( + gemm_reference_einsum( + tensors_abc[idx][0], + tensors_abc[idx][1], + acc_dtype=acc_dtype, + c_dtype=c_dtype, + a_major=a_major, + b_major=b_major, + c_major=c_major, + ) + ) + c_ref = jnp.concatenate(c_ref, axis=cm_axis).astype(jnp.float32) + tensor_c_device = tensor_c_device.astype(jnp.float32) + + # Tolerance from cutedsl tests. + assert jnp.allclose(c_ref, tensor_c_device, atol=0.1) + + with benchmark.runner("blackwell_grouped_gemm.txt") as runner: + for _ in runner: + tensor_c_device = gemm( + tensor_a_device, + tensor_b_device, + tensor_c_device, + group_offsets, + problem_sizes_mnkl_tensor, + strides_abc_tensor, + ) diff --git a/.github/container/cutlass_dsl_jax/tests/conftest.py b/.github/container/cutlass_dsl_jax/tests/conftest.py index cb16eb1aa..aeee5a3c5 100644 --- a/.github/container/cutlass_dsl_jax/tests/conftest.py +++ b/.github/container/cutlass_dsl_jax/tests/conftest.py @@ -14,11 +14,18 @@ import pytest import jax +import sys +import re +from unittest.mock import MagicMock, patch from jax_cutlass import release_compile_cache from .benchmark import cupti_profile, BenchmarkCollector +def pytest_configure(config): + config.addinivalue_line("markers", "requires_sm(arg): Specify required SM type.") + + def pytest_addoption(parser): parser.addoption("--benchmark_iters", default=16, action="store", type=int) parser.addoption("--benchmark", action="store_true") @@ -26,6 +33,12 @@ def pytest_addoption(parser): def pytest_sessionstart(session): + # Mock torch so that import of CuteDSL examples does not + # break on platforms without torch. + mock_modules = ("torch", "torch.nn", "torch.nn.functional") + for m in mock_modules: + sys.modules.update({m: MagicMock()}) + session.stash["collector"] = BenchmarkCollector( session.config.option.benchmark, session.config.option.benchmark_iters ) @@ -38,6 +51,17 @@ def pytest_sessionfinish(session): session.stash["collector"].save_csv() +def pytest_runtest_setup(item): + requires_device = item.get_closest_marker("requires_device") + if requires_device: + arg_value = requires_device.args[0] if requires_device.args else "" + for d in jax.devices(): + if not re.search(arg_value, d.device_kind): + pytest.skip( + f"Skipping test because device {d} is '{d.device_kind}' but requires '{arg_value}'" + ) + + @pytest.fixture def benchmark(request): collector = request.session.stash["collector"] diff --git a/.github/container/cutlass_dsl_jax/tests/tensor.py b/.github/container/cutlass_dsl_jax/tests/tensor.py index 320262f41..3eeaf7be8 100644 --- a/.github/container/cutlass_dsl_jax/tests/tensor.py +++ b/.github/container/cutlass_dsl_jax/tests/tensor.py @@ -17,6 +17,67 @@ import jax.numpy as jnp +def reorder_modes(src: str, target: str) -> tuple[int, ...]: + """Computes the mode given a source and target order.""" + src = tuple(src) + target = tuple(target) + src_map = {} + for idx, s in enumerate(src): + src_map[s] = idx + return tuple([src_map[d] for d in target]) + + +def gemm_a_major(d: str): + """Returns order for A tensor major mode.""" + return {"k": "lmk", "m": "lkm"}[d] + + +def gemm_a_mode(d: str) -> tuple[int, ...]: + """Returns mode for A tensor major mode.""" + return reorder_modes(gemm_a_major(d), "mkl") + + +def gemm_b_major(d: str): + """Returns order for B tensor major mode.""" + return {"k": "lnk", "n": "lkn"}[d] + + +def gemm_b_mode(d: str) -> tuple[int, ...]: + """Returns mode for B tensor major mode.""" + return reorder_modes(gemm_b_major(d), "nkl") + + +def gemm_c_major(d: str): + """Returns order for C tensor major mode.""" + return {"n": "lmn", "m": "lnm"}[d] + + +def gemm_c_mode(d: str) -> tuple[int, ...]: + """Returns mode for C tensor major mode.""" + return reorder_modes(gemm_c_major(d), "mnl") + + +def gemm_a_shape(l, m, k, major) -> tuple[int, ...]: + """Returns shape for A tensor given major mode.""" + assert major in ("k", "m") + shape = (l, m, k) if major == "k" else (l, k, m) + return shape + + +def gemm_b_shape(l, n, k, major) -> tuple[int, ...]: + """Returns shape for B tensor given major mode.""" + assert major in ("k", "n") + shape = (l, n, k) if major == "k" else (l, k, n) + return shape + + +def gemm_c_shape(l, m, n, major) -> tuple[int, ...]: + """Returns shape for C tensor given major mode.""" + assert major in ("m", "n") + shape = (l, m, n) if major == "n" else (l, n, m) + return shape + + def create_tensor( shape, dtype, key, *, minval=-2.0, maxval=2.0, fill_value=None, fill_arange=False ): @@ -33,3 +94,155 @@ def create_tensor( ) tensor = tensor.astype(dtype) return tensor + + +def create_a_tensor( + l, + m, + k, + major, + dtype, + key, + minval=-2.0, + maxval=2.0, + fill_value=None, + fill_arange=False, +): + shape = gemm_a_shape(l, m, k, major) + tensor = create_tensor( + shape, + dtype, + key, + minval=minval, + maxval=maxval, + fill_value=fill_value, + fill_arange=fill_arange, + ) + return tensor + + +def create_b_tensor( + l, + n, + k, + major, + dtype, + key, + minval=-2.0, + maxval=2.0, + fill_value=None, + fill_arange=False, +): + shape = gemm_b_shape(l, n, k, major) + tensor = create_tensor( + shape, + dtype, + key, + minval=minval, + maxval=maxval, + fill_value=fill_value, + fill_arange=fill_arange, + ) + return tensor + + +def create_cd_tensor( + l, + m, + n, + major, + dtype, + key, + *, + minval=-2.0, + maxval=2.0, + fill_value=None, + fill_arange=False, +): + shape = gemm_c_shape(l, m, n, major) + tensor = create_tensor( + shape, + dtype, + key, + minval=minval, + maxval=maxval, + fill_value=fill_value, + fill_arange=fill_arange, + ) + return tensor + + +def gemm_reference_einsum( + a, + b, + acc_dtype, + c_dtype, + a_major, + b_major, + c_major, + sf_a=None, + sf_b=None, + precision="highest", +): + a_idx = gemm_a_major(a_major) + b_idx = gemm_b_major(b_major) + c_idx = gemm_c_major(c_major) + spec = f"{a_idx},{b_idx}->{c_idx}" + + # If block scaled pre-scale input at higher precision + # Assumes we only use it for fp8 and smaller. + if sf_a is not None: + sf_vec_size = int(a.shape[-1] // sf_a.shape[-1]) + sf_a = jnp.repeat(sf_a, sf_vec_size, axis=-1) + a = a.astype(jnp.float16) * sf_a.astype(jnp.float16) + + if sf_b is not None: + sf_vec_size = int(b.shape[-1] // sf_b.shape[-1]) + sf_b = jnp.repeat(sf_b, sf_vec_size, axis=-1) + b = b.astype(jnp.float16) * sf_b.astype(jnp.float16) + + return jax.jit( + lambda a, b: jnp.einsum( + spec, a, b, preferred_element_type=acc_dtype, precision=precision + ).astype(c_dtype) + )(a, b) + + +def create_attn_tensors( + b, s, hq, hkv, d, dtype, key, *, minval=-2.0, maxval=2.0, fill_value=None +): + qkey, kkey, vkey = jax.random.split(key, 3) + return ( + create_tensor( + (b, s, hq, d), + dtype, + key, + minval=minval, + maxval=maxval, + fill_value=fill_value, + ), + create_tensor( + (b, s, hkv, d), + dtype, + key, + minval=minval, + maxval=maxval, + fill_value=fill_value, + ), + create_tensor( + (b, s, hkv, d), + dtype, + key, + minval=minval, + maxval=maxval, + fill_value=fill_value, + ), + ) + + +def attn_ref(q, k, v, is_causal: bool): + return jax.jit( + lambda q, k, v: jax.nn.dot_product_attention( + q, k, v, is_causal=is_causal, implementation="cudnn" + ) + )(q, k, v) diff --git a/.github/container/cutlass_dsl_jax/tests/test_args.py b/.github/container/cutlass_dsl_jax/tests/test_args.py index a18a63d56..b0f66fbed 100644 --- a/.github/container/cutlass_dsl_jax/tests/test_args.py +++ b/.github/container/cutlass_dsl_jax/tests/test_args.py @@ -42,9 +42,9 @@ def kernel( tidx, _, _ = cute.arch.thread_idx() bidx, _, _ = cute.arch.block_idx() - frgA = cute.make_fragment(cute.size(x, mode=[0]), x.element_type) - frgB = cute.make_fragment(cute.size(y, mode=[0]), y.element_type) - frgC = cute.make_fragment(cute.size(z, mode=[0]), z.element_type) + frgA = cute.make_rmem_tensor(cute.size(x, mode=[0]), x.element_type) + frgB = cute.make_rmem_tensor(cute.size(y, mode=[0]), y.element_type) + frgC = cute.make_rmem_tensor(cute.size(z, mode=[0]), z.element_type) cute.autovec_copy(x[None, tidx, bidx], frgA) cute.autovec_copy(y[None, tidx, bidx], frgB) @@ -82,8 +82,7 @@ def test(self): cutlass_call, self.launch, output_shape_dtype=jax.ShapeDtypeStruct(shape, dtype), - input_mode=(TM(static=True), TM(static=True)), - output_mode=TM(static=True), + use_static_tensors=True, ) c = call(const_a=1.0, const_b=1.0)(a, b) c_ref = self.ref_call(a, b, 1.0, 1.0) @@ -111,12 +110,14 @@ def kernel( bidx, _, _ = cute.arch.block_idx() for idx in cutlass.range_constexpr(len(b)): - frgA = cute.make_fragment(cute.size(a, mode=[0]), a.element_type) + frgA = cute.make_rmem_tensor(cute.size(a, mode=[0]), a.element_type) cute.autovec_copy(a[None, tidx, bidx], frgA) - frgB = cute.make_fragment( + frgB = cute.make_rmem_tensor( cute.size(b[int(idx)], mode=[0]), b[idx].element_type ) - frgC = cute.make_fragment(cute.size(c[idx], mode=[0]), c[idx].element_type) + frgC = cute.make_rmem_tensor( + cute.size(c[idx], mode=[0]), c[idx].element_type + ) cute.autovec_copy(b[idx][None, tidx, bidx], frgB) frgC.store(frgA.load() + frgB.load()) cute.autovec_copy(frgC, c[idx][None, tidx, bidx]) @@ -177,10 +178,14 @@ def kernel( # Only write to the even lists for idx in cutlass.range_constexpr(0, len(b), 2): - frgA = cute.make_fragment(cute.size(a, mode=[0]), a.element_type) + frgA = cute.make_rmem_tensor(cute.size(a, mode=[0]), a.element_type) cute.autovec_copy(a[None, tidx, bidx], frgA) - frgB = cute.make_fragment(cute.size(b[idx], mode=[0]), b[idx].element_type) - frgC = cute.make_fragment(cute.size(c[idx], mode=[0]), c[idx].element_type) + frgB = cute.make_rmem_tensor( + cute.size(b[idx], mode=[0]), b[idx].element_type + ) + frgC = cute.make_rmem_tensor( + cute.size(c[idx], mode=[0]), c[idx].element_type + ) cute.autovec_copy(b[idx][None, tidx, bidx], frgB) frgC.store(frgA.load() + frgB.load()) cute.autovec_copy(frgC, c[idx][None, tidx, bidx]) @@ -238,3 +243,137 @@ def test(self): c_ref = self.ref_call(a, b) for ci, ci_ref in zip(c, c_ref): assert jnp.allclose(ci, ci_ref) + + +class TestPartialBoundArgs: + @cute.kernel + def kernel( + self, + x: cute.Tensor, + y: cute.Tensor, + z: cute.Tensor, + const_a: cutlass.Constexpr, + const_b: cutlass.Constexpr, + ): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + frgA = cute.make_rmem_tensor(cute.size(x, mode=[0]), x.element_type) + frgB = cute.make_rmem_tensor(cute.size(y, mode=[0]), y.element_type) + frgC = cute.make_rmem_tensor(cute.size(z, mode=[0]), z.element_type) + + cute.autovec_copy(x[None, tidx, bidx], frgA) + cute.autovec_copy(y[None, tidx, bidx], frgB) + frgC.store(frgA.load() * const_a + frgB.load() * const_b) + cute.autovec_copy(frgC, z[None, tidx, bidx]) + + @cute.jit + def launch( + self, + stream: cuda.CUstream, + a1: cute.Tensor, + b1: cute.Tensor, + c1: cute.Tensor, + *, + const_a: cutlass.Constexpr[float], + const_b: cutlass.Constexpr[float] + ): + self.kernel(a1, b1, c1, const_a, const_b).launch( + grid=[a1.shape[-1], 1, 1], block=[a1.shape[-2], 1, 1], stream=stream + ) + + @partial(jax.jit, static_argnums=[0, 3, 4]) + def ref_call(self, a, b, const_a, const_b): + return a * const_a + b * const_b + + def test(self): + shape = (4, 16, 16) + dtype = jnp.float32 + a_key, b_key = jax.random.split(jax.random.key(1123), 2) + + a = create_tensor(shape, dtype, a_key) + b = create_tensor(shape, dtype, b_key) + + fn = partial(self.launch, const_a=2.0) + + call = partial( + cutlass_call, + fn, + output_shape_dtype=jax.ShapeDtypeStruct(shape, dtype), + input_mode=(TM(static=True), TM(static=True)), + output_mode=TM(static=True), + const_b=-3.0, + ) + c = call()(a, b) + c_ref = self.ref_call(a, b, 2.0, -3.0) + + assert jnp.allclose(c, c_ref) + + +class TestCompileOptionsPassing: + @cute.kernel + def kernel(self, x: cute.Tensor, z: cute.Tensor): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + frgA = cute.make_rmem_tensor(cute.size(x, mode=[0]), x.element_type) + frgC = cute.make_rmem_tensor(cute.size(z, mode=[0]), z.element_type) + + cute.autovec_copy(x[None, tidx, bidx], frgA) + frgC.store(frgA.load()) + cute.autovec_copy(frgC, z[None, tidx, bidx]) + + @cute.jit + def launch( + self, + stream: cuda.CUstream, + a1: cute.Tensor, + c1: cute.Tensor, + ): + self.kernel(a1, c1).launch( + grid=[a1.shape[-1], 1, 1], block=[a1.shape[-2], 1, 1], stream=stream + ) + + def test(self): + shape = (4, 16, 16) + dtype = jnp.float32 + a_key = jax.random.key(1123) + a = create_tensor(shape, dtype, a_key) + + call = cutlass_call( + self.launch, + output_shape_dtype=jax.ShapeDtypeStruct(shape, dtype), + input_mode=TM(static=True), + output_mode=TM(static=True), + compile_options="--opt-level=0", + ) + + c = call(a) + assert jnp.allclose(c, a) + + # Combine typed and string + from cutlass.cute import ( + OptLevel, + EnableAssertions, + GenerateLineInfo, + KeepCUBIN, + KeepPTX, + ) + + my_debugging_options = ( + "--opt-level=1", + EnableAssertions, + GenerateLineInfo, + KeepCUBIN, + KeepPTX, + ) + + call = cutlass_call( + self.launch, + output_shape_dtype=jax.ShapeDtypeStruct(shape, dtype), + use_static_tensors=True, + compile_options=my_debugging_options, + ) + + c = call(a * 2.0) + assert jnp.allclose(c, a * 2.0) diff --git a/.github/container/cutlass_dsl_jax/tests/test_sharding.py b/.github/container/cutlass_dsl_jax/tests/test_sharding.py index d7547d5c0..e6148dd15 100644 --- a/.github/container/cutlass_dsl_jax/tests/test_sharding.py +++ b/.github/container/cutlass_dsl_jax/tests/test_sharding.py @@ -40,9 +40,9 @@ def kernel( tidx, _, _ = cute.arch.thread_idx() bidx, _, _ = cute.arch.block_idx() - frgA = cute.make_fragment(cute.size(a, mode=[0]), a.element_type) - frgB = cute.make_fragment(cute.size(b, mode=[0]), b.element_type) - frgC = cute.make_fragment(cute.size(c, mode=[0]), c.element_type) + frgA = cute.make_rmem_tensor(cute.size(a, mode=[0]), a.element_type) + frgB = cute.make_rmem_tensor(cute.size(b, mode=[0]), b.element_type) + frgC = cute.make_rmem_tensor(cute.size(c, mode=[0]), c.element_type) cute.autovec_copy(a[None, tidx, bidx], frgA) cute.autovec_copy(b[None, tidx, bidx], frgB) diff --git a/.github/container/cutlass_dsl_jax/tests/test_stream.py b/.github/container/cutlass_dsl_jax/tests/test_stream.py index 14297ed84..72ac51055 100644 --- a/.github/container/cutlass_dsl_jax/tests/test_stream.py +++ b/.github/container/cutlass_dsl_jax/tests/test_stream.py @@ -39,9 +39,9 @@ def kernel( tidx, _, _ = cute.arch.thread_idx() bidx, _, _ = cute.arch.block_idx() - frgA = cute.make_fragment(cute.size(a, mode=[0]), a.element_type) - frgB = cute.make_fragment(cute.size(b, mode=[0]), b.element_type) - frgC = cute.make_fragment(cute.size(c, mode=[0]), c.element_type) + frgA = cute.make_rmem_tensor(cute.size(a, mode=[0]), a.element_type) + frgB = cute.make_rmem_tensor(cute.size(b, mode=[0]), b.element_type) + frgC = cute.make_rmem_tensor(cute.size(c, mode=[0]), c.element_type) cute.autovec_copy(a[None, tidx, bidx], frgA) cute.autovec_copy(b[None, tidx, bidx], frgB) diff --git a/.github/eks-workflow-files/jax-cutlass/scripts/unittest.sh b/.github/eks-workflow-files/jax-cutlass/scripts/unittest.sh index 4655dfde7..cd243b10b 100644 --- a/.github/eks-workflow-files/jax-cutlass/scripts/unittest.sh +++ b/.github/eks-workflow-files/jax-cutlass/scripts/unittest.sh @@ -18,11 +18,17 @@ pip install ${PIP_SRC} + # Clone CUTLASS examples + CUTLASS_ROOT="${SRC_ROOT}/cutlass" + CUTLASS_EXAMPLES_ROOT="${CUTLASS_ROOT}/examples/python/CuTeDSL" + git clone https://github.com/NVIDIA/cutlass.git ${CUTLASS_ROOT} + NGPUS=$(nvidia-smi --list-gpus | wc -l) # Start MPS daemon nvidia-cuda-mps-control -d + export PYTHONPATH=${CUTLASS_EXAMPLES_ROOT} pytest-xdist.sh ${NGPUS} 1 ${LOG_DIR}/pytest-report.jsonl pytest -xsv --log-file=${LOG_DIR}/pytest_log.log --log-file-level=INFO ${PIP_SRC}/tests/ | tee -a ${LOG_DIR}/pytest_stdout_dist.log touch ${LOG_DIR}/done