@@ -303,18 +303,48 @@ def outer_loop(self):
303303 )
304304
305305 def get_residual_matrix (self , components = None , weights = None , stretch = None ):
306- # Initialize residual matrix as negative of source_matrix
306+ """
307+ Return the residuals (difference) between the source matrix and its reconstruction
308+ from the given components, weights, and stretch factors.
309+
310+ Each component profile is stretched, interpolated to fractional positions,
311+ weighted per signal, and summed to form the reconstruction. The residuals
312+ are the source matrix minus this reconstruction.
313+
314+ Parameters
315+ ----------
316+ components : (signal_len, n_components) array, optional
317+ weights : (n_components, n_signals) array, optional
318+ stretch : (n_components, n_signals) array, optional
319+
320+ Returns
321+ -------
322+ residuals : (signal_len, n_signals) array
323+ """
324+
307325 if components is None :
308326 components = self .components
309327 if weights is None :
310328 weights = self .weights
311329 if stretch is None :
312330 stretch = self .stretch
331+
313332 residuals = - self .source_matrix .copy ()
314- # Compute transformed components for all (k, m) pairs
315- for k in range (weights .shape [0 ]): # K
316- stretched_components , _ , _ = apply_interpolation (stretch [k , :], components [:, k ]) # Only use Ax
317- residuals += weights [k , :] * stretched_components # Element-wise scaling and sum
333+ sample_indices = np .arange (components .shape [0 ]) # (signal_len,)
334+
335+ for comp in range (components .shape [1 ]): # loop over components
336+ residuals += (
337+ np .interp (
338+ sample_indices [:, None ]
339+ / stretch [comp ][None , :], # fractional positions (signal_len, n_signals)
340+ sample_indices , # (signal_len,)
341+ components [:, comp ], # component profile (signal_len,)
342+ left = components [0 , comp ],
343+ right = components [- 1 , comp ],
344+ )
345+ * weights [comp ][None , :] # broadcast (n_signals,) over rows
346+ )
347+
318348 return residuals
319349
320350 def get_objective_function (self , residuals = None , stretch = None ):
@@ -579,42 +609,47 @@ def update_components(self):
579609
580610 def update_weights (self ):
581611 """
582- Updates weights using matrix operations, solving a quadratic program to do so.
612+ Updates weights by building the stretched component matrix `stretched_comps` with np.interp
613+ and solving a quadratic program for each signal.
583614 """
584615
585- signal_length = self .signal_length
586- n_signals = self .n_signals
587-
588- for m in range (n_signals ):
589- t = np .zeros ((signal_length , self .n_components ))
590-
591- # Populate t using apply_interpolation
592- for k in range (self .n_components ):
593- t [:, k ] = apply_interpolation (self .stretch [k , m ], self .components [:, k ])[0 ].squeeze ()
594-
595- # Solve quadratic problem for y
596- y = self .solve_quadratic_program (t = t , m = m )
616+ sample_indices = np .arange (self .signal_length )
617+ for signal in range (self .n_signals ):
618+ # Stretch factors for this signal across components:
619+ this_stretch = self .stretch [:, signal ]
620+ # Build stretched_comps[:, k] by interpolating component at frac. pos. index / this_stretch[comp]
621+ stretched_comps = np .empty ((self .signal_length , self .n_components ), dtype = self .components .dtype )
622+ for comp in range (self .n_components ):
623+ pos = sample_indices / this_stretch [comp ]
624+ stretched_comps [:, comp ] = np .interp (
625+ pos ,
626+ sample_indices ,
627+ self .components [:, comp ],
628+ left = self .components [0 , comp ],
629+ right = self .components [- 1 , comp ],
630+ )
597631
598- # Update Y
599- self .weights [:, m ] = y
632+ # Solve quadratic problem for a given signal and update its weight
633+ new_weight = self .solve_quadratic_program (t = stretched_comps , m = signal )
634+ self .weights [:, signal ] = new_weight
600635
601636 def regularize_function (self , stretch = None ):
602637 if stretch is None :
603638 stretch = self .stretch
604639
605- K = self .n_components
606- M = self .n_signals
607- N = self .signal_length
608-
609640 stretched_components , d_stretch_comps , dd_stretch_comps = self .apply_interpolation_matrix (stretch = stretch )
610- intermediate = stretched_components .flatten (order = "F" ).reshape ((N * M , K ), order = "F" )
611- residuals = intermediate .sum (axis = 1 ).reshape ((N , M ), order = "F" ) - self .source_matrix
641+ intermediate = stretched_components .flatten (order = "F" ).reshape (
642+ (self .signal_length * self .n_signals , self .n_components ), order = "F"
643+ )
644+ residuals = (
645+ intermediate .sum (axis = 1 ).reshape ((self .signal_length , self .n_signals ), order = "F" ) - self .source_matrix
646+ )
612647
613648 fun = self .get_objective_function (residuals , stretch )
614649
615- tiled_res = np .tile (residuals , (1 , K ))
650+ tiled_res = np .tile (residuals , (1 , self . n_components ))
616651 grad_flat = np .sum (d_stretch_comps * tiled_res , axis = 0 )
617- gra = grad_flat .reshape ((M , K ), order = "F" ).T
652+ gra = grad_flat .reshape ((self . n_signals , self . n_components ), order = "F" ).T
618653 gra += self .rho * stretch @ (self ._spline_smooth_operator .T @ self ._spline_smooth_operator )
619654
620655 # Hessian would go here
@@ -623,10 +658,10 @@ def regularize_function(self, stretch=None):
623658
624659 def update_stretch (self ):
625660 """
626- Updates matrix A using constrained optimization (equivalent to fmincon in MATLAB).
661+ Updates stretching matrix using constrained optimization (equivalent to fmincon in MATLAB).
627662 """
628663
629- # Flatten A for compatibility with the optimizer (since SciPy expects 1D inputs )
664+ # Flatten stretch for compatibility with the optimizer (since SciPy expects 1D input )
630665 stretch_flat_initial = self .stretch .flatten ()
631666
632667 # Define the optimization function
@@ -648,7 +683,7 @@ def objective(stretch_vec):
648683 bounds = bounds ,
649684 )
650685
651- # Update A with the optimized values
686+ # Update stretch with the optimized values
652687 self .stretch = result .x .reshape (self .stretch .shape )
653688
654689
@@ -683,48 +718,3 @@ def cubic_largest_real_root(p, q):
683718 y = np .max (real_roots , axis = 0 ) * (delta < 0 ) # Keep only real roots when delta < 0
684719
685720 return y
686-
687-
688- def apply_interpolation (a , x ):
689- """
690- Applies an interpolation-based transformation to `x` based on scaling `a`.
691- Also computes first (`d_intr_x`) and second (`dd_intr_x`) derivatives.
692- """
693- x_len = len (x )
694-
695- # Ensure `a` is an array and reshape for broadcasting
696- a = np .atleast_1d (np .asarray (a )) # Ensures a is at least 1D
697-
698- # Compute fractional indices, broadcasting over `a`
699- fractional_indices = np .arange (x_len )[:, None ] / a # Shape (N, M)
700-
701- integer_indices = np .floor (fractional_indices ).astype (int ) # Integer part (still (N, M))
702- valid_mask = integer_indices < (x_len - 1 ) # Ensure indices are within bounds
703-
704- # Apply valid_mask to keep correct indices
705- idx_int = np .where (valid_mask , integer_indices , x_len - 2 ) # Prevent out-of-bounds indexing (previously "I")
706- idx_frac = np .where (valid_mask , fractional_indices , integer_indices ) # Keep aligned (previously "i")
707-
708- # Ensure x is a 1D array
709- x = np .asarray (x ).ravel ()
710-
711- # Compute interpolated_x (linear interpolation)
712- interpolated_x = x [idx_int ] * (1 - idx_frac + idx_int ) + x [np .minimum (idx_int + 1 , x_len - 1 )] * (
713- idx_frac - idx_int
714- )
715-
716- # Fill the tail with the last valid value
717- intr_x_tail = np .full ((x_len - len (idx_int ), interpolated_x .shape [1 ]), interpolated_x [- 1 , :])
718- interpolated_x = np .vstack ([interpolated_x , intr_x_tail ])
719-
720- # Compute first derivative (d_intr_x)
721- di = - idx_frac / a
722- d_intr_x = x [idx_int ] * (- di ) + x [np .minimum (idx_int + 1 , x_len - 1 )] * di
723- d_intr_x = np .vstack ([d_intr_x , np .zeros ((x_len - len (idx_int ), d_intr_x .shape [1 ]))])
724-
725- # Compute second derivative (dd_intr_x)
726- ddi = - di / a + idx_frac * a ** - 2
727- dd_intr_x = x [idx_int ] * (- ddi ) + x [np .minimum (idx_int + 1 , x_len - 1 )] * ddi
728- dd_intr_x = np .vstack ([dd_intr_x , np .zeros ((x_len - len (idx_int ), dd_intr_x .shape [1 ]))])
729-
730- return interpolated_x , d_intr_x , dd_intr_x
0 commit comments