|
33 | 33 | optional_import, |
34 | 34 | ) |
35 | 35 |
|
36 | | -tqdm, _ = optional_import("tqdm", name="tqdm") |
| 36 | + |
| 37 | +tqdm, _ = optional_import("tqdm", name="tqdm")ㄋ |
37 | 38 | _nearest_mode = "nearest-exact" |
38 | 39 |
|
39 | 40 | __all__ = ["sliding_window_inference"] |
40 | 41 |
|
41 | 42 |
|
| 43 | + |
| 44 | + """ |
| 45 | + Enforce channel-first layout (NCHW/NCDHW) without guessing. |
| 46 | +
|
| 47 | + Args: |
| 48 | + t: Input tensor to validate. |
| 49 | + name: Name of the tensor for error messages. |
| 50 | + num_classes: Optional expected channel count at dim=1. |
| 51 | + allow_binary_two_channel: If True and num_classes==2, accept C=2. |
| 52 | +
|
| 53 | + Raises: |
| 54 | + ValueError: If tensor shape doesn't match channel-first layout or |
| 55 | + num_classes constraint. |
| 56 | +
|
| 57 | + Note: |
| 58 | + - Accepts only 4D (NCHW) or 5D (NCDHW) tensors with channel at dim=1. |
| 59 | + - Non-tensors and tensors with other dimensionalities are silently ignored. |
| 60 | + - Users must apply EnsureChannelFirst/EnsureChannelFirstd upstream for |
| 61 | + channel-last data. |
| 62 | + """ |
| 63 | + if not isinstance(t, torch.Tensor): |
| 64 | + return |
| 65 | + if t.ndim not in (4, 5): |
| 66 | + return |
| 67 | + |
| 68 | + c = int(t.shape[1]) |
| 69 | + layout = "NCHW" if t.ndim == 4 else "NCDHW" |
| 70 | + layout_last = "NHWC" if t.ndim == 4 else "NDHWC" |
| 71 | + |
| 72 | + if num_classes is not None: |
| 73 | + ok = (c == num_classes) or (num_classes == 1 and c == 1) |
| 74 | + if allow_binary_two_channel and num_classes == 2: |
| 75 | + ok = ok or (c == 2) |
| 76 | + if not ok: |
| 77 | + raise ValueError( |
| 78 | + f"{name}: expected {layout} with C(dim=1)==num_classes, " |
| 79 | + f"but got shape={tuple(t.shape)} (C={c}) and num_classes={num_classes}. " |
| 80 | + f"If your data is {layout_last}, please apply EnsureChannelFirst/EnsureChannelFirstd upstream." |
| 81 | + ) |
| 82 | + # No guessing when num_classes is None; we simply require channel at dim=1. |
| 83 | + # If callers provided NHWC/NDHWC, they must convert upstream. |
| 84 | + |
| 85 | + |
42 | 86 | def sliding_window_inference( |
43 | 87 | inputs: torch.Tensor | MetaTensor, |
44 | 88 | roi_size: Sequence[int] | int, |
@@ -131,11 +175,29 @@ def sliding_window_inference( |
131 | 175 | kwargs: optional keyword args to be passed to ``predictor``. |
132 | 176 |
|
133 | 177 | Note: |
134 | | - - input must be channel-first and have a batch dim, supports N-D sliding window. |
| 178 | + - Inputs must be channel-first and have a batch dim (NCHW / NCDHW). |
| 179 | + - If your data is NHWC/NDHWC, please apply `EnsureChannelFirst` / `EnsureChannelFirstd` upstream. |
135 | 180 |
|
136 | 181 | """ |
137 | | - buffered = buffer_steps is not None and buffer_steps > 0 |
138 | 182 | num_spatial_dims = len(inputs.shape) - 2 |
| 183 | + |
| 184 | + # Only perform strict shape validation if roi_size is a sequence (explicit dimensions). |
| 185 | + # If roi_size is an integer, it is broadcast to all dimensions, so we cannot |
| 186 | + # infer the expected dimensionality to enforce a strict check here. |
| 187 | + if not isinstance(roi_size, int): |
| 188 | + roi_dims = len(roi_size) |
| 189 | + if num_spatial_dims != roi_dims: |
| 190 | + raise ValueError( |
| 191 | + f"inputs must have {roi_dims + 2} dimensions for {roi_dims}D roi_size " |
| 192 | + f"(Batch, Channel, {', '.join(['Spatial'] * roi_dims)}), " |
| 193 | + f"but got inputs shape {inputs.shape}.\n" |
| 194 | + "If you have channel-last data (e.g. B, D, H, W, C), please use " |
| 195 | + "monai.transforms.EnsureChannelFirst or EnsureChannelFirstd upstream." |
| 196 | + ) |
| 197 | + # ----------------------------------------------------------------- |
| 198 | + # ---- Strict validation: do NOT guess or permute layouts ---- |
| 199 | + |
| 200 | + buffered = buffer_steps is not None and buffer_steps > 0 |
139 | 201 | if buffered: |
140 | 202 | if buffer_dim < -num_spatial_dims or buffer_dim > num_spatial_dims: |
141 | 203 | raise ValueError(f"buffer_dim must be in [{-num_spatial_dims}, {num_spatial_dims}], got {buffer_dim}.") |
|
0 commit comments