diff --git a/cuda_bindings/cuda/bindings/_internal/nvjitlink_linux.pyx b/cuda_bindings/cuda/bindings/_internal/nvjitlink_linux.pyx index 8c96b6d640..8025189b34 100644 --- a/cuda_bindings/cuda/bindings/_internal/nvjitlink_linux.pyx +++ b/cuda_bindings/cuda/bindings/_internal/nvjitlink_linux.pyx @@ -11,7 +11,6 @@ from .utils import FunctionNotFoundError, NotSupportedError from cuda.pathfinder import load_nvidia_dynamic_lib - ############################################################################### # Extern ############################################################################### diff --git a/cuda_bindings/cuda/bindings/_internal/nvvm_linux.pyx b/cuda_bindings/cuda/bindings/_internal/nvvm_linux.pyx index 408a2cb592..5411e157c7 100644 --- a/cuda_bindings/cuda/bindings/_internal/nvvm_linux.pyx +++ b/cuda_bindings/cuda/bindings/_internal/nvvm_linux.pyx @@ -11,7 +11,6 @@ from .utils import FunctionNotFoundError, NotSupportedError from cuda.pathfinder import load_nvidia_dynamic_lib - ############################################################################### # Extern ############################################################################### diff --git a/cuda_core/cuda/core/experimental/_program.py b/cuda_core/cuda/core/experimental/_program.py index f3ad9af644..fedafbc4fb 100644 --- a/cuda_core/cuda/core/experimental/_program.py +++ b/cuda_core/cuda/core/experimental/_program.py @@ -102,6 +102,10 @@ def _get_nvvm_module(): _nvvm_module = None raise e +def _find_libdevice_path(): + """Find libdevice.10.bc for NVVM compilation using cuda.pathfinder.""" + from cuda.pathfinder import get_libdevice_path + return get_libdevice_path() def _process_define_macro_inner(formatted_options, macro): if isinstance(macro, str): @@ -335,6 +339,10 @@ class ProgramOptions: split_compile: int | None = None fdevice_syntax_only: bool | None = None minimal: bool | None = None + # Creating as 2 tuples ((names, source), (names,source)) + extra_sources: ( + Union[List[Tuple[str, Union[str, bytes, bytearray]]], Tuple[Tuple[str, Union[str, bytes, bytearray]]]] | None + ) = None no_cache: bool | None = None fdevice_time_trace: str | None = None device_float128: bool | None = None @@ -348,6 +356,7 @@ class ProgramOptions: pch_messages: bool | None = None instantiate_templates_in_pch: bool | None = None numba_debug: bool | None = None # Custom option for Numba debugging + use_libdevice: bool | None = None # Use libdevice def __post_init__(self): self._name = self.name.encode() @@ -669,19 +678,23 @@ def close(self): nvvm.destroy_program(self.handle) self.handle = None - __slots__ = ("__weakref__", "_mnff", "_backend", "_linker", "_options") + __slots__ = ("__weakref__", "_mnff", "_backend", "_linker", "_options", "_module_count") def __init__(self, code, code_type, options: ProgramOptions = None): self._mnff = Program._MembersNeededForFinalize(self, None, None) self._options = options = check_or_create_options(ProgramOptions, options, "Program options") code_type = code_type.lower() + self._module_count = 0 if code_type == "c++": assert_type(code, str) # TODO: support pre-loaded headers & include names - # TODO: allow tuples once NVIDIA/cuda-python#72 is resolved + if options.extra_sources is not None: + raise ValueError("extra_sources is not supported by the NVRTC backend (C++ code_type)") + + # TODO: allow tuples once NVIDIA/cuda-python#72 is resolved self._mnff.handle = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), options._name, 0, [], [])) self._mnff.backend = "NVRTC" self._backend = "NVRTC" @@ -689,6 +702,9 @@ def __init__(self, code, code_type, options: ProgramOptions = None): elif code_type == "ptx": assert_type(code, str) + if options.extra_sources is not None: + raise ValueError("extra_sources is not supported by the PTX backend.") + self._linker = Linker( ObjectCode._init(code.encode(), code_type), options=self._translate_program_options(options) ) @@ -704,6 +720,41 @@ def __init__(self, code, code_type, options: ProgramOptions = None): self._mnff.handle = nvvm.create_program() self._mnff.backend = "NVVM" nvvm.add_module_to_program(self._mnff.handle, code, len(code), options._name.decode()) + self._module_count = 1 + # Add extra modules if provided + if options.extra_sources is not None: + if not is_sequence(options.extra_sources): + raise TypeError( + "extra_modules must be a sequence of 2-tuples:((name1, source1), (name2, source2), ...)" + ) + for i, module in enumerate(options.extra_sources): + if not isinstance(module, tuple) or len(module) != 2: + raise TypeError( + f"Each extra module must be a 2-tuple (name, source)" + f", got {type(module).__name__} at index {i}" + ) + + module_name, module_source = module + + if not isinstance(module_name, str): + raise TypeError(f"Module name at index {i} must be a string,got {type(module_name).__name__}") + + if isinstance(module_source, str): + # Textual LLVM IR - encode to UTF-8 bytes + module_source = module_source.encode("utf-8") + elif not isinstance(module_source, (bytes, bytearray)): + raise TypeError( + f"Module source at index {i} must be str (textual LLVM IR), bytes (textual LLVM IR or bitcode), " + f"or bytearray, got {type(module_source).__name__}" + ) + + if len(module_source) == 0: + raise ValueError(f"Module source for '{module_name}' (index {i}) cannot be empty") + + nvvm.add_module_to_program(self._mnff.handle, module_source, len(module_source), module_name) + self._module_count += 1 + + self._use_libdevice = options.use_libdevice self._backend = "NVVM" self._linker = None @@ -821,6 +872,19 @@ def compile(self, target_type, name_expressions=(), logs=None): nvvm = _get_nvvm_module() with _nvvm_exception_manager(self): nvvm.verify_program(self._mnff.handle, len(nvvm_options), nvvm_options) + # Invoke libdevice + if getattr(self, '_use_libdevice', False): + libdevice_path = _find_libdevice_path() + if libdevice_path is None: + raise RuntimeError( + "use_libdevice=True but could not find libdevice.10.bc. " + "Ensure CUDA toolkit is installed." + ) + with open(libdevice_path, "rb") as f: + libdevice_bc = f.read() + # Use lazy_add_module for libdevice bitcode only following numba-cuda + nvvm.lazy_add_module_to_program(self._mnff.handle, libdevice_bc, len(libdevice_bc), None) + nvvm.compile_program(self._mnff.handle, len(nvvm_options), nvvm_options) size = nvvm.get_compiled_result_size(self._mnff.handle) diff --git a/cuda_core/tests/test_program.py b/cuda_core/tests/test_program.py index f432e3f88d..00d0709870 100644 --- a/cuda_core/tests/test_program.py +++ b/cuda_core/tests/test_program.py @@ -15,6 +15,13 @@ cuda_driver_version = handle_return(driver.cuDriverGetVersion()) is_culink_backend = _linker._decide_nvjitlink_or_driver() +try: + from cuda_python_test_helpers.nvvm_bitcode import minimal_nvvmir + + _test_helpers_available = True +except ImportError: + _test_helpers_available = False + def _is_nvvm_available(): """Check if NVVM is available.""" @@ -32,7 +39,7 @@ def _is_nvvm_available(): ) try: - from cuda.core.experimental._utils.cuda_utils import driver, handle_return, nvrtc + from cuda.core.experimental._utils.cuda_utils import driver, handle_return _cuda_driver_version = handle_return(driver.cuDriverGetVersion()) except Exception: @@ -42,7 +49,6 @@ def _is_nvvm_available(): def _get_nvrtc_version_for_tests(): """ Get NVRTC version. - Returns: int: Version in format major * 1000 + minor * 100 (e.g., 13200 for CUDA 13.2) None: If NVRTC is not available @@ -54,7 +60,6 @@ def _get_nvrtc_version_for_tests(): except Exception: return None - _libnvvm_version = None _libnvvm_version_attempted = False @@ -499,6 +504,191 @@ def test_nvvm_program_options(init_cuda, nvvm_ir, options): program.close() +@nvvm_available +@pytest.mark.parametrize( + "options", + [ + ProgramOptions(name="ltoir_test1", arch="sm_90", device_code_optimize=False), + ProgramOptions(name="ltoir_test2", arch="sm_100", link_time_optimization=True), + ProgramOptions( + name="ltoir_test3", + arch="sm_90", + ftz=True, + prec_sqrt=False, + prec_div=False, + fma=True, + device_code_optimize=True, + link_time_optimization=True, + ), + ], +) +def test_nvvm_program_options_ltoir(init_cuda, nvvm_ir, options): + """Test NVVM programs for LTOIR with different options""" + program = Program(nvvm_ir, "nvvm", options) + assert program.backend == "NVVM" + + ltoir_code = program.compile("ltoir") + assert isinstance(ltoir_code, ObjectCode) + assert ltoir_code.name == options.name + code_content = ltoir_code.code + assert len(code_content) > 0 + program.close() + + +@nvvm_available +def test_nvvm_program_with_single_extra_source(nvvm_ir): + """Test NVVM program with a single extra source""" + from cuda.core.experimental._program import _get_nvvm_module + + nvvm = _get_nvvm_module() + major, minor, debug_major, debug_minor = nvvm.ir_version() + # helper nvvm ir for multiple module loading + helper_nvvmir = f"""target triple = "nvptx64-unknown-cuda" +target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-i128:128:128-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64" + +define i32 @helper_add(i32 %x) {{ +entry: + %result = add i32 %x, 1 + ret i32 %result +}} + +!nvvmir.version = !{{!0}} +!0 = !{{i32 {major}, i32 {minor}, i32 {debug_major}, i32 {debug_minor}}} +""" # noqa: E501 + + options = ProgramOptions( + name="multi_module_test", + extra_sources=[ + ("helper", helper_nvvmir), + ], + ) + program = Program(nvvm_ir, "nvvm", options) + + assert program.backend == "NVVM" + + ptx_code = program.compile("ptx") + assert isinstance(ptx_code, ObjectCode) + assert ptx_code.name == "multi_module_test" + + program.close() + + +@nvvm_available +def test_nvvm_program_with_multiple_extra_sources(): + """Test NVVM program with multiple extra sources""" + from cuda.core.experimental._program import _get_nvvm_module + + nvvm = _get_nvvm_module() + major, minor, debug_major, debug_minor = nvvm.ir_version() + + main_nvvm_ir = f"""target triple = "nvptx64-unknown-cuda" +target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-i128:128:128-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64" + +declare i32 @helper_add(i32) nounwind readnone +declare i32 @helper_mul(i32) nounwind readnone + +define void @main_kernel(i32* %data) {{ +entry: + %tid = call i32 @llvm.nvvm.read.ptx.sreg.tid.x() + %ptr = getelementptr inbounds i32, i32* %data, i32 %tid + %val = load i32, i32* %ptr, align 4 + + %val1 = call i32 @helper_add(i32 %val) + %val2 = call i32 @helper_mul(i32 %val1) + + store i32 %val2, i32* %ptr, align 4 + ret void +}} + +declare i32 @llvm.nvvm.read.ptx.sreg.tid.x() nounwind readnone + +!nvvm.annotations = !{{!0}} +!0 = !{{void (i32*)* @main_kernel, !"kernel", i32 1}} + +!nvvmir.version = !{{!1}} +!1 = !{{i32 {major}, i32 {minor}, i32 {debug_major}, i32 {debug_minor}}} +""" # noqa: E501 + + helper1_ir = f"""target triple = "nvptx64-unknown-cuda" +target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-i128:128:128-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64" + +define i32 @helper_add(i32 %x) nounwind readnone {{ +entry: + %result = add i32 %x, 1 + ret i32 %result +}} + +!nvvmir.version = !{{!0}} +!0 = !{{i32 {major}, i32 {minor}, i32 {debug_major}, i32 {debug_minor}}} +""" # noqa: E501 + + helper2_ir = f"""target triple = "nvptx64-unknown-cuda" +target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-i128:128:128-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64" + +define i32 @helper_mul(i32 %x) nounwind readnone {{ +entry: + %result = mul i32 %x, 2 + ret i32 %result +}} + +!nvvmir.version = !{{!0}} +!0 = !{{i32 {major}, i32 {minor}, i32 {debug_major}, i32 {debug_minor}}} +""" # noqa: E501 + + options = ProgramOptions( + name="nvvm_multi_helper_test", + extra_sources=[ + ("helper1", helper1_ir), + ("helper2", helper2_ir), + ], + ) + program = Program(main_nvvm_ir, "nvvm", options) + + assert program.backend == "NVVM" + ptx_code = program.compile("ptx") + assert isinstance(ptx_code, ObjectCode) + assert ptx_code.name == "nvvm_multi_helper_test" + + ltoir_code = program.compile("ltoir") + assert isinstance(ltoir_code, ObjectCode) + assert ltoir_code.name == "nvvm_multi_helper_test" + + program.close() + + +@nvvm_available +@pytest.mark.skipif(not _test_helpers_available, reason="cuda_python_test_helpers not accessible") +def test_bitcode_format(minimal_nvvmir): + if len(minimal_nvvmir) < 4: + pytest.skip("Bitcode file is not valid or empty") + + options = ProgramOptions(name="minimal_nvvmir_bitcode_test", arch="sm_90") + program = Program(minimal_nvvmir, "nvvm", options) + + assert program.backend == "NVVM" + ptx_result = program.compile("ptx") + assert isinstance(ptx_result, ObjectCode) + assert ptx_result.name == "minimal_nvvmir_bitcode_test" + assert len(ptx_result.code) > 0 + program_lto = Program(minimal_nvvmir, "nvvm", options) + try: + ltoir_result = program_lto.compile("ltoir") + assert isinstance(ltoir_result, ObjectCode) + assert len(ltoir_result.code) > 0 + print(f"LTOIR size: {len(ltoir_result.code)} bytes") + except Exception as e: + print(f"LTOIR compilation failed : {e}") + finally: + program.close() + + +def test_cpp_program_with_extra_sources(): + # negative test with NVRTC with multiple sources + code = 'extern "C" __global__ void my_kernel(){}' + helper = 'extern "C" __global__ void helper(){}' + options = ProgramOptions(extra_sources=helper) + with pytest.raises(ValueError, match="extra_sources is not supported by the NVRTC backend"): + Program(code, "c++", options) def test_program_options_as_bytes_nvrtc(): """Test ProgramOptions.as_bytes() for NVRTC backend""" options = ProgramOptions(arch="sm_80", debug=True, lineinfo=True, ftz=True) @@ -546,4 +736,4 @@ def test_program_options_as_bytes_nvvm_unsupported_option(): """Test that unsupported options raise CUDAError for NVVM backend""" options = ProgramOptions(arch="sm_80", lineinfo=True) with pytest.raises(CUDAError, match="not supported by NVVM backend"): - options.as_bytes("nvvm") + options.as_bytes("nvvm") \ No newline at end of file diff --git a/cuda_pathfinder/cuda/pathfinder/__init__.py b/cuda_pathfinder/cuda/pathfinder/__init__.py index 143c4b45cc..8820917318 100644 --- a/cuda_pathfinder/cuda/pathfinder/__init__.py +++ b/cuda_pathfinder/cuda/pathfinder/__init__.py @@ -6,6 +6,11 @@ from cuda.pathfinder._dynamic_libs.load_dl_common import DynamicLibNotFoundError as DynamicLibNotFoundError from cuda.pathfinder._dynamic_libs.load_dl_common import LoadedDL as LoadedDL from cuda.pathfinder._dynamic_libs.load_nvidia_dynamic_lib import load_nvidia_dynamic_lib as load_nvidia_dynamic_lib +from cuda.pathfinder._dynamic_libs.find_libdevice import ( + LibdeviceNotFoundError as LibdeviceNotFoundError, + find_libdevice as find_libdevice, + get_libdevice_path as get_libdevice_path, +) from cuda.pathfinder._dynamic_libs.supported_nvidia_libs import ( SUPPORTED_LIBNAMES as SUPPORTED_NVIDIA_LIBNAMES, # noqa: F401 ) diff --git a/cuda_pathfinder/cuda/pathfinder/_dynamic_libs/find_libdevice.py b/cuda_pathfinder/cuda/pathfinder/_dynamic_libs/find_libdevice.py new file mode 100644 index 0000000000..f1c0763765 --- /dev/null +++ b/cuda_pathfinder/cuda/pathfinder/_dynamic_libs/find_libdevice.py @@ -0,0 +1,146 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +import functools +import glob +import os + +from cuda.pathfinder._dynamic_libs.supported_nvidia_libs import IS_WINDOWS +from cuda.pathfinder._utils.find_sub_dirs import find_sub_dirs_all_sitepackages + +# Site-package paths for libdevice (following SITE_PACKAGES_LIBDIRS pattern) +SITE_PACKAGES_LIBDEVICE_DIRS = ( + "nvidia/cuda_nvvm/nvvm/libdevice", # CTK 13+ + "nvidia/cuda_nvcc/nvvm/libdevice", # CTK <13 +) + + +class LibdeviceNotFoundError(Exception): + pass + + +def _get_cuda_home_or_path() -> str | None: + return os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") + + +def _no_such_file_in_dir( + dir_path: str, filename: str, error_messages: list[str], attachments: list[str] +) -> None: + error_messages.append(f"No such file: {os.path.join(dir_path, filename)}") + if os.path.isdir(dir_path): + attachments.append(f' listdir("{dir_path}"):') + for node in sorted(os.listdir(dir_path)): + attachments.append(f" {node}") + else: + attachments.append(f' Directory does not exist: "{dir_path}"') + + +class _FindLibdevice: + FILENAME = "libdevice.10.bc" + REL_PATH = os.path.join("nvvm", "libdevice") + + def __init__(self): + self.error_messages: list[str] = [] + self.attachments: list[str] = [] + self.abs_path: str | None = None + + def try_site_packages(self) -> str | None: + for rel_dir in SITE_PACKAGES_LIBDEVICE_DIRS: + sub_dir = tuple(rel_dir.split("/")) + for abs_dir in find_sub_dirs_all_sitepackages(sub_dir): + file_path = os.path.join(abs_dir, self.FILENAME) + if os.path.isfile(file_path): + return file_path + return None + + def try_with_conda_prefix(self) -> str | None: + conda_prefix = os.environ.get("CONDA_PREFIX") + if not conda_prefix: + return None + + anchor = os.path.join(conda_prefix, "Library") if IS_WINDOWS else conda_prefix + file_path = os.path.join(anchor, self.REL_PATH, self.FILENAME) + if os.path.isfile(file_path): + return file_path + return None + + def try_with_cuda_home(self) -> str | None: + cuda_home = _get_cuda_home_or_path() + if cuda_home is None: + self.error_messages.append("CUDA_HOME/CUDA_PATH not set") + return None + + file_path = os.path.join(cuda_home, self.REL_PATH, self.FILENAME) + if os.path.isfile(file_path): + return file_path + + _no_such_file_in_dir( + os.path.join(cuda_home, self.REL_PATH), + self.FILENAME, + self.error_messages, + self.attachments, + ) + return None + + def try_common_paths(self) -> str | None: + if IS_WINDOWS: + bases = [r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA", r"C:\CUDA"] + else: + bases = ["/usr/local/cuda", "/opt/cuda"] + + for base in bases: + # Direct path + file_path = os.path.join(base, self.REL_PATH, self.FILENAME) + if os.path.isfile(file_path): + return file_path + # Versioned paths (e.g., /usr/local/cuda-13.0) + for versioned in sorted(glob.glob(base + "*"), reverse=True): + if os.path.isdir(versioned): + file_path = os.path.join(versioned, self.REL_PATH, self.FILENAME) + if os.path.isfile(file_path): + return file_path + return None + + def raise_not_found_error(self) -> None: + err = ", ".join(self.error_messages) if self.error_messages else "No search paths available" + att = "\n".join(self.attachments) if self.attachments else "" + raise LibdeviceNotFoundError( + f'Failure finding "{self.FILENAME}": {err}\n{att}' + ) + + +@functools.cache +def find_libdevice() -> str: + """Find the path to libdevice.10.bc. + Raises: + LibdeviceNotFoundError: If libdevice.10.bc cannot be found + """ + finder = _FindLibdevice() + + abs_path = finder.try_site_packages() + if abs_path is None: + abs_path = finder.try_with_conda_prefix() + if abs_path is None: + abs_path = finder.try_with_cuda_home() + if abs_path is None: + abs_path = finder.try_common_paths() + + if abs_path is None: + finder.raise_not_found_error() + + return abs_path + + +def get_libdevice_path() -> str | None: + """Get the path to libdevice.10.bc, or None if not found.""" + finder = _FindLibdevice() + + abs_path = finder.try_site_packages() + if abs_path is None: + abs_path = finder.try_with_conda_prefix() + if abs_path is None: + abs_path = finder.try_with_cuda_home() + if abs_path is None: + abs_path = finder.try_common_paths() + + return abs_path + diff --git a/cuda_python_test_helpers/cuda_python_test_helpers/nvvm_bitcode.py b/cuda_python_test_helpers/cuda_python_test_helpers/nvvm_bitcode.py new file mode 100644 index 0000000000..e42dbff085 --- /dev/null +++ b/cuda_python_test_helpers/cuda_python_test_helpers/nvvm_bitcode.py @@ -0,0 +1,135 @@ +import binascii + +import pytest +from cuda.bindings import nvvm + +MINIMAL_NVVMIR_TXT_TEMPLATE = b"""\ +target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-i128:128:128-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64" + +target triple = "nvptx64-nvidia-cuda" + +define void @kernel() { +entry: + ret void +} + +!nvvm.annotations = !{!0} +!0 = !{void ()* @kernel, !"kernel", i32 1} + +!nvvmir.version = !{!1} +!1 = !{i32 %d, i32 0, i32 %d, i32 0} +""" # noqa: E501 + +MINIMAL_NVVMIR_BITCODE_STATIC = { + (1, 3): # (major, debug_major) + "4243c0de3514000005000000620c30244a59be669dfbb4bf0b51804c01000000210c00007f010000" + "0b02210002000000160000000781239141c80449061032399201840c250508191e048b62800c4502" + "42920b42641032143808184b0a3232884870c421234412878c1041920264c808b1142043468820c9" + "01323284182a282a90317cb05c9120c3c8000000892000000b0000003222c80820624600212b2498" + "0c212524980c19270c85a4906032645c20246382a01801300128030173046000132677b00778a007" + "7cb0033a680377b0877420877408873618877a208770d8e012e5d006f0a0077640077a600774a007" + "7640076d900e71a00778a00778d006e980077a80077a80076d900e7160077a100776a0077160076d" + "900e7320077a300772a0077320076d900e7640077a600774a0077640076d900e71200778a0077120" + "0778a00771200778d006e6300772a0077320077a300772d006e6600774a0077640077a600774d006" + "f6100776a0077160077a100776d006f6300772a0077320077a300772d006f6600774a0077640077a" + "600774d006f610077280077a10077280077a10077280076de00e7160077a300772a0077640071a21" + "4c0e11de9c2e4fbbcfbe211560040000000000000000000000000620b141a0e86000004016080000" + "06000000321e980c19114c908c092647c6044362098c009401000000b1180000ac0000003308801c" + "c4e11c6614013d88433884c38c4280077978077398710ce6000fed100ef4800e330c421ec2c11dce" + "a11c6630053d88433884831bcc033dc8433d8c033dcc788c7470077b08077948877070077a700376" + "788770208719cc110eec900ee1300f6e300fe3f00ef0500e3310c41dde211cd8211dc2611e663089" + "3bbc833bd04339b4033cbc833c84033bccf0147660077b6807376887726807378087709087706007" + "76280776f8057678877780875f08877118877298877998812ceef00eeee00ef5c00eec300362c8a1" + "1ce4a11ccca11ce4a11cdc611cca211cc4811dca6106d6904339c84339984339c84339b8c3389443" + "3888033b94c32fbc833cfc823bd4033bb0c30cc7698770588772708374680778608774188774a087" + "19ce530fee000ff2500ee4900ee3400fe1200eec500e3320281ddcc11ec2411ed2211cdc811edce0" + "1ce4e11dea011e66185138b0433a9c833bcc50247660077b68073760877778077898514cf4900ff0" + "500e331e6a1eca611ce8211ddec11d7e011ee4a11ccc211df0610654858338ccc33bb0433dd04339" + "fcc23ce4433b88c33bb0c38cc50a877998877718877408077a28077298815ce3100eecc00ee5500e" + "f33023c1d2411ee4e117d8e11dde011e6648193bb0833db4831b84c3388c4339ccc33cb8c139c8c3" + "3bd4033ccc48b471080776600771088771588719dbc60eec600fede006f0200fe5300fe5200ff650" + "0e6e100ee3300ee5300ff3e006e9e00ee4500ef83023e2ec611cc2811dd8e117ec211de6211dc421" + "1dd8211de8211f66209d3bbc433db80339948339cc58bc7070077778077a08077a488777708719cb" + "e70eef300fe1e00ee9400fe9a00fe530c3010373a8077718875f988770708774a08774d087729881" + "844139e0c338b0433d904339cc40c4a01dcaa11de0411edec11c662463300ee1c00eec300fe9400f" + "e5000000792000001d000000721e482043880c19097232482023818c9191d144a01028643c313242" + "8e9021a318100a00060000006b65726e656c0000230802308240042308843082400c330c4230cc40" + "0c4441c84860821272b3b36b730973737ba30ba34b7b739b1b2528d271b3b36b4b9373b12b939b4b" + "7b731b2530000000a9180000250000000b0a7228877780077a587098433db8c338b04339d0c382e6" + "1cc6a10de8411ec2c11de6211de8211ddec11d1634e3600ee7500fe1200fe4400fe1200fe7500ef4" + "b08081077928877060077678877108077a28077258709cc338b4013ba4833d94c3026b1cd8211cdc" + "e11cdc201ce4611cdc201ce8811ec2611cd0a11cc8611cc2811dd861c1010ff4200fe1500ff4800e" + "00000000d11000000600000007cc3ca4833b9c033b94033da0833c94433890c30100000061200000" + "06000000130481860301000002000000075010cd14610000000000007120000003000000320e1022" + "8400fb020000000000000000650c00001f000000120394f000000000030000000600000006000000" + "4c000000010000005800000000000000580000000100000070000000000000000c00000013000000" + "1f000000080000000600000000000000700000000000000000000000010000000000000000000000" + "060000000000000006000000ffffffff00240000000000005d0c00000d0000001203946700000000" + "6b65726e656c31352e302e376e7670747836342d6e76696469612d637564613c737472696e673e00" + "00000000", + (2, 3): # (major, debug_major) + "4243c0de3514000005000000620c30244a59be669dfbb4bf0b51804c01000000210c000080010000" + "0b02210002000000160000000781239141c80449061032399201840c250508191e048b62800c4502" + "42920b42641032143808184b0a3232884870c421234412878c1041920264c808b1142043468820c9" + "01323284182a282a90317cb05c9120c3c8000000892000000b0000003222c80820624600212b2498" + "0c212524980c19270c85a4906032645c20246382a01801300128030173046000132677b00778a007" + "7cb0033a680377b0877420877408873618877a208770d8e012e5d006f0a0077640077a600774a007" + "7640076d900e71a00778a00778d006e980077a80077a80076d900e7160077a100776a0077160076d" + "900e7320077a300772a0077320076d900e7640077a600774a0077640076d900e71200778a0077120" + "0778a00771200778d006e6300772a0077320077a300772d006e6600774a0077640077a600774d006" + "f6100776a0077160077a100776d006f6300772a0077320077a300772d006f6600774a0077640077a" + "600774d006f610077280077a10077280077a10077280076de00e7160077a300772a0077640071a21" + "4c0e11de9c2e4fbbcfbe211560040000000000000000000000000620b141a0286100004016080000" + "06000000321e980c19114c908c092647c60443620914c10840190000b1180000ac0000003308801c" + "c4e11c6614013d88433884c38c4280077978077398710ce6000fed100ef4800e330c421ec2c11dce" + "a11c6630053d88433884831bcc033dc8433d8c033dcc788c7470077b08077948877070077a700376" + "788770208719cc110eec900ee1300f6e300fe3f00ef0500e3310c41dde211cd8211dc2611e663089" + "3bbc833bd04339b4033cbc833c84033bccf0147660077b6807376887726807378087709087706007" + "76280776f8057678877780875f08877118877298877998812ceef00eeee00ef5c00eec300362c8a1" + "1ce4a11ccca11ce4a11cdc611cca211cc4811dca6106d6904339c84339984339c84339b8c3389443" + "3888033b94c32fbc833cfc823bd4033bb0c30cc7698770588772708374680778608774188774a087" + "19ce530fee000ff2500ee4900ee3400fe1200eec500e3320281ddcc11ec2411ed2211cdc811edce0" + "1ce4e11dea011e66185138b0433a9c833bcc50247660077b68073760877778077898514cf4900ff0" + "500e331e6a1eca611ce8211ddec11d7e011ee4a11ccc211df0610654858338ccc33bb0433dd04339" + "fcc23ce4433b88c33bb0c38cc50a877998877718877408077a28077298815ce3100eecc00ee5500e" + "f33023c1d2411ee4e117d8e11dde011e6648193bb0833db4831b84c3388c4339ccc33cb8c139c8c3" + "3bd4033ccc48b471080776600771088771588719dbc60eec600fede006f0200fe5300fe5200ff650" + "0e6e100ee3300ee5300ff3e006e9e00ee4500ef83023e2ec611cc2811dd8e117ec211de6211dc421" + "1dd8211de8211f66209d3bbc433db80339948339cc58bc7070077778077a08077a488777708719cb" + "e70eef300fe1e00ee9400fe9a00fe530c3010373a8077718875f988770708774a08774d087729881" + "844139e0c338b0433d904339cc40c4a01dcaa11de0411edec11c662463300ee1c00eec300fe9400f" + "e5000000792000001e000000721e482043880c19097232482023818c9191d144a01028643c313242" + "8e9021a318100a00060000006b65726e656c0000230802308240042308843082400c23080431c320" + "04c30c045118858c04262821373bbb36973037b737ba30bab437b7b95102231d373bbbb6343917bb" + "32b9b9b437b7518203000000a9180000250000000b0a7228877780077a587098433db8c338b04339" + "d0c382e61cc6a10de8411ec2c11de6211de8211ddec11d1634e3600ee7500fe1200fe4400fe1200f" + "e7500ef4b08081077928877060077678877108077a28077258709cc338b4013ba4833d94c3026b1c" + "d8211cdce11cdc201ce4611cdc201ce8811ec2611cd0a11cc8611cc2811dd861c1010ff4200fe150" + "0ff4800e00000000d11000000600000007cc3ca4833b9c033b94033da0833c94433890c301000000" + "6120000006000000130481860301000002000000075010cd14610000000000007120000003000000" + "320e10228400fc020000000000000000650c00001f000000120394f0000000000300000006000000" + "060000004c000000010000005800000000000000580000000100000070000000000000000c000000" + "130000001f0000000800000006000000000000007000000000000000000000000100000000000000" + "00000000060000000000000006000000ffffffff00240000000000005d0c00000d00000012039467" + "000000006b65726e656c31352e302e376e7670747836342d6e76696469612d637564613c73747269" + "6e673e0000000000", +} + + +@pytest.fixture(params=("txt", "bitcode_static")) +def minimal_nvvmir(request): + major, minor, debug_major, debug_minor = nvvm.ir_version() + + if request.param == "txt": + return MINIMAL_NVVMIR_TXT_TEMPLATE % (major, debug_major) + + bitcode_static_binascii = MINIMAL_NVVMIR_BITCODE_STATIC.get((major, debug_major)) + if bitcode_static_binascii: + return binascii.unhexlify(bitcode_static_binascii) + raise RuntimeError( + "Static bitcode for NVVM IR version " + f"{major}.{debug_major} is not available in this test.\n" + "Maintainers: Please run the helper script to generate it and add the " + "output to the MINIMAL_NVVMIR_BITCODE_STATIC dict:\n" + " ../../toolshed/build_static_bitcode_input.py" + )