|
13 | 13 |
|
14 | 14 | import itertools |
15 | 15 | from collections.abc import Callable, Iterable, Mapping, Sequence |
| 16 | +from numbers import Integral |
16 | 17 | from typing import Any |
17 | 18 |
|
18 | 19 | import numpy as np |
|
33 | 34 | optional_import, |
34 | 35 | ) |
35 | 36 |
|
| 37 | + |
36 | 38 | tqdm, _ = optional_import("tqdm", name="tqdm") |
37 | 39 | _nearest_mode = "nearest-exact" |
38 | 40 |
|
@@ -131,11 +133,27 @@ def sliding_window_inference( |
131 | 133 | kwargs: optional keyword args to be passed to ``predictor``. |
132 | 134 |
|
133 | 135 | Note: |
134 | | - - input must be channel-first and have a batch dim, supports N-D sliding window. |
| 136 | + - Inputs must be channel-first and have a batch dim (NCHW / NCDHW). |
| 137 | + - If your data is NHWC/NDHWC, please apply `EnsureChannelFirst` / `EnsureChannelFirstd` upstream. |
135 | 138 |
|
136 | 139 | """ |
137 | | - buffered = buffer_steps is not None and buffer_steps > 0 |
138 | 140 | num_spatial_dims = len(inputs.shape) - 2 |
| 141 | + |
| 142 | + # Only perform strict shape validation if roi_size is a sequence (explicit dimensions). |
| 143 | + # If roi_size is an integer, it is broadcast to all dimensions, so we cannot |
| 144 | + # infer the expected dimensionality to enforce a strict check here. |
| 145 | + if not isinstance(roi_size, Integral): |
| 146 | + roi_dims = len(roi_size) |
| 147 | + if num_spatial_dims != roi_dims: |
| 148 | + raise ValueError( |
| 149 | + f"inputs must have {roi_dims + 2} dimensions for {roi_dims}D roi_size " |
| 150 | + f"(Batch, Channel, {', '.join(['Spatial'] * roi_dims)}), " |
| 151 | + f"but got inputs shape {inputs.shape}.\n" |
| 152 | + "If you have channel-last data (e.g. B, D, H, W, C), please use " |
| 153 | + "monai.transforms.EnsureChannelFirst or EnsureChannelFirstd upstream." |
| 154 | + ) |
| 155 | + # ----------------------------------------------------------------- |
| 156 | + buffered = buffer_steps is not None and buffer_steps > 0 |
139 | 157 | if buffered: |
140 | 158 | if buffer_dim < -num_spatial_dims or buffer_dim > num_spatial_dims: |
141 | 159 | raise ValueError(f"buffer_dim must be in [{-num_spatial_dims}, {num_spatial_dims}], got {buffer_dim}.") |
|
0 commit comments