|
22 | 22 | from collections.abc import Callable, Iterable, Iterator, Sequence |
23 | 23 | from dataclasses import dataclass |
24 | 24 | from pathlib import Path |
25 | | -from typing import TYPE_CHECKING, Any |
| 25 | +from typing import TYPE_CHECKING, Any, Union |
26 | 26 |
|
27 | 27 | import numpy as np |
28 | 28 | from torch.utils.data._utils.collate import np_str_obj_array_pattern |
|
57 | 57 | cp, has_cp = optional_import("cupy") |
58 | 58 | kvikio, has_kvikio = optional_import("kvikio") |
59 | 59 |
|
| 60 | +if TYPE_CHECKING: |
| 61 | + import cupy |
| 62 | + |
| 63 | + NdarrayOrCupy = Union[np.ndarray, cupy.ndarray] |
| 64 | +else: |
| 65 | + NdarrayOrCupy = Any |
| 66 | + |
60 | 67 | __all__ = ["ImageReader", "ITKReader", "NibabelReader", "NumpyReader", "PILReader", "PydicomReader", "NrrdReader"] |
61 | 68 |
|
62 | 69 |
|
@@ -663,10 +670,7 @@ def get_data(self, data) -> tuple[np.ndarray, dict]: |
663 | 670 | metadata[MetaKeys.SPATIAL_SHAPE] = data_array.shape |
664 | 671 | dicom_data.append((data_array, metadata)) |
665 | 672 |
|
666 | | - # TODO: the actual type is list[np.ndarray | cp.ndarray] |
667 | | - # should figure out how to define correct types without having cupy not found error |
668 | | - # https://github.com/Project-MONAI/MONAI/pull/8188#discussion_r1886645918 |
669 | | - img_array: list[np.ndarray] = [] |
| 673 | + img_array: list[NdarrayOrCupy] = [] |
670 | 674 | compatible_meta: dict = {} |
671 | 675 |
|
672 | 676 | for data_array, metadata in ensure_tuple(dicom_data): |
@@ -1104,10 +1108,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]: |
1104 | 1108 | img: a Nibabel image object loaded from an image file or a list of Nibabel image objects. |
1105 | 1109 |
|
1106 | 1110 | """ |
1107 | | - # TODO: the actual type is list[np.ndarray | cp.ndarray] |
1108 | | - # should figure out how to define correct types without having cupy not found error |
1109 | | - # https://github.com/Project-MONAI/MONAI/pull/8188#discussion_r1886645918 |
1110 | | - img_array: list[np.ndarray] = [] |
| 1111 | + img_array: list[NdarrayOrCupy] = [] |
1111 | 1112 | compatible_meta: dict = {} |
1112 | 1113 |
|
1113 | 1114 | for i, filename in zip(ensure_tuple(img), self.filenames): |
|
0 commit comments