Skip to content

Commit 52ab77c

Browse files
author
John Halloran
committed
refactor: drastically simplify indexing in apply_interpolation_matrix() and remove legacy MATLAB terminology
1 parent 6c1f8ec commit 52ab77c

File tree

1 file changed

+70
-71
lines changed

1 file changed

+70
-71
lines changed

src/diffpy/snmf/snmf_class.py

Lines changed: 70 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -372,88 +372,87 @@ def get_objective_function(self, residuals=None, stretch=None):
372372

373373
def apply_interpolation_matrix(self, components=None, weights=None, stretch=None):
374374
"""
375-
Applies an interpolation-based transformation to the 'components' using `stretch`,
376-
weighted by `weights`. Optionally computes first (`d_stretched_components`) and
377-
second (`dd_stretched_components`) derivatives.
375+
Interpolates each component along its sample axis according to per-(component, signal)
376+
stretch factors, then applies per-(component, signal) weights. Also computes the
377+
first and second derivatives with respect to stretch. Left and right, respectively,
378+
refer to the sample prior to and subsequent to the interpolated sample's position.
379+
380+
Inputs
381+
------
382+
components : array, shape (signal_len, n_components)
383+
Each column is a component with signal_len samples.
384+
weights : array, shape (n_components, n_signals)
385+
Per-(component, signal) weights.
386+
stretch : array, shape (n_components, n_signals)
387+
Per-(component, signal) stretch factors.
388+
389+
Outputs
390+
-------
391+
stretched_components : array, shape (signal_len, n_components * n_signals)
392+
Interpolated and weighted components.
393+
d_stretched_components : array, shape (signal_len, n_components * n_signals)
394+
First derivatives with respect to stretch.
395+
dd_stretched_components : array, shape (signal_len, n_components * n_signals)
396+
Second derivatives with respect to stretch.
378397
"""
379398

399+
# --- Defaults ---
380400
if components is None:
381401
components = self.components_
382402
if weights is None:
383403
weights = self.weights_
384404
if stretch is None:
385405
stretch = self.stretch_
386406

387-
eps = 1e-8 # guard against divide by zero/NaN stretches
388-
stretch = np.maximum(stretch, eps)
389-
390-
# Compute scaled indices
391-
stretch_flat = stretch.reshape(1, self.n_signals * self.n_components) ** -1
392-
393-
# Compute `fractional_indices`
394-
fractional_indices = np.arange(self.signal_length)[:, None] * stretch_flat
395-
396-
# Weighting matrix
397-
weights_flat = weights.reshape(1, self.n_signals * self.n_components)
398-
399-
# Bias for indexing into reshaped components
400-
# TODO break this up or describe what it does better
401-
bias = np.kron(
402-
np.arange(self.n_components) * (self.signal_length + 1),
403-
np.ones((self.signal_length, self.n_signals), dtype=int),
404-
).reshape(self.signal_length, self.n_components * self.n_signals)
405-
406-
# Handle boundary conditions for interpolation
407-
components_bounded = np.vstack(
408-
[components, components[-1, :]]
409-
) # Duplicate last row (like MATLAB, not sure why)
410-
411-
# Compute floor indices
412-
floor_indices = np.floor(fractional_indices).astype(int)
413-
414-
floor_indices_1 = np.minimum(floor_indices + 1, self.signal_length)
415-
floor_indices_2 = np.minimum(floor_indices_1 + 1, self.signal_length)
416-
417-
# Compute fractional part
418-
fractional_floor_indices = fractional_indices - floor_indices
419-
420-
# Compute offset indices
421-
offset_indices_1 = floor_indices_1 + bias
422-
offset_indices_2 = floor_indices_2 + bias
423-
424-
# Flatten components once (Fortran order, column-major)
425-
components_bounded_flat = components_bounded.ravel(order="F")
426-
427-
# Pre-compute flattened indices
428-
offset_indices_1_flat = (offset_indices_1 - 1).ravel(order="F")
429-
offset_indices_2_flat = (offset_indices_2 - 1).ravel(order="F")
430-
431-
# Extract values using pre-flattened arrays
432-
comp_values_1 = components_bounded_flat[offset_indices_1_flat].reshape(
433-
self.signal_length, self.n_components * self.n_signals, order="F"
434-
)
435-
comp_values_2 = components_bounded_flat[offset_indices_2_flat].reshape(
436-
self.signal_length, self.n_components * self.n_signals, order="F"
407+
# Dimensions
408+
signal_len = components.shape[0] # number of samples
409+
n_components = components.shape[1] # number of components
410+
n_signals = weights.shape[1] # number of signals
411+
412+
# Guard stretches
413+
eps = 1e-8
414+
stretch = np.clip(stretch, eps, None)
415+
stretch_inv = 1.0 / stretch
416+
417+
# Apply stretching to the original sample indices, represented as a "time-stretch"
418+
t = np.arange(signal_len, dtype=float)[:, None, None] * stretch_inv[None, :, :]
419+
# has shape (signal_len, n_components, n_signals)
420+
421+
# For each stretched coordinate, find its prior integer (original) index and their difference
422+
i0 = np.floor(t).astype(np.int64) # prior original index
423+
alpha = t - i0.astype(float) # fractional distance between left/right
424+
425+
# Clip indices to valid range (0, signal_len - 1) to maintain original size
426+
max_idx = signal_len - 1
427+
i0 = np.clip(i0, 0, max_idx)
428+
i1 = np.clip(i0 + 1, 0, max_idx)
429+
430+
# Gather sample values
431+
comps_3d = components[:, :, None] # expand components by a dimension for broadcasting across n_signals
432+
c0 = np.take_along_axis(comps_3d, i0, axis=0) # left sample values
433+
c1 = np.take_along_axis(comps_3d, i1, axis=0) # right sample values
434+
435+
# Linear interpolation to determine stretched sample values
436+
interp = c0 * (1.0 - alpha) + c1 * alpha
437+
interp_weighted = interp * weights[None, :, :]
438+
439+
# Derivatives
440+
di = -t * stretch_inv[None, :, :] # first-derivative coefficient
441+
ddi = -di * stretch_inv[None, :, :] * 2.0 # second-derivative coefficient
442+
443+
d_unweighted = c0 * (-di) + c1 * di
444+
dd_unweighted = c0 * (-ddi) + c1 * ddi
445+
446+
d_weighted = d_unweighted * weights[None, :, :]
447+
dd_weighted = dd_unweighted * weights[None, :, :]
448+
449+
# Flatten back to expected shape (signal_len, n_components * n_signals)
450+
return (
451+
interp_weighted.reshape(signal_len, n_components * n_signals),
452+
d_weighted.reshape(signal_len, n_components * n_signals),
453+
dd_weighted.reshape(signal_len, n_components * n_signals),
437454
)
438455

439-
# Interpolation
440-
unweighted_stretched_comps = (
441-
comp_values_1 * (1 - fractional_floor_indices) + comp_values_2 * fractional_floor_indices
442-
)
443-
stretched_components = unweighted_stretched_comps * weights_flat # Apply weighting
444-
445-
# Compute first derivative
446-
di = -fractional_indices * stretch_flat
447-
d_comps_unweighted = comp_values_1 * (-di) + comp_values_2 * di
448-
d_stretched_components = d_comps_unweighted * weights_flat
449-
450-
# Compute second derivative
451-
ddi = -di * stretch_flat * 2
452-
dd_comps_unweighted = comp_values_1 * (-ddi) + comp_values_2 * ddi
453-
dd_stretched_components = dd_comps_unweighted * weights_flat
454-
455-
return stretched_components, d_stretched_components, dd_stretched_components
456-
457456
def apply_transformation_matrix(self, stretch=None, weights=None, residuals=None):
458457
"""
459458
Computes the transformation matrix `stretch_transformed` for residuals,

0 commit comments

Comments
 (0)