@@ -370,86 +370,89 @@ def get_objective_function(self, residuals=None, stretch=None):
370370 function = residual_term + regularization_term + sparsity_term
371371 return function
372372
373- def apply_interpolation_matrix (self , components = None , weights = None , stretch = None ):
373+ def compute_stretched_components (self , components = None , weights = None , stretch = None ):
374374 """
375- Applies an interpolation-based transformation to the 'components' using `stretch`,
376- weighted by `weights`. Optionally computes first (`d_stretched_components`) and
377- second (`dd_stretched_components`) derivatives.
375+ Interpolates each component along its sample axis according to per-(component, signal)
376+ stretch factors, then applies per-(component, signal) weights. Also computes the
377+ first and second derivatives with respect to stretch. Left and right, respectively,
378+ refer to the sample prior to and subsequent to the interpolated sample's position.
379+
380+ Inputs
381+ ------
382+ components : array, shape (signal_len, n_components)
383+ Each column is a component with signal_len samples.
384+ weights : array, shape (n_components, n_signals)
385+ Per-(component, signal) weights.
386+ stretch : array, shape (n_components, n_signals)
387+ Per-(component, signal) stretch factors.
388+
389+ Outputs
390+ -------
391+ stretched_components : array, shape (signal_len, n_components * n_signals)
392+ Interpolated and weighted components.
393+ d_stretched_components : array, shape (signal_len, n_components * n_signals)
394+ First derivatives with respect to stretch.
395+ dd_stretched_components : array, shape (signal_len, n_components * n_signals)
396+ Second derivatives with respect to stretch.
378397 """
379398
399+ # --- Defaults ---
380400 if components is None :
381401 components = self .components_
382402 if weights is None :
383403 weights = self .weights_
384404 if stretch is None :
385405 stretch = self .stretch_
386406
387- # Compute scaled indices
388- stretch_flat = stretch .reshape (1 , self .n_signals * self .n_components ) ** - 1
389- stretch_tiled = np .tile (stretch_flat , (self .signal_length , 1 ))
390-
391- # Compute `fractional_indices`
392- fractional_indices = (
393- np .tile (np .arange (self .signal_length )[:, None ], (1 , self .n_signals * self .n_components ))
394- * stretch_tiled
407+ # Dimensions
408+ signal_len = components .shape [0 ] # number of samples
409+ n_components = components .shape [1 ] # number of components
410+ n_signals = weights .shape [1 ] # number of signals
411+
412+ # Guard stretches
413+ eps = 1e-8
414+ stretch = np .clip (stretch , eps , None )
415+ stretch_inv = 1.0 / stretch
416+
417+ # Apply stretching to the original sample indices, represented as a "time-stretch"
418+ t = np .arange (signal_len , dtype = float )[:, None , None ] * stretch_inv [None , :, :]
419+ # has shape (signal_len, n_components, n_signals)
420+
421+ # For each stretched coordinate, find its prior integer (original) index and their difference
422+ i0 = np .floor (t ).astype (np .int64 ) # prior original index
423+ alpha = t - i0 .astype (float ) # fractional distance between left/right
424+
425+ # Clip indices to valid range (0, signal_len - 1) to maintain original size
426+ max_idx = signal_len - 1
427+ i0 = np .clip (i0 , 0 , max_idx )
428+ i1 = np .clip (i0 + 1 , 0 , max_idx )
429+
430+ # Gather sample values
431+ comps_3d = components [:, :, None ] # expand components by a dimension for broadcasting across n_signals
432+ c0 = np .take_along_axis (comps_3d , i0 , axis = 0 ) # left sample values
433+ c1 = np .take_along_axis (comps_3d , i1 , axis = 0 ) # right sample values
434+
435+ # Linear interpolation to determine stretched sample values
436+ interp = c0 * (1.0 - alpha ) + c1 * alpha
437+ interp_weighted = interp * weights [None , :, :]
438+
439+ # Derivatives
440+ di = - t * stretch_inv [None , :, :] # first-derivative coefficient
441+ ddi = - di * stretch_inv [None , :, :] * 2.0 # second-derivative coefficient
442+
443+ d_unweighted = c0 * (- di ) + c1 * di
444+ dd_unweighted = c0 * (- ddi ) + c1 * ddi
445+
446+ d_weighted = d_unweighted * weights [None , :, :]
447+ dd_weighted = dd_unweighted * weights [None , :, :]
448+
449+ # Flatten back to expected shape (signal_len, n_components * n_signals)
450+ return (
451+ interp_weighted .reshape (signal_len , n_components * n_signals ),
452+ d_weighted .reshape (signal_len , n_components * n_signals ),
453+ dd_weighted .reshape (signal_len , n_components * n_signals ),
395454 )
396455
397- # Weighting matrix
398- weights_flat = weights .reshape (1 , self .n_signals * self .n_components )
399- weights_tiled = np .tile (weights_flat , (self .signal_length , 1 ))
400-
401- # Bias for indexing into reshaped components
402- # TODO break this up or describe what it does better
403- bias = np .kron (
404- np .arange (self .n_components ) * (self .signal_length + 1 ),
405- np .ones ((self .signal_length , self .n_signals ), dtype = int ),
406- ).reshape (self .signal_length , self .n_components * self .n_signals )
407-
408- # Handle boundary conditions for interpolation
409- components_bounded = np .vstack (
410- [components , components [- 1 , :]]
411- ) # Duplicate last row (like MATLAB, not sure why)
412-
413- # Compute floor indices
414- floor_indices = np .floor (fractional_indices ).astype (int )
415-
416- floor_indices_1 = np .minimum (floor_indices + 1 , self .signal_length )
417- floor_indices_2 = np .minimum (floor_indices_1 + 1 , self .signal_length )
418-
419- # Compute fractional part
420- fractional_floor_indices = fractional_indices - floor_indices
421-
422- # Compute offset indices
423- offset_indices_1 = floor_indices_1 + bias
424- offset_indices_2 = floor_indices_2 + bias
425-
426- # Extract values
427- # Note: this "-1" corrects an off-by-one error that may have originated in an earlier line
428- comp_values_1 = components_bounded .flatten (order = "F" )[(offset_indices_1 - 1 ).ravel (order = "F" )].reshape (
429- self .signal_length , self .n_components * self .n_signals , order = "F"
430- ) # order = F uses FORTRAN, column major order
431- comp_values_2 = components_bounded .flatten (order = "F" )[(offset_indices_2 - 1 ).ravel (order = "F" )].reshape (
432- self .signal_length , self .n_components * self .n_signals , order = "F"
433- )
434-
435- # Interpolation
436- unweighted_stretched_comps = (
437- comp_values_1 * (1 - fractional_floor_indices ) + comp_values_2 * fractional_floor_indices
438- )
439- stretched_components = unweighted_stretched_comps * weights_tiled # Apply weighting
440-
441- # Compute first derivative
442- di = - fractional_indices * stretch_tiled
443- d_comps_unweighted = comp_values_1 * (- di ) + comp_values_2 * di
444- d_stretched_components = d_comps_unweighted * weights_tiled
445-
446- # Compute second derivative
447- ddi = - di * stretch_tiled * 2
448- dd_comps_unweighted = comp_values_1 * (- ddi ) + comp_values_2 * ddi
449- dd_stretched_components = dd_comps_unweighted * weights_tiled
450-
451- return stretched_components , d_stretched_components , dd_stretched_components
452-
453456 def apply_transformation_matrix (self , stretch = None , weights = None , residuals = None ):
454457 """
455458 Computes the transformation matrix `stretch_transformed` for residuals,
@@ -560,7 +563,7 @@ def update_components(self):
560563 Updates `components` using gradient-based optimization with adaptive step size.
561564 """
562565 # Compute stretched components using the interpolation function
563- stretched_components , _ , _ = self .apply_interpolation_matrix () # Discard the derivatives
566+ stretched_components , _ , _ = self .compute_stretched_components () # Discard the derivatives
564567 # Compute reshaped_stretched_components and component_residuals
565568 intermediate_reshaped = stretched_components .flatten (order = "F" ).reshape (
566569 (self .signal_length * self .n_signals , self .n_components ), order = "F"
@@ -648,7 +651,9 @@ def regularize_function(self, stretch=None):
648651 if stretch is None :
649652 stretch = self .stretch_
650653
651- stretched_components , d_stretch_comps , dd_stretch_comps = self .apply_interpolation_matrix (stretch = stretch )
654+ stretched_components , d_stretch_comps , dd_stretch_comps = self .compute_stretched_components (
655+ stretch = stretch
656+ )
652657 intermediate = stretched_components .flatten (order = "F" ).reshape (
653658 (self .signal_length * self .n_signals , self .n_components ), order = "F"
654659 )
@@ -751,8 +756,8 @@ def reconstruct_matrix(components, weights, stretch):
751756 """
752757
753758 signal_len = components .shape [0 ]
754- n_signals = weights .shape [1 ]
755759 n_components = components .shape [1 ]
760+ n_signals = weights .shape [1 ]
756761
757762 reconstructed_matrix = np .zeros ((signal_len , n_signals ))
758763 sample_indices = np .arange (signal_len )
0 commit comments