diff --git a/xtuner/v1/utils/env_check.py b/xtuner/v1/utils/env_check.py index dad82a4b4..c7383bfc5 100644 --- a/xtuner/v1/utils/env_check.py +++ b/xtuner/v1/utils/env_check.py @@ -1,6 +1,5 @@ from typing import Any, Callable, List - def check_torch_accelerator_available(): """Check if PyTorch is installed and the torch accelerator is available. @@ -14,13 +13,17 @@ def check_torch_accelerator_available(): except Exception: return False - def check_triton_available(): """Check if Triton is installed. Returns: bool: True if Triton is installed, False otherwise. """ + import os + + if os.environ.get("XTUNER_USE_TRITON", "1") == "0": + return False + try: import triton # noqa: F401 @@ -28,7 +31,6 @@ def check_triton_available(): except ImportError: return False - def get_env_not_available_func(env_name_list: List[str]) -> Callable: """Get a function that raises an error indicating the environment is not available. @@ -42,7 +44,6 @@ def env_not_available_func(*args: Any, **kwargs: Any) -> Any: return env_not_available_func - def get_rollout_engine_version() -> dict: import os