From 7a9dc13c1d64ed9d3c007af583e4895b98a0e962 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 3 May 2024 13:46:49 +0200 Subject: [PATCH 1/5] Add prototype of MPS support --- napari_cellseg3d/code_models/model_framework.py | 16 ++++++++++++++++ napari_cellseg3d/utils.py | 16 ++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py index 7caff7b6..ad42cdb9 100644 --- a/napari_cellseg3d/code_models/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -3,6 +3,8 @@ from typing import TYPE_CHECKING import torch +import torch.backends +import torch.backends.mps if TYPE_CHECKING: import napari @@ -94,6 +96,18 @@ def __init__( available_devices = ["CPU"] + [ f"GPU {i}" for i in range(torch.cuda.device_count()) ] + from napari_cellseg3d.utils import _is_mps_available + + try: + if ( + _is_mps_available() + and torch.backends.mps.is_available() + and torch.backends.mps.is_built() + ): + available_devices.append("MPS") + except Exception as e: + logger.error(f"Error while checking MPS availability : {e}") + self.device_choice = ui.DropdownMenu( available_devices, parent=self, @@ -345,6 +359,8 @@ def check_device_choice(self): elif "GPU" in choice: i = int(choice.split(" ")[1]) device = f"cuda:{i}" + elif choice == "MPS": + device = "mps" else: device = self.get_device() logger.debug(f"DEVICE choice : {device}") diff --git a/napari_cellseg3d/utils.py b/napari_cellseg3d/utils.py index 45f50582..fd9f2f32 100644 --- a/napari_cellseg3d/utils.py +++ b/napari_cellseg3d/utils.py @@ -7,6 +7,7 @@ import napari import numpy as np +import pkg_resources import torch from monai.transforms import Zoom from numpy.random import PCG64, Generator @@ -645,3 +646,18 @@ def fraction_above_threshold(volume: np.array, threshold=0.5) -> float: f"non zero in above_thresh : {np.count_nonzero(above_thresh)}" ) return np.count_nonzero(above_thresh) / np.size(flattened) + + +def _is_mps_available(torch): + available = False + if pkg_resources.parse_version( + torch.__version__ + ) >= pkg_resources.parse_version("1.12"): + LOGGER.debug("Torch version is 1.12 or higher, compatible with MPS") + if torch.backends.mps.is_available(): + LOGGER.debug("MPS is available") + if torch.backends.mps.is_built(): + LOGGER.debug("MPS is built") + available = True + + return available From 6f1a6460ae1e79c7fae7d828a3f4bbd2bc62a6a8 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 3 May 2024 13:48:49 +0200 Subject: [PATCH 2/5] Fix checking for MPS --- napari_cellseg3d/code_models/model_framework.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py index ad42cdb9..66d655fd 100644 --- a/napari_cellseg3d/code_models/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -99,11 +99,7 @@ def __init__( from napari_cellseg3d.utils import _is_mps_available try: - if ( - _is_mps_available() - and torch.backends.mps.is_available() - and torch.backends.mps.is_built() - ): + if _is_mps_available(torch): available_devices.append("MPS") except Exception as e: logger.error(f"Error while checking MPS availability : {e}") From f89ca8e470769be174702f625a9ca9d7007412c0 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 3 May 2024 13:49:22 +0200 Subject: [PATCH 3/5] Update model_framework.py --- napari_cellseg3d/code_models/model_framework.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py index 66d655fd..1d90e906 100644 --- a/napari_cellseg3d/code_models/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -100,7 +100,7 @@ def __init__( try: if _is_mps_available(torch): - available_devices.append("MPS") + available_devices.append("MPS (beta)") except Exception as e: logger.error(f"Error while checking MPS availability : {e}") @@ -355,7 +355,7 @@ def check_device_choice(self): elif "GPU" in choice: i = int(choice.split(" ")[1]) device = f"cuda:{i}" - elif choice == "MPS": + elif choice == "MPS (beta)": # TODO : check if MPS is available device = "mps" else: device = self.get_device() From 3a0dd115da4186ce37ba6573acb7569f7f912d3f Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 3 May 2024 14:05:00 +0200 Subject: [PATCH 4/5] Add fallback environment variable --- napari_cellseg3d/code_models/model_framework.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py index 1d90e906..2ac0eaad 100644 --- a/napari_cellseg3d/code_models/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -100,6 +100,9 @@ def __init__( try: if _is_mps_available(torch): + from os import environ + + environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" available_devices.append("MPS (beta)") except Exception as e: logger.error(f"Error while checking MPS availability : {e}") From 822826f5a3ab82f0c8cef3ed8460ce4950a8bfb5 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 3 May 2024 14:15:41 +0200 Subject: [PATCH 5/5] Set fallback mode for MPS directly in workers --- napari_cellseg3d/code_models/model_framework.py | 3 --- napari_cellseg3d/code_models/worker_inference.py | 5 +++++ napari_cellseg3d/code_models/worker_training.py | 5 +++++ 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/napari_cellseg3d/code_models/model_framework.py b/napari_cellseg3d/code_models/model_framework.py index 2ac0eaad..1d90e906 100644 --- a/napari_cellseg3d/code_models/model_framework.py +++ b/napari_cellseg3d/code_models/model_framework.py @@ -100,9 +100,6 @@ def __init__( try: if _is_mps_available(torch): - from os import environ - - environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" available_devices.append("MPS (beta)") except Exception as e: logger.error(f"Error while checking MPS availability : {e}") diff --git a/napari_cellseg3d/code_models/worker_inference.py b/napari_cellseg3d/code_models/worker_inference.py index c69dfe45..1a6bd3c4 100644 --- a/napari_cellseg3d/code_models/worker_inference.py +++ b/napari_cellseg3d/code_models/worker_inference.py @@ -838,6 +838,11 @@ def inference(self): torch.set_num_threads(1) # required for threading on macOS ? self.log("Number of threads has been set to 1 for macOS") + if self.config.device == "mps": + from os import environ + + environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + try: dims = self.config.model_info.model_input_size self.log(f"MODEL DIMS : {dims}") diff --git a/napari_cellseg3d/code_models/worker_training.py b/napari_cellseg3d/code_models/worker_training.py index 9ad08d24..7c37fd0f 100644 --- a/napari_cellseg3d/code_models/worker_training.py +++ b/napari_cellseg3d/code_models/worker_training.py @@ -1127,6 +1127,11 @@ def train( weights_config = self.config.weights_info deterministic_config = self.config.deterministic_config + if self.config.device == "mps": + from os import environ + + environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + start_time = time.time() try: