diff --git a/nitransforms/nonlinear.py b/nitransforms/nonlinear.py index 1f5b14f..0cb40f3 100644 --- a/nitransforms/nonlinear.py +++ b/nitransforms/nonlinear.py @@ -36,9 +36,9 @@ class DenseFieldTransform(TransformBase): """Represents dense field (voxel-wise) transforms.""" - __slots__ = ("_field", "_deltas", "_is_deltas") + __slots__ = ("_field", "_deltas", "_is_deltas", "_filtered_field") - def __init__(self, field=None, is_deltas=True, reference=None): + def __init__(self, field=None, is_deltas=True, reference=None, do_prefilter=True): """ Create a dense field transform. @@ -107,6 +107,17 @@ def __init__(self, field=None, is_deltas=True, reference=None): else: self._field = _data.copy() + self._filtered_field = None + if do_prefilter: + # pre-cache filtered field to accelerate later mapping + from scipy.ndimage import spline_filter + for i in range(self.reference.ndim): + filtered_field_i = spline_filter(self._field[..., i], order=3, output=np.float64, mode='constant') + if self._filtered_field is None: + self._filtered_field = np.repeat(filtered_field_i[..., np.newaxis], self.reference.ndim, axis=-1) + else: + self._filtered_field[..., i] = filtered_field_i + def __repr__(self): """Beautify the python representation.""" return f"<{self.__class__.__name__}[{self._field.shape[-1]}D] {self._field.shape[:3]}>" @@ -193,12 +204,12 @@ def map(self, x, inverse=False): mapped_coords = np.vstack( tuple( map_coordinates( - self._field[..., i], + self._field[..., i] if self._filtered_field is None else self._filtered_field[..., i], ijk.T, order=3, mode="constant", cval=np.nan, - prefilter=True, + prefilter=self._filtered_field is None, ) for i in range(self.reference.ndim) )