diff --git a/docs/api/core.md b/docs/api/core.md index e830563..aa51869 100644 --- a/docs/api/core.md +++ b/docs/api/core.md @@ -30,6 +30,14 @@ Low-level optical flow computation engine implementing variational optical flow .. autofunction:: pyflowreg.core.optical_flow.get_motion_tensor_gc ``` +```{eval-rst} +.. autofunction:: pyflowreg.core.optical_flow.get_motion_tensor_gray +``` + +```{eval-rst} +.. autofunction:: pyflowreg.core.optical_flow.get_motion_tensor_cs +``` + ### Boundary Handling ```{eval-rst} diff --git a/docs/api/session.md b/docs/api/session.md index 3dc543f..59e08cd 100644 --- a/docs/api/session.md +++ b/docs/api/session.md @@ -37,6 +37,7 @@ from pyflowreg.session.config import SessionConfig - sigma_smooth - alpha_between - iterations_between + - stage2_constancy_assumption **Configuration File Support** @@ -78,6 +79,7 @@ cc_upsample = 4 sigma_smooth = 6.0 alpha_between = 25.0 iterations_between = 100 +stage2_constancy_assumption = "gc" # Options: "gc", "gray", "cs" ``` ## Stage 1: Per-Recording Compensation @@ -321,13 +323,15 @@ config = SessionConfig( cc_upsample=4, sigma_smooth=6.0, alpha_between=25.0, - iterations_between=100 + iterations_between=100, + stage2_constancy_assumption="gc", ) # Stage 1: Motion correct each recording print("Running Stage 1...") config.flow_options = { "quality_setting": "balanced", + "constancy_assumption": "gc", "save_valid_idx": True, "save_w": False, } diff --git a/docs/user_guide/configuration.md b/docs/user_guide/configuration.md index 7e61ad4..eaff986 100644 --- a/docs/user_guide/configuration.md +++ b/docs/user_guide/configuration.md @@ -47,9 +47,17 @@ options = OFOptions( # Nonlinear diffusion parameters a_smooth=1.0, # Smoothness diffusion parameter a_data=0.45, # Data term diffusion parameter + + # Data term, default preserves MATLAB Flow-Registration behavior + constancy_assumption="gc", # Options: "gc", "gray", "cs" ) ``` +`constancy_assumption="gc"` is the default gradient constancy data term used by +the MATLAB Flow-Registration reference. `"gray"` selects gray-value constancy, +and `"cs"` selects census constancy. These data terms are implemented by the +native `flowreg` backend; the `diso` backend rejects non-default values. + ### Alpha (Smoothness Weight) Controls the tradeoff between fitting the data and enforcing smooth flow fields: diff --git a/docs/user_guide/multi_session.md b/docs/user_guide/multi_session.md index 14b1e92..b95c7ba 100644 --- a/docs/user_guide/multi_session.md +++ b/docs/user_guide/multi_session.md @@ -62,6 +62,7 @@ cc_upsample = 4 # Subpixel accuracy sigma_smooth = 6.0 # Gaussian smoothing alpha_between = 25.0 # Regularization iterations_between = 100 +stage2_constancy_assumption = "gc" # Options: "gc", "gray", "cs" ``` ### 3. Run Processing @@ -118,6 +119,9 @@ The `pyflowreg.session` pipeline always runs the same three deterministic stages - Temporal averages are reloaded from disk and the reference recording (center) is selected automatically or from `SessionConfig.center`. - `compute_between_displacement()` smooths both averages, applies phase cross-correlation for a rigid guess, then refines with the configured flow backend (`src/pyflowreg/session/stage2_between_avgs.py`). +- `stage2_constancy_assumption` controls the Stage 2 data term. The default + `"gc"` preserves MATLAB Flow-Registration behavior; `"cs"` enables the + census term for the native `flowreg` backend. - Results are written to `w_to_reference.npz` (separate `u`/`v` arrays) so MATLAB users can load them directly. **Outputs:** `w_to_reference.npz`, per-recording `status.json` updates, and `middle_idx` (0-based) pointing to the reference average. @@ -170,6 +174,7 @@ quality_setting = "fast" # Options: fast, balanced, quality buffer_size = 1000 # Frames per batch save_w = false # Don't save displacement fields save_valid_idx = true # Required for Stage 3 +# constancy_assumption = "cs" # Optional Stage 1 data term override ``` Alternatively, point to a saved MATLAB/Python options file: diff --git a/examples/session_config.toml b/examples/session_config.toml index ffefce4..a8c7170 100644 --- a/examples/session_config.toml +++ b/examples/session_config.toml @@ -42,6 +42,7 @@ sigma_smooth = 6.0 # Sigma for Gaussian filter # Optical flow refinement alpha_between = 25.0 # Regularization strength (higher = smoother) iterations_between = 100 # Solver iterations (higher = more accurate) +stage2_constancy_assumption = "gc" # Options: "gc", "gray", "cs" # === Stage 1 Flow Options (Optional) === # Provide inline overrides passed to OFOptions @@ -51,6 +52,7 @@ buffer_size = 1000 # Frames per batch save_w = false # Save displacement fields save_valid_idx = true # Required for Stage 3 save_meta_info = true # Save statistics +# constancy_assumption = "gc" # Options: "gc", "gray", "cs" # Alternatively reference a saved JSON file: # flow_options = "./saved_options/session_stage1.json" diff --git a/examples/session_config.yml b/examples/session_config.yml index 311f7c5..6d35c3e 100644 --- a/examples/session_config.yml +++ b/examples/session_config.yml @@ -31,6 +31,7 @@ cc_upsample: 4 # Subpixel accuracy (higher = more precise) sigma_smooth: 6.0 # Sigma for Gaussian filter alpha_between: 25.0 # Regularization strength (higher = smoother) iterations_between: 100 # Solver iterations (higher = more accurate) +stage2_constancy_assumption: gc # Options: gc, gray, cs # Stage 1 Flow Options (Optional) # Provide inline overrides passed to OFOptions @@ -40,6 +41,7 @@ flow_options: save_w: false # Save displacement fields save_valid_idx: true # Required for Stage 3 save_meta_info: true # Save statistics + # constancy_assumption: gc # Options: gc, gray, cs # Alternatively, reference a JSON file saved via OF_options: # flow_options: "./saved_options/session_stage1.json" diff --git a/src/pyflowreg/core/optical_flow.py b/src/pyflowreg/core/optical_flow.py index c9faf5a..19ff5e1 100644 --- a/src/pyflowreg/core/optical_flow.py +++ b/src/pyflowreg/core/optical_flow.py @@ -14,6 +14,10 @@ Main API function for computing optical flow between two frames get_motion_tensor_gc Compute motion tensor components for gradient constancy +get_motion_tensor_gray + Compute motion tensor components for gray-value constancy +get_motion_tensor_cs + Compute motion tensor components for census constancy imregister_wrapper Warp an image using computed displacement fields warpingDepth @@ -222,7 +226,7 @@ def get_motion_tensor_gray(f1, f2, hy, hx): return J11, J22, J33, J12, J13, J23 -def get_motion_tensor_cs(f1, f2, hy, hx): +def get_motion_tensor_cs(f1, f2, hy, hx, eps=None): """ Compute motion tensor components for census-based constancy assumption. @@ -242,6 +246,14 @@ def get_motion_tensor_cs(f1, f2, hy, hx): Spatial grid spacing in y-direction. hx : float Spatial grid spacing in x-direction. + eps : float, optional + Smoothing width for the smoothed Heaviside function applied to + directional differences ``r = (neighbor - center) / dist``. If None, + uses ``0.1 / 255.0``, matching the Hafner/Demetz/Weickert + ``epsilon = 0.1`` convention for images scaled from ``[0, 255]`` to + approximately ``[0, 1]``. When ``hx`` or ``hy`` are physical units + rather than pixel-like units, callers may need to scale ``eps`` + consistently. Returns ------- @@ -252,26 +264,33 @@ def get_motion_tensor_cs(f1, f2, hy, hx): Notes ----- - The census transform is invariant to monotonically increasing grey-level - changes, improving robustness to illumination variations compared to - gray-value or gradient constancy. A smoothed Heaviside approximation - controlled by `eps` stabilizes the derivatives while preserving the - ordering information that drives the census constraint. - - Symmetric padding enforces Neumann boundary conditions, and the border - is zeroed after aggregation to avoid wrap-around effects from the - cyclic shifts used during neighbor comparisons. + The hard census transform is invariant to global monotonically increasing + grey-value transforms because it depends only on ordering. This + implementation uses finite differences, Gaussian-preprocessed inputs, a + smoothed Heaviside function, and a linearized motion tensor, so invariance + is approximate. + + Additive offsets cancel exactly in neighbor-center differences. Positive + multiplicative changes are approximately handled only when ``eps`` is small + relative to the directional-difference scale, or when ``eps`` is scaled + consistently with image intensity scale. Gamma and other nonlinear + monotone transforms preserve hard ordering but not exact smoothed + Heaviside values. References ---------- .. [1] Hafner, D., Demetz, O., and Weickert, J. "Why is the Census Transform Good for Robust Optic Flow Computation?", SSVM 2013. """ - eps = 0.1 + if eps is None: + eps = 0.1 / 255.0 eps2 = eps * eps + H, W = f1.shape f1p = np.pad(f1, ((1, 1), (1, 1)), mode="symmetric") f2p = np.pad(f2, ((1, 1), (1, 1)), mode="symmetric") + center1 = f1p[1:-1, 1:-1] + center2 = f2p[1:-1, 1:-1] offsets = [ (dy, dx) for dy in (-1, 0, 1) for dx in (-1, 0, 1) if not (dy == 0 and dx == 0) @@ -289,22 +308,22 @@ def get_motion_tensor_cs(f1, f2, hy, hx): for dy, dx in offsets: dist = float(np.sqrt((hy * dy) * (hy * dy) + (hx * dx) * (hx * dx))) - r1 = (np.roll(f1p, shift=(-dy, -dx), axis=(0, 1)) - f1p) / dist - r2 = (np.roll(f2p, shift=(-dy, -dx), axis=(0, 1)) - f2p) / dist + neigh1 = f1p[1 + dy : 1 + dy + H, 1 + dx : 1 + dx + W] + neigh2 = f2p[1 + dy : 1 + dy + H, 1 + dx : 1 + dx + W] + + r1_core = (neigh1 - center1) / dist + r2_core = (neigh2 - center2) / dist - s1 = 0.5 * (1.0 + r1 / np.sqrt(r1 * r1 + eps2)) - s2 = 0.5 * (1.0 + r2 / np.sqrt(r2 * r2 + eps2)) + s1_core = 0.5 * (1.0 + r1_core / np.sqrt(r1_core * r1_core + eps2)) + s2_core = 0.5 * (1.0 + r2_core / np.sqrt(r2_core * r2_core + eps2)) - s1[0, :] = s1[1, :] - s1[-1, :] = s1[-2, :] - s1[:, 0] = s1[:, 1] - s1[:, -1] = s1[:, -2] - s2[0, :] = s2[1, :] - s2[-1, :] = s2[-2, :] - s2[:, 0] = s2[:, 1] - s2[:, -1] = s2[:, -2] + s1 = np.pad(s1_core, 1, mode="edge") + s2 = np.pad(s2_core, 1, mode="edge") - sy, sx = np.gradient(s1, hy, hx) + sy1, sx1 = np.gradient(s1, hy, hx) + sy2, sx2 = np.gradient(s2, hy, hx) + sx = 0.5 * (sx1 + sx2) + sy = 0.5 * (sy1 + sy2) st = s2 - s1 J11 += sx * sx @@ -367,6 +386,36 @@ def level_solver( return du, dv +def _resolve_motion_tensor_func(const_assumption): + """ + Resolve a constancy-assumption selector to a motion tensor function. + + The default ``gc`` path is the MATLAB Flow-Registration behavior. Census + and gray-value constancy are explicit opt-in alternatives. + """ + if hasattr(const_assumption, "value"): + const_assumption = const_assumption.value + + key = str(const_assumption).strip().lower() + tensor_funcs = { + "gc": get_motion_tensor_gc, + "gradient": get_motion_tensor_gc, + "gray": get_motion_tensor_gray, + "brightness": get_motion_tensor_gray, + "cs": get_motion_tensor_cs, + "census": get_motion_tensor_cs, + } + + try: + return tensor_funcs[key] + except KeyError as e: + supported = "', '".join(sorted(tensor_funcs)) + raise ValueError( + f"Unknown constancy assumption: '{const_assumption}'. " + f"Supported values are: '{supported}'." + ) from e + + def get_displacement( fixed, moving, @@ -426,7 +475,9 @@ def get_displacement( - a = 0.5: linear (L1) penalty - a = 0.45: sublinear, robust to noisy microscopy data const_assumption : str, default='gc' - Constancy assumption: 'gc' for gradient constancy (only option implemented) + Constancy assumption. Supported values are 'gc'/'gradient' for gradient + constancy, 'gray'/'brightness' for gray-value constancy, and + 'cs'/'census' for census constancy. uv : np.ndarray, optional Initial displacement field (H, W, 2) with [u, v] components to initialize the coarsest (highest) pyramid level. If None, initializes with zeros. @@ -477,6 +528,7 @@ def get_displacement( assert ( fixed.ndim == moving.ndim ), f"Fixed and moving must have same dimensions: fixed.shape={fixed.shape}, moving.shape={moving.shape}" + motion_tensor_func = _resolve_motion_tensor_func(const_assumption) fixed = fixed.astype(np.float64) moving = moving.astype(np.float64) if fixed.ndim == 3: @@ -578,7 +630,7 @@ def get_displacement( J13 = np.zeros(J_size, dtype=np.float64) J23 = np.zeros(J_size, dtype=np.float64) for ch in range(n_channels): - J11_ch, J22_ch, J33_ch, J12_ch, J13_ch, J23_ch = get_motion_tensor_gc( + J11_ch, J22_ch, J33_ch, J12_ch, J13_ch, J23_ch = motion_tensor_func( f1_level[:, :, ch], tmp[:, :, ch], current_hx, current_hy ) J11[:, :, ch] = J11_ch diff --git a/src/pyflowreg/motion_correction/OF_options.py b/src/pyflowreg/motion_correction/OF_options.py index 16008ff..f234a35 100644 --- a/src/pyflowreg/motion_correction/OF_options.py +++ b/src/pyflowreg/motion_correction/OF_options.py @@ -78,6 +78,22 @@ class InterpolationMethod(str, Enum): class ConstancyAssumption(str, Enum): GRAY = "gray" GRADIENT = "gc" + CENSUS = "cs" + + +def _normalize_constancy_assumption_value(v): + """Normalize constancy assumption aliases to serialized option values.""" + if hasattr(v, "value"): + v = v.value + if isinstance(v, str): + aliases = { + "gradient": ConstancyAssumption.GRADIENT.value, + "brightness": ConstancyAssumption.GRAY.value, + "census": ConstancyAssumption.CENSUS.value, + } + key = v.strip().lower() + return aliases.get(key, key) + return v class NamingConvention(str, Enum): @@ -176,7 +192,8 @@ class OFOptions(BaseModel): NamingConvention.DEFAULT, description="Output filename style" ) constancy_assumption: ConstancyAssumption = Field( - ConstancyAssumption.GRADIENT, description="Constancy assumption" + ConstancyAssumption.GRADIENT, + description="Optical-flow data term: 'gc', 'gray', or 'cs'", ) # Backend configuration @@ -278,6 +295,12 @@ def normalize_sigma(cls, v): raise ValueError("Sigma must be [sx,sy,st] or (n_channels, 3)") return v + @field_validator("constancy_assumption", mode="before") + @classmethod + def normalize_constancy_assumption(cls, v): + """Normalize constancy assumption aliases to serialized option values.""" + return _normalize_constancy_assumption_value(v) + @model_validator(mode="after") def validate_and_normalize(self) -> "OFOptions": """Normalize fields and maintain MATLAB parity.""" @@ -691,6 +714,16 @@ def resolve_get_displacement(self) -> Callable: # Priority 3: Registry backend from pyflowreg.core.backend_registry import get_backend + constancy_assumption = _normalize_constancy_assumption_value( + self.constancy_assumption + ) + if self.flow_backend == "diso" and constancy_assumption != "gc": + raise ValueError( + "The 'diso' backend does not support variational constancy " + f"assumption '{constancy_assumption}'. Use " + "flow_backend='flowreg' for 'gray' or 'cs'." + ) + factory = get_backend(self.flow_backend) return factory(**self.backend_params) @@ -706,7 +739,9 @@ def to_dict(self) -> dict: "update_lag": self.update_lag, "a_data": self.a_data, "a_smooth": self.a_smooth, - "const_assumption": self.constancy_assumption.value, # Fixed: use const_assumption for API compatibility + "const_assumption": _normalize_constancy_assumption_value( + self.constancy_assumption + ), } def __repr__(self) -> str: diff --git a/src/pyflowreg/motion_correction/compensate_recording.py b/src/pyflowreg/motion_correction/compensate_recording.py index ac629a8..8ad4d10 100644 --- a/src/pyflowreg/motion_correction/compensate_recording.py +++ b/src/pyflowreg/motion_correction/compensate_recording.py @@ -211,6 +211,34 @@ def _resolve_displacement_func(self): """Resolve the displacement function to use based on options.""" self._get_disp = self.options.resolve_get_displacement() + def _get_flow_params(self) -> Dict[str, Any]: + """Build optical-flow parameters for the configured displacement backend.""" + if hasattr(self.options, "to_dict"): + flow_params = dict(self.options.to_dict()) + else: + flow_params = { + "alpha": self.options.alpha, + "levels": self.options.levels, + "min_level": getattr( + self.options, + "effective_min_level", + getattr(self.options, "min_level", 0), + ), + "eta": self.options.eta, + "update_lag": self.options.update_lag, + "iterations": self.options.iterations, + "a_smooth": self.options.a_smooth, + "a_data": self.options.a_data, + } + const_assumption = getattr(self.options, "constancy_assumption", None) + if const_assumption is not None: + flow_params["const_assumption"] = getattr( + const_assumption, "value", const_assumption + ) + + flow_params["weight"] = self.weight + return flow_params + def register_progress_callback(self, callback: Callable[[int, int], None]) -> None: """ Register a progress callback function. @@ -396,21 +424,7 @@ def _compute_flow_single( w_init: Optional[np.ndarray] = None, ) -> np.ndarray: """Compute flow for a single frame.""" - flow_params = { - "alpha": self.options.alpha, - "weight": self.weight, - "levels": self.options.levels, - "min_level": getattr( - self.options, - "effective_min_level", - getattr(self.options, "min_level", 0), - ), - "eta": self.options.eta, - "update_lag": self.options.update_lag, - "iterations": self.options.iterations, - "a_smooth": self.options.a_smooth, - "a_data": self.options.a_data, - } + flow_params = self._get_flow_params() if w_init is not None: flow_params["uv"] = w_init @@ -433,22 +447,7 @@ def _process_batch_parallel( w_init: Initial displacement field task_id: Task identifier for progress tracking (default: "main") """ - # Build flow parameters dictionary - flow_params = { - "alpha": self.options.alpha, - "weight": self.weight, - "levels": self.options.levels, - "min_level": getattr( - self.options, - "effective_min_level", - getattr(self.options, "min_level", 0), - ), - "eta": self.options.eta, - "update_lag": self.options.update_lag, - "iterations": self.options.iterations, - "a_smooth": self.options.a_smooth, - "a_data": self.options.a_data, - } + flow_params = self._get_flow_params() # Get interpolation method interp_method = getattr(self.options, "interpolation_method", "cubic") diff --git a/src/pyflowreg/session/config.py b/src/pyflowreg/session/config.py index e95c422..a7709b6 100644 --- a/src/pyflowreg/session/config.py +++ b/src/pyflowreg/session/config.py @@ -57,6 +57,9 @@ class SessionConfig(BaseModel): Regularization for inter-sequence optical flow iterations_between : int, default=100 Iterations for inter-sequence optical flow + stage2_constancy_assumption : str, default="gc" + Constancy assumption for Stage 2 optical flow. The default "gc" + preserves the MATLAB Flow-Registration behavior. align_chunk_size : int, default=64 Number of frames to process per batch during Stage 3 video alignment align_output_format : str, default="TIFF" @@ -88,6 +91,7 @@ class SessionConfig(BaseModel): sigma_smooth: float = 6.0 alpha_between: float = 25.0 iterations_between: int = 100 + stage2_constancy_assumption: str = "gc" # Stage 3 parameters align_chunk_size: int = 64 diff --git a/src/pyflowreg/session/stage2_between_avgs.py b/src/pyflowreg/session/stage2_between_avgs.py index 2960eed..e886d84 100644 --- a/src/pyflowreg/session/stage2_between_avgs.py +++ b/src/pyflowreg/session/stage2_between_avgs.py @@ -26,6 +26,32 @@ from pyflowreg.util.xcorr_prealignment import estimate_rigid_xcorr_2d +def normalize_constancy_assumption(value) -> str: + """ + Normalize a constancy-assumption selector to the backend API value. + + The default ``gc`` value is the MATLAB Flow-Registration behavior. Other + data terms are explicit opt-in extensions for the native flowreg backend. + """ + if hasattr(value, "value"): + value = value.value + + key = str(value).strip().lower() + aliases = { + "gradient": "gc", + "brightness": "gray", + "census": "cs", + } + key = aliases.get(key, key) + if key not in {"gc", "gray", "cs"}: + raise ValueError( + f"Unknown constancy assumption: '{value}'. " + "Supported values are: 'gc', 'gradient', 'gray', 'brightness', " + "'cs', and 'census'." + ) + return key + + def mat2gray_ref(img: np.ndarray, ref: np.ndarray = None) -> np.ndarray: """ Normalize image to [0, 1] range. @@ -117,6 +143,16 @@ def compute_between_displacement( img2 = img2.reshape(img2_dims) # Get displacement function based on configured backend + constancy_assumption = normalize_constancy_assumption( + config.stage2_constancy_assumption + ) + if config.flow_backend == "diso" and constancy_assumption != "gc": + raise ValueError( + "The 'diso' backend does not support variational constancy " + f"assumption '{constancy_assumption}'. Use flow_backend='flowreg' " + "for 'gray' or 'cs'." + ) + backend_factory = get_backend(config.flow_backend) get_displacement_func = backend_factory(**config.backend_params) @@ -131,6 +167,7 @@ def compute_between_displacement( img2, alpha=alpha, iterations=config.iterations_between, + const_assumption=constancy_assumption, ) return w + w_init diff --git a/tests/core/test_optical_flow.py b/tests/core/test_optical_flow.py new file mode 100644 index 0000000..5f1d15a --- /dev/null +++ b/tests/core/test_optical_flow.py @@ -0,0 +1,87 @@ +"""Tests for core optical-flow tensor helpers.""" + +import inspect + +import numpy as np + +from pyflowreg.core.optical_flow import get_motion_tensor_cs + + +def _sample_images(shape=(8, 9)): + y = np.linspace(0.0, 1.0, shape[0])[:, np.newaxis] + x = np.linspace(0.0, 1.0, shape[1])[np.newaxis, :] + f1 = 0.25 + 0.35 * x + 0.20 * y + f2 = 0.30 + 0.25 * x * x + 0.15 * y + return f1.astype(np.float64), f2.astype(np.float64) + + +def _assert_zero_border(tensors): + for tensor in tensors: + assert np.array_equal(tensor[0, :], np.zeros_like(tensor[0, :])) + assert np.array_equal(tensor[-1, :], np.zeros_like(tensor[-1, :])) + assert np.array_equal(tensor[:, 0], np.zeros_like(tensor[:, 0])) + assert np.array_equal(tensor[:, -1], np.zeros_like(tensor[:, -1])) + + +def test_get_motion_tensor_cs_shape_and_zero_border(): + """Returned tensors match the solver contract.""" + f1, f2 = _sample_images() + + tensors = get_motion_tensor_cs(f1, f2, hy=1.0, hx=1.0) + + assert len(tensors) == 6 + for tensor in tensors: + assert tensor.shape == (f1.shape[0] + 2, f1.shape[1] + 2) + _assert_zero_border(tensors) + + +def test_get_motion_tensor_cs_constant_images_near_zero(): + """Constant frames should not create census tensor energy.""" + f1 = np.full((7, 6), 0.25, dtype=np.float64) + f2 = np.full((7, 6), 0.75, dtype=np.float64) + + tensors = get_motion_tensor_cs(f1, f2, hy=1.0, hx=1.0) + + for tensor in tensors: + np.testing.assert_allclose(tensor, 0.0, atol=1e-14) + + +def test_get_motion_tensor_cs_additive_shift_invariance(): + """Neighbor-center differences should cancel common additive offsets.""" + f1, f2 = _sample_images() + + base = get_motion_tensor_cs(f1, f2, hy=1.0, hx=1.0) + shifted = get_motion_tensor_cs(f1 + 4.0, f2 + 4.0, hy=1.0, hx=1.0) + + for base_tensor, shifted_tensor in zip(base, shifted): + np.testing.assert_allclose(shifted_tensor, base_tensor, atol=1e-12) + + +def test_get_motion_tensor_cs_default_eps_matches_normalized_convention(): + """Default epsilon should be the normalized 0.1 / 255.0 value.""" + f1, f2 = _sample_images() + + default = get_motion_tensor_cs(f1, f2, hy=1.0, hx=1.0) + explicit = get_motion_tensor_cs(f1, f2, hy=1.0, hx=1.0, eps=0.1 / 255.0) + + for default_tensor, explicit_tensor in zip(default, explicit): + np.testing.assert_allclose(default_tensor, explicit_tensor) + + +def test_get_motion_tensor_cs_does_not_use_np_roll(): + """Neighbor access should not use cyclic shifts.""" + source = inspect.getsource(get_motion_tensor_cs) + + assert "np.roll" not in source + + +def test_get_motion_tensor_cs_anisotropic_spacing(): + """Anisotropic spacing should preserve tensor shape and zero borders.""" + f1, f2 = _sample_images(shape=(6, 10)) + + tensors = get_motion_tensor_cs(f1, f2, hy=0.5, hx=2.0) + + for tensor in tensors: + assert tensor.shape == (f1.shape[0] + 2, f1.shape[1] + 2) + assert np.all(np.isfinite(tensor)) + _assert_zero_border(tensors) diff --git a/tests/motion_correction/test_OF_options.py b/tests/motion_correction/test_OF_options.py index 7efa4d1..14df2f8 100644 --- a/tests/motion_correction/test_OF_options.py +++ b/tests/motion_correction/test_OF_options.py @@ -8,6 +8,7 @@ import numpy as np from pyflowreg.motion_correction.OF_options import ( + ConstancyAssumption, OFOptions, QualitySetting, ) @@ -186,6 +187,42 @@ def test_custom_min_level_override(self): assert opts.effective_min_level == 8 +class TestConstancyAssumption: + """Test optical-flow data term configuration.""" + + @pytest.mark.parametrize( + ("value", "expected"), + [ + ("gc", ConstancyAssumption.GRADIENT), + ("gradient", ConstancyAssumption.GRADIENT), + ("gray", ConstancyAssumption.GRAY), + ("brightness", ConstancyAssumption.GRAY), + ("cs", ConstancyAssumption.CENSUS), + ("census", ConstancyAssumption.CENSUS), + ], + ) + def test_constancy_assumption_aliases(self, value, expected): + """Aliases normalize to the enum value used by get_displacement.""" + opts = OFOptions(constancy_assumption=value) + + assert opts.constancy_assumption == expected + assert opts.to_dict()["const_assumption"] == expected.value + + def test_diso_rejects_non_default_constancy_assumption(self): + """DISO backend should not silently accept flowreg-only data terms.""" + opts = OFOptions(flow_backend="diso", constancy_assumption="census") + + with pytest.raises(ValueError, match="does not support"): + opts.resolve_get_displacement() + + def test_to_dict_normalizes_assignment_alias(self): + """Assignment-time aliases should serialize to backend API values.""" + opts = OFOptions() + opts.constancy_assumption = "census" + + assert opts.to_dict()["const_assumption"] == "cs" + + class TestGetWeightAt: """Test get_weight_at method.""" diff --git a/tests/motion_correction/test_compensate_recording.py b/tests/motion_correction/test_compensate_recording.py index 52e89e2..b2fd13d 100644 --- a/tests/motion_correction/test_compensate_recording.py +++ b/tests/motion_correction/test_compensate_recording.py @@ -13,7 +13,7 @@ RegistrationConfig, compensate_recording, ) -from pyflowreg.motion_correction.OF_options import OutputFormat +from pyflowreg.motion_correction.OF_options import OFOptions, OutputFormat from pyflowreg._runtime import RuntimeContext from pyflowreg.util.io.factory import get_video_file_reader @@ -111,6 +111,26 @@ def test_initialization_with_basic_options(self, basic_of_options): assert len(pipeline.mean_disp) == 0 assert len(pipeline.max_disp) == 0 + def test_flow_params_include_constancy_assumption(self, tmp_path): + """Batch flow calls should receive the configured data term.""" + options = OFOptions( + output_path=tmp_path, + quality_setting="fast", + constancy_assumption="census", + levels=1, + iterations=2, + ) + config = RegistrationConfig( + n_jobs=1, verbose=True, parallelization="sequential" + ) + pipeline = BatchMotionCorrector(options, config) + pipeline.weight = np.ones((8, 8, 1), dtype=np.float64) + + flow_params = pipeline._get_flow_params() + + assert flow_params["const_assumption"] == "cs" + assert flow_params["weight"] is pipeline.weight + class TestExecutorTypes: """Test different executor types work correctly.""" diff --git a/tests/session/test_config.py b/tests/session/test_config.py index 8ed16fe..ee9370b 100644 --- a/tests/session/test_config.py +++ b/tests/session/test_config.py @@ -320,6 +320,7 @@ def test_default_stage2_parameters(self, tmp_path): assert config.sigma_smooth == 6.0 assert config.alpha_between == 25.0 assert config.iterations_between == 100 + assert config.stage2_constancy_assumption == "gc" def test_custom_stage2_parameters(self, tmp_path): """Test setting custom Stage 2 parameters.""" @@ -329,12 +330,14 @@ def test_custom_stage2_parameters(self, tmp_path): sigma_smooth=4.5, alpha_between=20.0, iterations_between=150, + stage2_constancy_assumption="census", ) assert config.cc_upsample == 8 assert config.sigma_smooth == 4.5 assert config.alpha_between == 20.0 assert config.iterations_between == 150 + assert config.stage2_constancy_assumption == "census" class TestSessionConfigBackendParameters: