@@ -52,10 +52,6 @@ class SNMFOptimizer:
5252 num_updates : int
5353 The total number of times that any of (stretch, components, and weights) have had their values changed.
5454 If not terminated by other means, this value is used to stop when reaching max_iter.
55- objective_function: float
56- The value corresponding to the minimization of the difference between the source_matrix and the
57- products of (stretch, components, and weights). For full details see the sNMF paper. Smaller corresponds to
58- better agreement and is desirable.
5955 objective_difference : float
6056 The change in the objective function value since the last update. A negative value
6157 means that the result improved.
@@ -134,7 +130,7 @@ def __init__(
134130 # Initialize weights and determine number of components
135131 if init_weights is None :
136132 self .n_components = n_components
137- self .weights = self ._rng .beta (a = 2.5 , b = 1.5 , size = (self .n_components , self .n_signals ))
133+ self .weights = self ._rng .beta (a = 2.0 , b = 2.0 , size = (self .n_components , self .n_signals ))
138134 else :
139135 self .n_components = init_weights .shape [0 ]
140136 self .weights = init_weights
@@ -165,20 +161,20 @@ def __init__(
165161
166162 # Set up residual matrix, objective function, and history
167163 self .residuals = self .get_residual_matrix ()
168- self .objective_function = self .get_objective_function ()
164+ self ._objective_history = []
165+ self .update_objective ()
169166 self .objective_difference = None
170- self ._objective_history = [self .objective_function ]
171167
172- # Set up tracking variables for updateX ()
168+ # Set up tracking variables for update_components ()
173169 self ._prev_components = None
174170 self .grad_components = np .zeros_like (self .components ) # Gradient of X (zeros for now)
175171 self ._prev_grad_components = np .zeros_like (self .components ) # Previous gradient of X (zeros for now)
176172
177173 regularization_term = 0.5 * rho * np .linalg .norm (self ._spline_smooth_operator @ self .stretch .T , "fro" ) ** 2
178174 sparsity_term = eta * np .sum (np .sqrt (self .components )) # Square root penalty
179175 print (
180- f"Start, Objective function: { self .objective_function :.5e} "
181- f", Obj - reg/sparse: { self .objective_function - regularization_term - sparsity_term :.5e} "
176+ f"Start, Objective function: { self ._objective_history [ - 1 ] :.5e} "
177+ f", Obj - reg/sparse: { self ._objective_history [ - 1 ] - regularization_term - sparsity_term :.5e} "
182178 )
183179
184180 # Main optimization loop
@@ -191,15 +187,15 @@ def __init__(
191187 sparsity_term = eta * np .sum (np .sqrt (self .components )) # Square root penalty
192188 print (
193189 f"Num_updates: { self .num_updates } , "
194- f"Obj fun: { self .objective_function :.5e} , "
195- f"Obj - reg/sparse: { self .objective_function - regularization_term - sparsity_term :.5e} , "
190+ f"Obj fun: { self ._objective_history [ - 1 ] :.5e} , "
191+ f"Obj - reg/sparse: { self ._objective_history [ - 1 ] - regularization_term - sparsity_term :.5e} , "
196192 f"Iter: { iter } "
197193 )
198194
199195 # Convergence check: decide when to terminate for small/no improvement
200- print (self .objective_difference , " < " , self .objective_function * tol )
201- if self .objective_difference < self .objective_function * tol and iter >= 20 :
196+ if self .objective_difference < self ._objective_history [- 1 ] * tol and iter >= 20 :
202197 break
198+ print (self .objective_difference , " < " , self ._objective_history [- 1 ] * tol )
203199
204200 # Normalize our results
205201 weights_row_max = np .max (self .weights , axis = 1 , keepdims = True )
@@ -214,17 +210,17 @@ def __init__(
214210 self .grad_components = np .zeros_like (self .components )
215211 self ._prev_grad_components = np .zeros_like (self .components )
216212 self .residuals = self .get_residual_matrix ()
217- self .objective_function = self .get_objective_function ()
218213 self .objective_difference = None
219- self ._objective_history = [self .objective_function ]
214+ self ._objective_history = []
215+ self .update_objective ()
220216 for norm_iter in range (100 ):
221217 self .update_components ()
222218 self .residuals = self .get_residual_matrix ()
223- self .objective_function = self . get_objective_function ()
224- print (f"Objective function after normX : { self .objective_function :.5e} " )
225- self ._objective_history .append (self .objective_function )
219+ self .update_objective ()
220+ print (f"Objective function after normalize_components : { self ._objective_history [ - 1 ] :.5e} " )
221+ self ._objective_history .append (self ._objective_history [ - 1 ] )
226222 self .objective_difference = self ._objective_history [- 2 ] - self ._objective_history [- 1 ]
227- if self .objective_difference < self .objective_function * tol and norm_iter >= 20 :
223+ if self .objective_difference < self ._objective_history [ - 1 ] * tol and norm_iter >= 20 :
228224 break
229225 # end of normalization (and program)
230226 # note that objective function may not fully recover after normalization, this is okay
@@ -233,30 +229,33 @@ def __init__(
233229 def optimize_loop (self ):
234230 # Update components first
235231 self ._prev_grad_components = self .grad_components .copy ()
232+
236233 self .update_components ()
234+
237235 self .num_updates += 1
238236 self .residuals = self .get_residual_matrix ()
239- self .objective_function = self . get_objective_function ()
240- print (f"Objective function after update_components: { self .objective_function :.5e} " )
241- self . _objective_history . append ( self . objective_function )
237+ self .update_objective ()
238+ print (f"Objective function after update_components: { self ._objective_history [ - 1 ] :.5e} " )
239+
242240 if self .objective_difference is None :
243- self .objective_difference = self ._objective_history [- 1 ] - self .objective_function
241+ self .objective_difference = self ._objective_history [- 2 ] - self ._objective_history [ - 1 ]
244242
245243 # Now we update weights
246244 self .update_weights ()
245+
247246 self .num_updates += 1
248247 self .residuals = self .get_residual_matrix ()
249- self .objective_function = self .get_objective_function ()
250- print (f"Objective function after update_weights: { self .objective_function :.5e} " )
251- self ._objective_history .append (self .objective_function )
248+ self .update_objective ()
249+ print (f"Objective function after update_weights: { self ._objective_history [- 1 ]:.5e} " )
252250
253251 # Now we update stretch
254252 self .update_stretch ()
253+
255254 self .num_updates += 1
256255 self .residuals = self .get_residual_matrix ()
257- self .objective_function = self . get_objective_function ()
258- print (f"Objective function after update_stretch: { self .objective_function :.5e} " )
259- self . _objective_history . append ( self . objective_function )
256+ self .update_objective ()
257+ print (f"Objective function after update_stretch: { self ._objective_history [ - 1 ] :.5e} " )
258+
260259 self .objective_difference = self ._objective_history [- 2 ] - self ._objective_history [- 1 ]
261260
262261 def apply_interpolation (self , a , x , return_derivatives = False ):
@@ -328,7 +327,8 @@ def get_residual_matrix(self, components=None, weights=None, stretch=None):
328327 residuals += weights [k , :] * stretched_components # Element-wise scaling and sum
329328 return residuals
330329
331- def get_objective_function (self , residuals = None , stretch = None ):
330+ def update_objective (self , residuals = None , stretch = None ):
331+ to_return = not (residuals is None and stretch is None )
332332 if residuals is None :
333333 residuals = self .residuals
334334 if stretch is None :
@@ -338,7 +338,11 @@ def get_objective_function(self, residuals=None, stretch=None):
338338 sparsity_term = self .eta * np .sum (np .sqrt (self .components )) # Square root penalty
339339 # Final objective function value
340340 function = residual_term + regularization_term + sparsity_term
341- return function
341+
342+ if to_return :
343+ return function # Get value directly for use
344+ else :
345+ self ._objective_history .append (function ) # Store value
342346
343347 def apply_interpolation_matrix (self , components = None , weights = None , stretch = None , return_derivatives = False ):
344348 """
@@ -590,7 +594,7 @@ def update_components(self):
590594 )
591595 self .components = mask * self .components
592596
593- objective_improvement = self ._objective_history [- 1 ] - self .get_objective_function (
597+ objective_improvement = self ._objective_history [- 1 ] - self .update_objective (
594598 residuals = self .get_residual_matrix ()
595599 )
596600
@@ -645,7 +649,7 @@ def regularize_function(self, stretch=None):
645649 stretch_difference = stretch_difference - self .source_matrix
646650
647651 # Compute objective function
648- reg_func = self .get_objective_function (stretch_difference , stretch )
652+ reg_func = self .update_objective (stretch_difference , stretch )
649653
650654 # Compute gradient
651655 tiled_derivative = np .sum (
0 commit comments