Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/container/cutlass_dsl_jax/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"jax>=0.6.2",
"nvidia-cutlass-dsl>=4.2.1"
"nvidia-cutlass-dsl>=4.3.1"
]
dynamic = ["version"]

Expand Down
33 changes: 20 additions & 13 deletions .github/container/cutlass_dsl_jax/src/jax_cutlass/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import gc
import ctypes
import inspect
from typing import Any, Callable
from typing import Any, Callable, Optional
from dataclasses import dataclass
from functools import partial
from pathlib import Path
Expand Down Expand Up @@ -77,6 +77,8 @@ class FunctionSpec:
output_layout: tuple[tuple[int, ...]]
output_mode: tuple[TensorMode, ...]
convert_tensors: bool
compile_options: str
use_static_tensors: bool
kwargs: tuple[tuple[str, Any]]

def get_compile_args(self):
Expand Down Expand Up @@ -121,8 +123,14 @@ def jit_wrapper(
# split buffer argument into inputs and outputs and return to tree
ins, outs = args[: len(spec.in_args)], args[(len(spec.in_args)) :]
if cutlass.const_expr(spec.convert_tensors):
ins = [x.get_tensor(a.mode) for x, a in zip(ins, spec.in_args)]
outs = [x.get_tensor(a.mode) for x, a in zip(outs, spec.out_args)]
ins = [
x.get_tensor(a.mode, spec.use_static_tensors)
for x, a in zip(ins, spec.in_args)
]
outs = [
x.get_tensor(a.mode, spec.use_static_tensors)
for x, a in zip(outs, spec.out_args)
]
ins = jax.tree.unflatten(spec.input_tree, ins)
outs = jax.tree.unflatten(spec.output_tree, outs)
wrapped_fn(stream, *ins, *outs, **dict(spec.kwargs))
Expand Down Expand Up @@ -176,6 +184,8 @@ def build_function_spec(
output_mode,
input_output_aliases,
convert_tensors,
compile_options,
use_static_tensors,
kwargs,
):
# TODO: Improve type checking and validate pytree structures.
Expand Down Expand Up @@ -233,6 +243,8 @@ def build_function_spec(
tuple(output_layout),
tuple(output_mode),
convert_tensors,
compile_options,
use_static_tensors,
tuple((k, kwargs[k]) for k in kwargs),
)

Expand Down Expand Up @@ -260,7 +272,11 @@ def get_or_compile_kernel(fn, spec, stream):
with _compile_lock:
start = time.time()
try:
compiled_fn = cutlass.cute.compile(
cute_compile = cutlass.cute.compile
if spec.compile_options:
cute_compile = partial(cute_compile, options=spec.compile_options)

compiled_fn = cute_compile(
jit_wrapper,
cuda.CUstream(stream),
spec.get_compile_args(),
Expand Down Expand Up @@ -313,15 +329,6 @@ def initialize_cutlass_dsl():
if _CUTLASS_DSL_INITIALIZED:
return

# TODO(mgoldfarb-nvidia): There are several runtime libraries that export C++ symbols
# which conflict with jax libraries. Initializing cutlass before jax will cause these
# symbols to incorrectly interpose. Our WAR is to for loading of jaxlib and its
# dependant libraries to ensure all symbols are loaded prior to compiling cutedsl programs.
# This linking issue is planed to be resolved in cute DSL 4.3.
jaxlib_common = Path(jaxlib.__file__).parent / "libjax_common.so"
if jaxlib_common.exists():
ctypes.CDLL(str(jaxlib_common), mode=ctypes.RTLD_GLOBAL)

kernel = _DummyInitKernel()
with _compile_lock:
logger.debug("Initializing cutlass dsl...")
Expand Down
21 changes: 20 additions & 1 deletion .github/container/cutlass_dsl_jax/src/jax_cutlass/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def cutlass_call(
input_output_aliases={},
convert_tensors=True,
allow_cuda_graph=True,
compile_options=None,
use_static_tensors=False,
**kwargs,
):
"""Creates a callable that invokes a @cute.jit function.
Expand All @@ -77,6 +79,11 @@ def cutlass_call(
convert_tensors: Jax array buffers will be converted to cute.Tensor with static shape and
layout. If disabled the kernel is instead given a JaxArray pointer directly.
allow_cuda_graph: If false will prevent XLA from building a cuda graph of for this call.
compile_options: Optional compiler arguments to pass into cute.compile.
use_static_tensors: If True, tensor shapes and strides are treated as constexpr values by
default. This can improve performance through compiler specialization but may not work
properly with all kernels. Specific tensors may be marked static or dynamic using the mode
and override this flag.
kwargs: Optional constexpr parameters to pass into the kernel fn.

Note: This API is experimental and subject to change!
Expand All @@ -94,6 +101,8 @@ def cutlass_call(
input_output_aliases=input_output_aliases,
convert_tensors=convert_tensors,
allow_cuda_graph=allow_cuda_graph,
compile_options=compile_options,
use_static_tensors=use_static_tensors,
**kwargs,
)

Expand Down Expand Up @@ -127,6 +136,8 @@ def _cutlass_call_impl(
input_output_aliases,
convert_tensors,
allow_cuda_graph,
compile_options,
use_static_tensors,
**kwargs,
):
multiple_results = isinstance(output_shape_dtype, Sequence)
Expand Down Expand Up @@ -200,7 +211,9 @@ def call_wrapper(*args):
# information we got as input.
for idx, (arg, mode) in enumerate(zip(args_flat, input_mode_flat)):
if mode.mode is not None and len(mode.mode) != len(arg.shape):
raise ValueError(f"Input #{idx} has invalid mode.")
raise ValueError(
f"Input #{idx} has invalid mode {mode.mode} for shape {arg.shape}."
)
for idx, (arg, mode) in enumerate(
zip(output_shape_dtype_flat, output_mode_flat)
):
Expand All @@ -220,6 +233,8 @@ def call_wrapper(*args):
input_output_aliases=tuple(input_output_aliases.items()),
convert_tensors=convert_tensors,
allow_cuda_graph=allow_cuda_graph,
compile_options=compile_options,
use_static_tensors=use_static_tensors,
**kwargs,
)

Expand Down Expand Up @@ -247,6 +262,8 @@ def cutlass_call_inner_p_impl(
input_output_aliases,
convert_tensors,
allow_cuda_graph,
compile_options,
use_static_tensors,
**kwargs,
):
input_output_aliases = dict(input_output_aliases)
Expand All @@ -266,6 +283,8 @@ def cutlass_call_inner_p_impl(
output_mode_flat,
input_output_aliases,
convert_tensors,
compile_options,
use_static_tensors,
kwargs,
)

Expand Down
31 changes: 22 additions & 9 deletions .github/container/cutlass_dsl_jax/src/jax_cutlass/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from operator import mul
from itertools import chain
from typing import Annotated
from enum import Enum

import cuda.bindings.driver as cuda

Expand Down Expand Up @@ -71,11 +70,12 @@ class TensorMode:
Arguments:
mode : Specifies the position of each mode in the tensor (M0, M1, ... MN)
"""

mode: tuple[int, ...] | None = field(metadata=dict(static=True), default=None)
# Indicates the shape and strides will be defined statically. Enabling may enable
# additional optimization. Kernels that do not support static shapes will generate
# compile errors if this is enabled so we leave it off by default.
static: bool = field(metadata=dict(static=True), default=False)
static: bool = field(metadata=dict(static=True), default=None)
# Overrides the default pointer alignment. Generally this should not be changed
# but is left here to provide a hook.
ptr_assumed_align: int = field(
Expand Down Expand Up @@ -243,7 +243,12 @@ def align(self, min_align: int, *, loc=None, ip=None) -> "JaxArray":
return JaxArray(self.ptr.align(min_align, loc, ip), self._shape, self._order)

def get_layout(
self, mode: tuple[int, ...] | TensorMode = None, *, loc=None, ip=None
self,
mode: tuple[int, ...] | TensorMode = None,
use_static_tensors: bool = False,
*,
loc=None,
ip=None,
) -> cute.Layout:
"""Create a cute.Layout from this JaxArray.

Expand All @@ -256,26 +261,34 @@ def get_layout(
:type tuple[int,...]: Tuple that is same size as shape.
"""
if isinstance(mode, (tuple, list)):
mode = TensorMode(mode)
mode = TensorMode(mode, static=use_static_tensors)

if (mode.static is None and use_static_tensors) or mode.static:
shape = self._shape
else:
shape = [cutlass.as_numeric(m) for m in self._shape]

shape = (
self._shape if mode.static else [cutlass.as_numeric(m) for m in self._shape]
)
layout = cute.make_ordered_layout(tuple(shape), self._order, loc=loc, ip=ip)
if mode is not None and mode.mode is not None:
layout = cute.select(layout, mode.mode)
return layout

def get_tensor(
self, mode: tuple[int, ...] | TensorMode = None, *, loc=None, ip=None
self,
mode: tuple[int, ...] | TensorMode = None,
use_static_tensors: bool = False,
*,
loc=None,
ip=None,
) -> cute.Tensor:
"""Create a cute.Tensor from this JaxArray.

:param mode: Maps the physical shape dimension to logical shape dimensions. If not given the physical layout is used.
:param use_static_tensors: Defaults tensor shape and stride to static if no mode is given.
:type tuple[int,...]: Tuple that is same size as shape.
:see get_layout
"""
layout = self.get_layout(mode, loc=loc, ip=ip)
layout = self.get_layout(mode, use_static_tensors, loc=loc, ip=ip)
return cute.make_tensor(self.ptr, layout)

# Utility methods
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@
# limitations under the License.


__version_info__ = (0, 2, 0)
__version_info__ = (0, 3, 0)
__version__ = ".".join(str(v) for v in __version_info__)
13 changes: 13 additions & 0 deletions .github/container/cutlass_dsl_jax/tests/blackwell/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading
Loading