Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a4db20b
add ltoir test support
abhilash1910 Nov 5, 2025
fb6cfb3
add options for multi-modules
abhilash1910 Nov 25, 2025
7aaed4e
add tests
abhilash1910 Nov 25, 2025
64c7f7d
add bitcode test
abhilash1910 Nov 25, 2025
42ba301
[pre-commit.ci] auto code formatting
pre-commit-ci[bot] Nov 25, 2025
7ca6899
fix format
abhilash1910 Nov 25, 2025
0674ea1
[pre-commit.ci] auto code formatting
pre-commit-ci[bot] Nov 25, 2025
03b1224
refresh
abhilash1910 Nov 26, 2025
033f11c
apply bitcode file from cupy_test helpers
abhilash1910 Dec 1, 2025
6e411ee
use 2 tuples
abhilash1910 Dec 1, 2025
b4c21db
Merge branch 'main' into nvvm_enhance
abhilash1910 Dec 2, 2025
aeb26aa
refresh
abhilash1910 Dec 3, 2025
b3d6d96
format
abhilash1910 Dec 3, 2025
edd6401
[pre-commit.ci] auto code formatting
pre-commit-ci[bot] Dec 3, 2025
d53e00b
Merge branch 'main' into nvvm_enhance
abhilash1910 Dec 7, 2025
8dbbafe
fix from upstream
abhilash1910 Dec 15, 2025
0174bb8
Merge branch 'main' into nvvm_enhance
abhilash1910 Dec 16, 2025
b78f0c3
refresh from upstream
abhilash1910 Dec 17, 2025
99a5593
fix tests
abhilash1910 Dec 17, 2025
783f6e5
take path_finder from PR 447
abhilash1910 Dec 17, 2025
5dbfb2d
add builder files
abhilash1910 Dec 17, 2025
0a9eea9
use python lists/tuples
abhilash1910 Dec 17, 2025
79138c0
libdevice integration
abhilash1910 Dec 18, 2025
25d336c
refresh
abhilash1910 Dec 19, 2025
32c1913
refresh
abhilash1910 Dec 19, 2025
01f03e5
refresh
abhilash1910 Dec 19, 2025
9a5d5fe
use cuda_pathfinder module for libdevice
abhilash1910 Dec 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cuda_bindings/cuda/bindings/_internal/nvjitlink_linux.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ from .utils import FunctionNotFoundError, NotSupportedError

from cuda.pathfinder import load_nvidia_dynamic_lib


###############################################################################
# Extern
###############################################################################
Expand Down
1 change: 0 additions & 1 deletion cuda_bindings/cuda/bindings/_internal/nvvm_linux.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ from .utils import FunctionNotFoundError, NotSupportedError

from cuda.pathfinder import load_nvidia_dynamic_lib


###############################################################################
# Extern
###############################################################################
Expand Down
68 changes: 66 additions & 2 deletions cuda_core/cuda/core/experimental/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ def _get_nvvm_module():
_nvvm_module = None
raise e

def _find_libdevice_path():
"""Find libdevice.10.bc for NVVM compilation using cuda.pathfinder."""
from cuda.pathfinder import get_libdevice_path
return get_libdevice_path()

def _process_define_macro_inner(formatted_options, macro):
if isinstance(macro, str):
Expand Down Expand Up @@ -335,6 +339,10 @@ class ProgramOptions:
split_compile: int | None = None
fdevice_syntax_only: bool | None = None
minimal: bool | None = None
# Creating as 2 tuples ((names, source), (names,source))
extra_sources: (
Union[List[Tuple[str, Union[str, bytes, bytearray]]], Tuple[Tuple[str, Union[str, bytes, bytearray]]]] | None
) = None
no_cache: bool | None = None
fdevice_time_trace: str | None = None
device_float128: bool | None = None
Expand All @@ -348,6 +356,7 @@ class ProgramOptions:
pch_messages: bool | None = None
instantiate_templates_in_pch: bool | None = None
numba_debug: bool | None = None # Custom option for Numba debugging
use_libdevice: bool | None = None # Use libdevice

def __post_init__(self):
self._name = self.name.encode()
Expand Down Expand Up @@ -669,26 +678,33 @@ def close(self):
nvvm.destroy_program(self.handle)
self.handle = None

__slots__ = ("__weakref__", "_mnff", "_backend", "_linker", "_options")
__slots__ = ("__weakref__", "_mnff", "_backend", "_linker", "_options", "_module_count")

def __init__(self, code, code_type, options: ProgramOptions = None):
self._mnff = Program._MembersNeededForFinalize(self, None, None)

self._options = options = check_or_create_options(ProgramOptions, options, "Program options")
code_type = code_type.lower()
self._module_count = 0

if code_type == "c++":
assert_type(code, str)
# TODO: support pre-loaded headers & include names
# TODO: allow tuples once NVIDIA/cuda-python#72 is resolved

if options.extra_sources is not None:
raise ValueError("extra_sources is not supported by the NVRTC backend (C++ code_type)")

# TODO: allow tuples once NVIDIA/cuda-python#72 is resolved
self._mnff.handle = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), options._name, 0, [], []))
self._mnff.backend = "NVRTC"
self._backend = "NVRTC"
self._linker = None

elif code_type == "ptx":
assert_type(code, str)
if options.extra_sources is not None:
raise ValueError("extra_sources is not supported by the PTX backend.")

self._linker = Linker(
ObjectCode._init(code.encode(), code_type), options=self._translate_program_options(options)
)
Expand All @@ -704,6 +720,41 @@ def __init__(self, code, code_type, options: ProgramOptions = None):
self._mnff.handle = nvvm.create_program()
self._mnff.backend = "NVVM"
nvvm.add_module_to_program(self._mnff.handle, code, len(code), options._name.decode())
self._module_count = 1
# Add extra modules if provided
if options.extra_sources is not None:
if not is_sequence(options.extra_sources):
raise TypeError(
"extra_modules must be a sequence of 2-tuples:((name1, source1), (name2, source2), ...)"
)
for i, module in enumerate(options.extra_sources):
if not isinstance(module, tuple) or len(module) != 2:
raise TypeError(
f"Each extra module must be a 2-tuple (name, source)"
f", got {type(module).__name__} at index {i}"
)

module_name, module_source = module

if not isinstance(module_name, str):
raise TypeError(f"Module name at index {i} must be a string,got {type(module_name).__name__}")

if isinstance(module_source, str):
# Textual LLVM IR - encode to UTF-8 bytes
module_source = module_source.encode("utf-8")
elif not isinstance(module_source, (bytes, bytearray)):
raise TypeError(
f"Module source at index {i} must be str (textual LLVM IR), bytes (textual LLVM IR or bitcode), "
f"or bytearray, got {type(module_source).__name__}"
)

if len(module_source) == 0:
raise ValueError(f"Module source for '{module_name}' (index {i}) cannot be empty")

nvvm.add_module_to_program(self._mnff.handle, module_source, len(module_source), module_name)
self._module_count += 1

self._use_libdevice = options.use_libdevice
self._backend = "NVVM"
self._linker = None

Expand Down Expand Up @@ -821,6 +872,19 @@ def compile(self, target_type, name_expressions=(), logs=None):
nvvm = _get_nvvm_module()
with _nvvm_exception_manager(self):
nvvm.verify_program(self._mnff.handle, len(nvvm_options), nvvm_options)
# Invoke libdevice
if getattr(self, '_use_libdevice', False):
libdevice_path = _find_libdevice_path()
if libdevice_path is None:
raise RuntimeError(
"use_libdevice=True but could not find libdevice.10.bc. "
"Ensure CUDA toolkit is installed."
)
with open(libdevice_path, "rb") as f:
libdevice_bc = f.read()
# Use lazy_add_module for libdevice bitcode only following numba-cuda
nvvm.lazy_add_module_to_program(self._mnff.handle, libdevice_bc, len(libdevice_bc), None)

nvvm.compile_program(self._mnff.handle, len(nvvm_options), nvvm_options)

size = nvvm.get_compiled_result_size(self._mnff.handle)
Expand Down
198 changes: 194 additions & 4 deletions cuda_core/tests/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
cuda_driver_version = handle_return(driver.cuDriverGetVersion())
is_culink_backend = _linker._decide_nvjitlink_or_driver()

try:
from cuda_python_test_helpers.nvvm_bitcode import minimal_nvvmir

_test_helpers_available = True
except ImportError:
_test_helpers_available = False


def _is_nvvm_available():
"""Check if NVVM is available."""
Expand All @@ -32,7 +39,7 @@ def _is_nvvm_available():
)

try:
from cuda.core.experimental._utils.cuda_utils import driver, handle_return, nvrtc
from cuda.core.experimental._utils.cuda_utils import driver, handle_return

_cuda_driver_version = handle_return(driver.cuDriverGetVersion())
except Exception:
Expand All @@ -42,7 +49,6 @@ def _is_nvvm_available():
def _get_nvrtc_version_for_tests():
"""
Get NVRTC version.

Returns:
int: Version in format major * 1000 + minor * 100 (e.g., 13200 for CUDA 13.2)
None: If NVRTC is not available
Expand All @@ -54,7 +60,6 @@ def _get_nvrtc_version_for_tests():
except Exception:
return None


_libnvvm_version = None
_libnvvm_version_attempted = False

Expand Down Expand Up @@ -499,6 +504,191 @@ def test_nvvm_program_options(init_cuda, nvvm_ir, options):
program.close()


@nvvm_available
@pytest.mark.parametrize(
"options",
[
ProgramOptions(name="ltoir_test1", arch="sm_90", device_code_optimize=False),
ProgramOptions(name="ltoir_test2", arch="sm_100", link_time_optimization=True),
ProgramOptions(
name="ltoir_test3",
arch="sm_90",
ftz=True,
prec_sqrt=False,
prec_div=False,
fma=True,
device_code_optimize=True,
link_time_optimization=True,
),
],
)
def test_nvvm_program_options_ltoir(init_cuda, nvvm_ir, options):
"""Test NVVM programs for LTOIR with different options"""
program = Program(nvvm_ir, "nvvm", options)
assert program.backend == "NVVM"

ltoir_code = program.compile("ltoir")
assert isinstance(ltoir_code, ObjectCode)
assert ltoir_code.name == options.name
code_content = ltoir_code.code
assert len(code_content) > 0
program.close()


@nvvm_available
def test_nvvm_program_with_single_extra_source(nvvm_ir):
"""Test NVVM program with a single extra source"""
from cuda.core.experimental._program import _get_nvvm_module

nvvm = _get_nvvm_module()
major, minor, debug_major, debug_minor = nvvm.ir_version()
# helper nvvm ir for multiple module loading
helper_nvvmir = f"""target triple = "nvptx64-unknown-cuda"
target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-i128:128:128-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"

define i32 @helper_add(i32 %x) {{
entry:
%result = add i32 %x, 1
ret i32 %result
}}

!nvvmir.version = !{{!0}}
!0 = !{{i32 {major}, i32 {minor}, i32 {debug_major}, i32 {debug_minor}}}
""" # noqa: E501

options = ProgramOptions(
name="multi_module_test",
extra_sources=[
("helper", helper_nvvmir),
],
)
program = Program(nvvm_ir, "nvvm", options)

assert program.backend == "NVVM"

ptx_code = program.compile("ptx")
assert isinstance(ptx_code, ObjectCode)
assert ptx_code.name == "multi_module_test"

program.close()


@nvvm_available
def test_nvvm_program_with_multiple_extra_sources():
"""Test NVVM program with multiple extra sources"""
from cuda.core.experimental._program import _get_nvvm_module

nvvm = _get_nvvm_module()
major, minor, debug_major, debug_minor = nvvm.ir_version()

main_nvvm_ir = f"""target triple = "nvptx64-unknown-cuda"
target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-i128:128:128-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"

declare i32 @helper_add(i32) nounwind readnone
declare i32 @helper_mul(i32) nounwind readnone

define void @main_kernel(i32* %data) {{
entry:
%tid = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%ptr = getelementptr inbounds i32, i32* %data, i32 %tid
%val = load i32, i32* %ptr, align 4

%val1 = call i32 @helper_add(i32 %val)
%val2 = call i32 @helper_mul(i32 %val1)

store i32 %val2, i32* %ptr, align 4
ret void
}}

declare i32 @llvm.nvvm.read.ptx.sreg.tid.x() nounwind readnone

!nvvm.annotations = !{{!0}}
!0 = !{{void (i32*)* @main_kernel, !"kernel", i32 1}}

!nvvmir.version = !{{!1}}
!1 = !{{i32 {major}, i32 {minor}, i32 {debug_major}, i32 {debug_minor}}}
""" # noqa: E501

helper1_ir = f"""target triple = "nvptx64-unknown-cuda"
target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-i128:128:128-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"

define i32 @helper_add(i32 %x) nounwind readnone {{
entry:
%result = add i32 %x, 1
ret i32 %result
}}

!nvvmir.version = !{{!0}}
!0 = !{{i32 {major}, i32 {minor}, i32 {debug_major}, i32 {debug_minor}}}
""" # noqa: E501

helper2_ir = f"""target triple = "nvptx64-unknown-cuda"
target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-i128:128:128-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"

define i32 @helper_mul(i32 %x) nounwind readnone {{
entry:
%result = mul i32 %x, 2
ret i32 %result
}}

!nvvmir.version = !{{!0}}
!0 = !{{i32 {major}, i32 {minor}, i32 {debug_major}, i32 {debug_minor}}}
""" # noqa: E501

options = ProgramOptions(
name="nvvm_multi_helper_test",
extra_sources=[
("helper1", helper1_ir),
("helper2", helper2_ir),
],
)
program = Program(main_nvvm_ir, "nvvm", options)

assert program.backend == "NVVM"
ptx_code = program.compile("ptx")
assert isinstance(ptx_code, ObjectCode)
assert ptx_code.name == "nvvm_multi_helper_test"

ltoir_code = program.compile("ltoir")
assert isinstance(ltoir_code, ObjectCode)
assert ltoir_code.name == "nvvm_multi_helper_test"

program.close()


@nvvm_available
@pytest.mark.skipif(not _test_helpers_available, reason="cuda_python_test_helpers not accessible")
def test_bitcode_format(minimal_nvvmir):
if len(minimal_nvvmir) < 4:
pytest.skip("Bitcode file is not valid or empty")

options = ProgramOptions(name="minimal_nvvmir_bitcode_test", arch="sm_90")
program = Program(minimal_nvvmir, "nvvm", options)

assert program.backend == "NVVM"
ptx_result = program.compile("ptx")
assert isinstance(ptx_result, ObjectCode)
assert ptx_result.name == "minimal_nvvmir_bitcode_test"
assert len(ptx_result.code) > 0
program_lto = Program(minimal_nvvmir, "nvvm", options)
try:
ltoir_result = program_lto.compile("ltoir")
assert isinstance(ltoir_result, ObjectCode)
assert len(ltoir_result.code) > 0
print(f"LTOIR size: {len(ltoir_result.code)} bytes")
except Exception as e:
print(f"LTOIR compilation failed : {e}")
finally:
program.close()


def test_cpp_program_with_extra_sources():
# negative test with NVRTC with multiple sources
code = 'extern "C" __global__ void my_kernel(){}'
helper = 'extern "C" __global__ void helper(){}'
options = ProgramOptions(extra_sources=helper)
with pytest.raises(ValueError, match="extra_sources is not supported by the NVRTC backend"):
Program(code, "c++", options)
def test_program_options_as_bytes_nvrtc():
"""Test ProgramOptions.as_bytes() for NVRTC backend"""
options = ProgramOptions(arch="sm_80", debug=True, lineinfo=True, ftz=True)
Expand Down Expand Up @@ -546,4 +736,4 @@ def test_program_options_as_bytes_nvvm_unsupported_option():
"""Test that unsupported options raise CUDAError for NVVM backend"""
options = ProgramOptions(arch="sm_80", lineinfo=True)
with pytest.raises(CUDAError, match="not supported by NVVM backend"):
options.as_bytes("nvvm")
options.as_bytes("nvvm")
5 changes: 5 additions & 0 deletions cuda_pathfinder/cuda/pathfinder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
from cuda.pathfinder._dynamic_libs.load_dl_common import DynamicLibNotFoundError as DynamicLibNotFoundError
from cuda.pathfinder._dynamic_libs.load_dl_common import LoadedDL as LoadedDL
from cuda.pathfinder._dynamic_libs.load_nvidia_dynamic_lib import load_nvidia_dynamic_lib as load_nvidia_dynamic_lib
from cuda.pathfinder._dynamic_libs.find_libdevice import (
LibdeviceNotFoundError as LibdeviceNotFoundError,
find_libdevice as find_libdevice,
get_libdevice_path as get_libdevice_path,
)
from cuda.pathfinder._dynamic_libs.supported_nvidia_libs import (
SUPPORTED_LIBNAMES as SUPPORTED_NVIDIA_LIBNAMES, # noqa: F401
)
Expand Down
Loading