33from scipy .optimize import minimize
44from scipy .sparse import block_diag , coo_matrix , diags
55
6- # from scipy.sparse import csr_matrix, spdiags (needed for hessian once fixed)
7-
86
97class SNMFOptimizer :
108 def __init__ (self , MM , Y0 = None , X0 = None , A = None , rho = 1e12 , eta = 610 , maxiter = 300 , components = None ):
@@ -67,6 +65,7 @@ def __init__(self, MM, Y0=None, X0=None, A=None, rho=1e12, eta=610, maxiter=300,
6765 f", Obj - reg/sparse: { self .objective_function - regularization_term - sparsity_term :.5e} "
6866 )
6967
68+ # Main optimization loop
7069 for outiter in range (self .maxiter ):
7170 self .outiter = outiter
7271 self .outer_loop ()
@@ -81,10 +80,18 @@ def __init__(self, MM, Y0=None, X0=None, A=None, rho=1e12, eta=610, maxiter=300,
8180 )
8281
8382 # Convergence check: Stop if diffun is small and at least 20 iterations have passed
84- # This check is not working, so have temporarily set 20->120 instead
85- if self .objective_difference < self .objective_function * 1e-6 and outiter >= 120 :
83+ print ( self . objective_difference , " < " , self . objective_function * 1e-6 )
84+ if self .objective_difference < self .objective_function * 1e-6 and outiter >= 20 :
8685 break
8786
87+ # Normalize our results
88+ Y_row_max = np .max (self .Y , axis = 1 , keepdims = True )
89+ self .Y = self .Y / Y_row_max
90+ A_row_max = np .max (self .A , axis = 1 , keepdims = True )
91+ self .A = self .A / A_row_max
92+ # TODO loop to normalize X (currently not normalized)
93+ # effectively just re-running class with non-normalized X, normalized Y/A as inputs, then only update X
94+
8895 def outer_loop (self ):
8996 # This inner loop runs up to four times per outer loop, making updates to X, Y
9097 for iter in range (4 ):
@@ -108,25 +115,19 @@ def outer_loop(self):
108115 self .objective_history .append (self .objective_function )
109116
110117 # Check whether to break out early
118+ # TODO this condition has not been tested, and may have issues
111119 if len (self .objective_history ) >= 3 : # Ensure at least 3 values exist
112120 if self .objective_history [- 3 ] - self .objective_function < self .objective_difference * 1e-3 :
113121 break # Stop if improvement is too small
114122
115- if self .outiter == 0 :
116- print ("Testing regularize_function:" )
117- test_fun , test_gra , test_hess = self .regularize_function ()
118- print (f"Fun: { test_fun :.5e} " )
119- np .savetxt ("output/py_test_gra.txt" , test_gra , fmt = "%.8g" , delimiter = " " )
120- np .savetxt ("output/py_test_hess.txt" , test_hess , fmt = "%.8g" , delimiter = " " )
121-
122123 self .updateA2 ()
123124
124125 self .num_updates += 1
125126 self .R = self .get_residual_matrix ()
126127 self .objective_function = self .get_objective_function ()
127128 print (f"Objective function after updateA2: { self .objective_function :.5e} " )
128129 self .objective_history .append (self .objective_function )
129- self .objective_difference = self .objective_history [- 1 ] - self .objective_function
130+ self .objective_difference = self .objective_history [- 2 ] - self .objective_history [ - 1 ]
130131
131132 def apply_interpolation (self , a , x ):
132133 """
0 commit comments