@@ -372,88 +372,87 @@ def get_objective_function(self, residuals=None, stretch=None):
372372
373373 def apply_interpolation_matrix (self , components = None , weights = None , stretch = None ):
374374 """
375- Applies an interpolation-based transformation to the 'components' using `stretch`,
376- weighted by `weights`. Optionally computes first (`d_stretched_components`) and
377- second (`dd_stretched_components`) derivatives.
375+ Interpolates each component along its sample axis according to per-(component, signal)
376+ stretch factors, then applies per-(component, signal) weights. Also computes the
377+ first and second derivatives with respect to stretch. Left and right, respectively,
378+ refer to the sample prior to and subsequent to the interpolated sample's position.
379+
380+ Inputs
381+ ------
382+ components : array, shape (signal_len, n_components)
383+ Each column is a component with signal_len samples.
384+ weights : array, shape (n_components, n_signals)
385+ Per-(component, signal) weights.
386+ stretch : array, shape (n_components, n_signals)
387+ Per-(component, signal) stretch factors.
388+
389+ Outputs
390+ -------
391+ stretched_components : array, shape (signal_len, n_components * n_signals)
392+ Interpolated and weighted components.
393+ d_stretched_components : array, shape (signal_len, n_components * n_signals)
394+ First derivatives with respect to stretch.
395+ dd_stretched_components : array, shape (signal_len, n_components * n_signals)
396+ Second derivatives with respect to stretch.
378397 """
379398
399+ # --- Defaults ---
380400 if components is None :
381401 components = self .components_
382402 if weights is None :
383403 weights = self .weights_
384404 if stretch is None :
385405 stretch = self .stretch_
386406
387- eps = 1e-8 # guard against divide by zero/NaN stretches
388- stretch = np .maximum (stretch , eps )
389-
390- # Compute scaled indices
391- stretch_flat = stretch .reshape (1 , self .n_signals * self .n_components ) ** - 1
392-
393- # Compute `fractional_indices`
394- fractional_indices = np .arange (self .signal_length )[:, None ] * stretch_flat
395-
396- # Weighting matrix
397- weights_flat = weights .reshape (1 , self .n_signals * self .n_components )
398-
399- # Bias for indexing into reshaped components
400- # TODO break this up or describe what it does better
401- bias = np .kron (
402- np .arange (self .n_components ) * (self .signal_length + 1 ),
403- np .ones ((self .signal_length , self .n_signals ), dtype = int ),
404- ).reshape (self .signal_length , self .n_components * self .n_signals )
405-
406- # Handle boundary conditions for interpolation
407- components_bounded = np .vstack (
408- [components , components [- 1 , :]]
409- ) # Duplicate last row (like MATLAB, not sure why)
410-
411- # Compute floor indices
412- floor_indices = np .floor (fractional_indices ).astype (int )
413-
414- floor_indices_1 = np .minimum (floor_indices + 1 , self .signal_length )
415- floor_indices_2 = np .minimum (floor_indices_1 + 1 , self .signal_length )
416-
417- # Compute fractional part
418- fractional_floor_indices = fractional_indices - floor_indices
419-
420- # Compute offset indices
421- offset_indices_1 = floor_indices_1 + bias
422- offset_indices_2 = floor_indices_2 + bias
423-
424- # Flatten components once (Fortran order, column-major)
425- components_bounded_flat = components_bounded .ravel (order = "F" )
426-
427- # Pre-compute flattened indices
428- offset_indices_1_flat = (offset_indices_1 - 1 ).ravel (order = "F" )
429- offset_indices_2_flat = (offset_indices_2 - 1 ).ravel (order = "F" )
430-
431- # Extract values using pre-flattened arrays
432- comp_values_1 = components_bounded_flat [offset_indices_1_flat ].reshape (
433- self .signal_length , self .n_components * self .n_signals , order = "F"
434- )
435- comp_values_2 = components_bounded_flat [offset_indices_2_flat ].reshape (
436- self .signal_length , self .n_components * self .n_signals , order = "F"
407+ # Dimensions
408+ signal_len = components .shape [0 ] # number of samples
409+ n_components = components .shape [1 ] # number of components
410+ n_signals = weights .shape [1 ] # number of signals
411+
412+ # Guard stretches
413+ eps = 1e-8
414+ stretch = np .clip (stretch , eps , None )
415+ stretch_inv = 1.0 / stretch
416+
417+ # Apply stretching to the original sample indices, represented as a "time-stretch"
418+ t = np .arange (signal_len , dtype = float )[:, None , None ] * stretch_inv [None , :, :]
419+ # has shape (signal_len, n_components, n_signals)
420+
421+ # For each stretched coordinate, find its prior integer (original) index and their difference
422+ i0 = np .floor (t ).astype (np .int64 ) # prior original index
423+ alpha = t - i0 .astype (float ) # fractional distance between left/right
424+
425+ # Clip indices to valid range (0, signal_len - 1) to maintain original size
426+ max_idx = signal_len - 1
427+ i0 = np .clip (i0 , 0 , max_idx )
428+ i1 = np .clip (i0 + 1 , 0 , max_idx )
429+
430+ # Gather sample values
431+ comps_3d = components [:, :, None ] # expand components by a dimension for broadcasting across n_signals
432+ c0 = np .take_along_axis (comps_3d , i0 , axis = 0 ) # left sample values
433+ c1 = np .take_along_axis (comps_3d , i1 , axis = 0 ) # right sample values
434+
435+ # Linear interpolation to determine stretched sample values
436+ interp = c0 * (1.0 - alpha ) + c1 * alpha
437+ interp_weighted = interp * weights [None , :, :]
438+
439+ # Derivatives
440+ di = - t * stretch_inv [None , :, :] # first-derivative coefficient
441+ ddi = - di * stretch_inv [None , :, :] * 2.0 # second-derivative coefficient
442+
443+ d_unweighted = c0 * (- di ) + c1 * di
444+ dd_unweighted = c0 * (- ddi ) + c1 * ddi
445+
446+ d_weighted = d_unweighted * weights [None , :, :]
447+ dd_weighted = dd_unweighted * weights [None , :, :]
448+
449+ # Flatten back to expected shape (signal_len, n_components * n_signals)
450+ return (
451+ interp_weighted .reshape (signal_len , n_components * n_signals ),
452+ d_weighted .reshape (signal_len , n_components * n_signals ),
453+ dd_weighted .reshape (signal_len , n_components * n_signals ),
437454 )
438455
439- # Interpolation
440- unweighted_stretched_comps = (
441- comp_values_1 * (1 - fractional_floor_indices ) + comp_values_2 * fractional_floor_indices
442- )
443- stretched_components = unweighted_stretched_comps * weights_flat # Apply weighting
444-
445- # Compute first derivative
446- di = - fractional_indices * stretch_flat
447- d_comps_unweighted = comp_values_1 * (- di ) + comp_values_2 * di
448- d_stretched_components = d_comps_unweighted * weights_flat
449-
450- # Compute second derivative
451- ddi = - di * stretch_flat * 2
452- dd_comps_unweighted = comp_values_1 * (- ddi ) + comp_values_2 * ddi
453- dd_stretched_components = dd_comps_unweighted * weights_flat
454-
455- return stretched_components , d_stretched_components , dd_stretched_components
456-
457456 def apply_transformation_matrix (self , stretch = None , weights = None , residuals = None ):
458457 """
459458 Computes the transformation matrix `stretch_transformed` for residuals,
0 commit comments