-
Notifications
You must be signed in to change notification settings - Fork 235
Add CUDA version compatibility check #1412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE | ||
|
|
||
| import os | ||
| import warnings | ||
|
|
||
| # Track whether we've already checked version compatibility | ||
| _version_compatibility_checked = False | ||
|
|
||
|
|
||
| def check_cuda_version_compatibility(): | ||
| """Check if the CUDA driver version is compatible with cuda-bindings compile-time version. | ||
| This function compares the CUDA version that cuda-bindings was compiled against | ||
| with the CUDA version supported by the installed driver. If the compile-time | ||
| major version is greater than the driver's major version, a warning is issued. | ||
| The check runs only once per process. Subsequent calls are no-ops. | ||
| The warning can be suppressed by setting the environment variable | ||
| ``CUDA_PYTHON_DISABLE_VERSION_CHECK=1``. | ||
| Examples | ||
| -------- | ||
| >>> from cuda.bindings.utils import check_cuda_version_compatibility | ||
| >>> check_cuda_version_compatibility() # Issues warning if version mismatch | ||
| """ | ||
| global _version_compatibility_checked | ||
| if _version_compatibility_checked: | ||
| return | ||
| _version_compatibility_checked = True | ||
|
|
||
| # Allow users to suppress the warning | ||
| if os.environ.get("CUDA_PYTHON_DISABLE_VERSION_CHECK"): | ||
| return | ||
|
|
||
| # Import here to avoid circular imports and allow lazy loading | ||
| from cuda.bindings import driver | ||
|
|
||
| # Get compile-time CUDA version from cuda-bindings | ||
| try: | ||
| compile_version = driver.CUDA_VERSION # e.g., 13010 | ||
| except AttributeError: | ||
| # Older cuda-bindings may not expose CUDA_VERSION | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm ... this code is in the cuda-bindings sources, it cannot possibly be in older cuda-bindings releases. I think the try-except can go completely? |
||
| return | ||
|
|
||
| # Get runtime driver version | ||
| err, runtime_version = driver.cuDriverGetVersion() | ||
| if err != driver.CUresult.CUDA_SUCCESS: | ||
| return # Can't check, skip silently | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not being able to query from the driver version is worthy of a warning to the user instead of silently eating it.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd even assume here that this should never error, and surface an exception if it does, thinking something is very wrong at that point, and flagging it immediately is most helpful. |
||
|
|
||
| compile_major = compile_version // 1000 | ||
| runtime_major = runtime_version // 1000 | ||
|
|
||
| if compile_major > runtime_major: | ||
| compile_minor = (compile_version % 1000) // 10 | ||
| runtime_minor = (runtime_version % 1000) // 10 | ||
| warnings.warn( | ||
| f"cuda-bindings was built against CUDA {compile_major}.{compile_minor}, " | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it'd be most helpful to be clear that the major version mismatch is the primary concern. The minor version information is more distracting than helpful. Suggested warning message: |
||
| f"but the installed driver only supports CUDA {runtime_major}.{runtime_minor}. " | ||
| f"Some features may not work correctly. Consider updating your NVIDIA driver. " | ||
| f"Set CUDA_PYTHON_DISABLE_VERSION_CHECK=1 to suppress this warning.", | ||
| UserWarning, | ||
| stacklevel=3, | ||
| ) | ||
|
|
||
|
|
||
| def _reset_version_compatibility_check(): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a real use case where production code would call this function? If it’s only used by tests, then it seems reasonable to rely implicitly on the global and move the function into the test code instead. |
||
| """Reset the version compatibility check flag for testing purposes. | ||
| This function is intended for use in tests to allow multiple test runs | ||
| to check the warning behavior. | ||
| """ | ||
| global _version_compatibility_checked | ||
| _version_compatibility_checked = False | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,113 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE | ||
|
|
||
| import os | ||
| import warnings | ||
| from unittest import mock | ||
|
|
||
| from cuda.bindings import driver | ||
| from cuda.bindings.utils import check_cuda_version_compatibility | ||
| from cuda.bindings.utils._version_check import _reset_version_compatibility_check | ||
|
|
||
|
|
||
| class TestVersionCompatibilityCheck: | ||
| """Tests for CUDA version compatibility check function.""" | ||
|
|
||
| def setup_method(self): | ||
| """Reset the version compatibility check flag before each test.""" | ||
| _reset_version_compatibility_check() | ||
|
|
||
| def teardown_method(self): | ||
| """Reset the version compatibility check flag after each test.""" | ||
| _reset_version_compatibility_check() | ||
|
|
||
| def test_no_warning_when_driver_newer(self): | ||
| """No warning should be issued when driver version >= compile version.""" | ||
| # Mock compile version 12.9 and driver version 13.0 | ||
| with ( | ||
| mock.patch.object(driver, "CUDA_VERSION", 12090), | ||
| mock.patch.object(driver, "cuDriverGetVersion", return_value=(driver.CUresult.CUDA_SUCCESS, 13000)), | ||
| warnings.catch_warnings(record=True) as w, | ||
| ): | ||
| warnings.simplefilter("always") | ||
| check_cuda_version_compatibility() | ||
| assert len(w) == 0 | ||
|
|
||
| def test_no_warning_when_same_major_version(self): | ||
| """No warning should be issued when major versions match.""" | ||
| # Mock compile version 12.9 and driver version 12.8 | ||
| with ( | ||
| mock.patch.object(driver, "CUDA_VERSION", 12090), | ||
| mock.patch.object(driver, "cuDriverGetVersion", return_value=(driver.CUresult.CUDA_SUCCESS, 12080)), | ||
| warnings.catch_warnings(record=True) as w, | ||
| ): | ||
| warnings.simplefilter("always") | ||
| check_cuda_version_compatibility() | ||
| assert len(w) == 0 | ||
|
|
||
| def test_warning_when_compile_major_newer(self): | ||
| """Warning should be issued when compile major version > driver major version.""" | ||
| # Mock compile version 13.0 and driver version 12.8 | ||
| with ( | ||
| mock.patch.object(driver, "CUDA_VERSION", 13000), | ||
| mock.patch.object(driver, "cuDriverGetVersion", return_value=(driver.CUresult.CUDA_SUCCESS, 12080)), | ||
| warnings.catch_warnings(record=True) as w, | ||
| ): | ||
| warnings.simplefilter("always") | ||
| check_cuda_version_compatibility() | ||
| assert len(w) == 1 | ||
| assert issubclass(w[0].category, UserWarning) | ||
| assert "cuda-bindings was built against CUDA 13.0" in str(w[0].message) | ||
| assert "driver only supports CUDA 12.8" in str(w[0].message) | ||
|
|
||
| def test_warning_only_issued_once(self): | ||
| """Warning should only be issued once per process.""" | ||
| with ( | ||
| mock.patch.object(driver, "CUDA_VERSION", 13000), | ||
| mock.patch.object(driver, "cuDriverGetVersion", return_value=(driver.CUresult.CUDA_SUCCESS, 12080)), | ||
| warnings.catch_warnings(record=True) as w, | ||
| ): | ||
| warnings.simplefilter("always") | ||
| check_cuda_version_compatibility() | ||
| check_cuda_version_compatibility() | ||
| check_cuda_version_compatibility() | ||
| # Only one warning despite multiple calls | ||
| assert len(w) == 1 | ||
|
|
||
| def test_warning_suppressed_by_env_var(self): | ||
| """Warning should be suppressed when CUDA_PYTHON_DISABLE_VERSION_CHECK is set.""" | ||
| with ( | ||
| mock.patch.object(driver, "CUDA_VERSION", 13000), | ||
| mock.patch.object(driver, "cuDriverGetVersion", return_value=(driver.CUresult.CUDA_SUCCESS, 12080)), | ||
| mock.patch.dict(os.environ, {"CUDA_PYTHON_DISABLE_VERSION_CHECK": "1"}), | ||
| warnings.catch_warnings(record=True) as w, | ||
| ): | ||
| warnings.simplefilter("always") | ||
| check_cuda_version_compatibility() | ||
| assert len(w) == 0 | ||
|
|
||
| def test_silent_when_driver_version_fails(self): | ||
| """Should silently skip if cuDriverGetVersion fails.""" | ||
| with ( | ||
| mock.patch.object(driver, "CUDA_VERSION", 13000), | ||
| mock.patch.object( | ||
| driver, "cuDriverGetVersion", return_value=(driver.CUresult.CUDA_ERROR_NOT_INITIALIZED, 0) | ||
| ), | ||
| warnings.catch_warnings(record=True) as w, | ||
| ): | ||
| warnings.simplefilter("always") | ||
| check_cuda_version_compatibility() | ||
| assert len(w) == 0 | ||
|
|
||
| def test_silent_when_cuda_version_not_available(self): | ||
| """Should silently skip if CUDA_VERSION attribute is not available.""" | ||
| # Simulate older cuda-bindings without CUDA_VERSION | ||
| original = driver.CUDA_VERSION | ||
| try: | ||
| del driver.CUDA_VERSION | ||
| with warnings.catch_warnings(record=True) as w: | ||
| warnings.simplefilter("always") | ||
| check_cuda_version_compatibility() | ||
| assert len(w) == 0 | ||
| finally: | ||
| driver.CUDA_VERSION = original |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd definitely add
majorinto the name for clarity, and make it more obvious what the function actually does:Also
CUDA_PYTHON_DISABLE_MAJOR_VERSION_WARNING