|
13 | 13 |
|
14 | 14 | import itertools |
15 | 15 | from collections.abc import Callable, Iterable, Mapping, Sequence |
16 | | -from typing import Any |
| 16 | +from typing import Any, Optional |
17 | 17 |
|
18 | 18 | import numpy as np |
19 | 19 | import torch |
|
33 | 33 | optional_import, |
34 | 34 | ) |
35 | 35 |
|
| 36 | + |
36 | 37 | tqdm, _ = optional_import("tqdm", name="tqdm") |
37 | 38 | _nearest_mode = "nearest-exact" |
38 | 39 |
|
39 | 40 | __all__ = ["sliding_window_inference"] |
40 | 41 |
|
41 | 42 |
|
| 43 | + |
| 44 | +def assert_channel_first( |
| 45 | + t: torch.Tensor, |
| 46 | + name: str, |
| 47 | + num_classes: Optional[int] = None, |
| 48 | + allow_binary_two_channel: bool = False, |
| 49 | +) -> None: |
| 50 | + """ |
| 51 | + Enforce channel-first layout (NCHW/NCDHW) without guessing. |
| 52 | +
|
| 53 | + Args: |
| 54 | + t: Input tensor to validate. |
| 55 | + name: Name of the tensor for error messages. |
| 56 | + num_classes: Optional expected channel count at dim=1. |
| 57 | + allow_binary_two_channel: If True and num_classes==2, accept C=2. |
| 58 | +
|
| 59 | + Raises: |
| 60 | + ValueError: If tensor shape doesn't match channel-first layout or |
| 61 | + num_classes constraint. |
| 62 | +
|
| 63 | + Note: |
| 64 | + - Accepts only 4D (NCHW) or 5D (NCDHW) tensors with channel at dim=1. |
| 65 | + - Non-tensors and tensors with other dimensionalities are silently ignored. |
| 66 | + - Users must apply EnsureChannelFirst/EnsureChannelFirstd upstream for |
| 67 | + channel-last data. |
| 68 | + """ |
| 69 | + if not isinstance(t, torch.Tensor): |
| 70 | + return |
| 71 | + if t.ndim not in (4, 5): |
| 72 | + return |
| 73 | + |
| 74 | + c = int(t.shape[1]) |
| 75 | + layout = "NCHW" if t.ndim == 4 else "NCDHW" |
| 76 | + layout_last = "NHWC" if t.ndim == 4 else "NDHWC" |
| 77 | + |
| 78 | + if num_classes is not None: |
| 79 | + ok = (c == num_classes) or (num_classes == 1 and c == 1) |
| 80 | + if allow_binary_two_channel and num_classes == 2: |
| 81 | + ok = ok or (c == 2) |
| 82 | + if not ok: |
| 83 | + raise ValueError( |
| 84 | + f"{name}: expected {layout} with C(dim=1)==num_classes, " |
| 85 | + f"but got shape={tuple(t.shape)} (C={c}) and num_classes={num_classes}. " |
| 86 | + f"If your data is {layout_last}, please apply EnsureChannelFirst/EnsureChannelFirstd upstream." |
| 87 | + ) |
| 88 | + # No guessing when num_classes is None; we simply require channel at dim=1. |
| 89 | + # If callers provided NHWC/NDHWC, they must convert upstream. |
| 90 | + |
| 91 | + |
42 | 92 | def sliding_window_inference( |
43 | 93 | inputs: torch.Tensor | MetaTensor, |
44 | 94 | roi_size: Sequence[int] | int, |
@@ -131,11 +181,30 @@ def sliding_window_inference( |
131 | 181 | kwargs: optional keyword args to be passed to ``predictor``. |
132 | 182 |
|
133 | 183 | Note: |
134 | | - - input must be channel-first and have a batch dim, supports N-D sliding window. |
| 184 | + - Inputs must be channel-first and have a batch dim (NCHW / NCDHW). |
| 185 | + - If your data is NHWC/NDHWC, please apply `EnsureChannelFirst` / `EnsureChannelFirstd` upstream. |
135 | 186 |
|
136 | 187 | """ |
137 | | - buffered = buffer_steps is not None and buffer_steps > 0 |
138 | 188 | num_spatial_dims = len(inputs.shape) - 2 |
| 189 | + |
| 190 | + # Only perform strict shape validation if roi_size is a sequence (explicit dimensions). |
| 191 | + # If roi_size is an integer, it is broadcast to all dimensions, so we cannot |
| 192 | + # infer the expected dimensionality to enforce a strict check here. |
| 193 | + if not isinstance(roi_size, int): |
| 194 | + roi_dims = len(roi_size) |
| 195 | + if num_spatial_dims != roi_dims: |
| 196 | + raise ValueError( |
| 197 | + f"inputs must have {roi_dims + 2} dimensions for {roi_dims}D roi_size " |
| 198 | + f"(Batch, Channel, {', '.join(['Spatial'] * roi_dims)}), " |
| 199 | + f"but got inputs shape {inputs.shape}.\n" |
| 200 | + "If you have channel-last data (e.g. B, D, H, W, C), please use " |
| 201 | + "monai.transforms.EnsureChannelFirst or EnsureChannelFirstd upstream." |
| 202 | + ) |
| 203 | + # ----------------------------------------------------------------- |
| 204 | + # ---- Strict validation: do NOT guess or permute layouts ---- |
| 205 | + if isinstance(inputs, torch.Tensor): |
| 206 | + assert_channel_first(inputs, "inputs") |
| 207 | + buffered = buffer_steps is not None and buffer_steps > 0 |
139 | 208 | if buffered: |
140 | 209 | if buffer_dim < -num_spatial_dims or buffer_dim > num_spatial_dims: |
141 | 210 | raise ValueError(f"buffer_dim must be in [{-num_spatial_dims}, {num_spatial_dims}], got {buffer_dim}.") |
|
0 commit comments