Skip to content

Commit d194af5

Browse files
john-halloranJohn Halloran
andauthored
Fixes to class attributes and style (#163)
* fix: use symmetric initial phase fractions * style: don't store objective function in a class attribute, just use history --------- Co-authored-by: John Halloran <[email protected]>
1 parent 27ea989 commit d194af5

File tree

1 file changed

+38
-34
lines changed

1 file changed

+38
-34
lines changed

src/diffpy/snmf/snmf_class.py

Lines changed: 38 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,6 @@ class SNMFOptimizer:
5252
num_updates : int
5353
The total number of times that any of (stretch, components, and weights) have had their values changed.
5454
If not terminated by other means, this value is used to stop when reaching max_iter.
55-
objective_function: float
56-
The value corresponding to the minimization of the difference between the source_matrix and the
57-
products of (stretch, components, and weights). For full details see the sNMF paper. Smaller corresponds to
58-
better agreement and is desirable.
5955
objective_difference : float
6056
The change in the objective function value since the last update. A negative value
6157
means that the result improved.
@@ -134,7 +130,7 @@ def __init__(
134130
# Initialize weights and determine number of components
135131
if init_weights is None:
136132
self.n_components = n_components
137-
self.weights = self._rng.beta(a=2.5, b=1.5, size=(self.n_components, self.n_signals))
133+
self.weights = self._rng.beta(a=2.0, b=2.0, size=(self.n_components, self.n_signals))
138134
else:
139135
self.n_components = init_weights.shape[0]
140136
self.weights = init_weights
@@ -165,20 +161,20 @@ def __init__(
165161

166162
# Set up residual matrix, objective function, and history
167163
self.residuals = self.get_residual_matrix()
168-
self.objective_function = self.get_objective_function()
164+
self._objective_history = []
165+
self.update_objective()
169166
self.objective_difference = None
170-
self._objective_history = [self.objective_function]
171167

172-
# Set up tracking variables for updateX()
168+
# Set up tracking variables for update_components()
173169
self._prev_components = None
174170
self.grad_components = np.zeros_like(self.components) # Gradient of X (zeros for now)
175171
self._prev_grad_components = np.zeros_like(self.components) # Previous gradient of X (zeros for now)
176172

177173
regularization_term = 0.5 * rho * np.linalg.norm(self._spline_smooth_operator @ self.stretch.T, "fro") ** 2
178174
sparsity_term = eta * np.sum(np.sqrt(self.components)) # Square root penalty
179175
print(
180-
f"Start, Objective function: {self.objective_function:.5e}"
181-
f", Obj - reg/sparse: {self.objective_function - regularization_term - sparsity_term:.5e}"
176+
f"Start, Objective function: {self._objective_history[-1]:.5e}"
177+
f", Obj - reg/sparse: {self._objective_history[-1] - regularization_term - sparsity_term:.5e}"
182178
)
183179

184180
# Main optimization loop
@@ -191,15 +187,15 @@ def __init__(
191187
sparsity_term = eta * np.sum(np.sqrt(self.components)) # Square root penalty
192188
print(
193189
f"Num_updates: {self.num_updates}, "
194-
f"Obj fun: {self.objective_function:.5e}, "
195-
f"Obj - reg/sparse: {self.objective_function - regularization_term - sparsity_term:.5e}, "
190+
f"Obj fun: {self._objective_history[-1]:.5e}, "
191+
f"Obj - reg/sparse: {self._objective_history[-1] - regularization_term - sparsity_term:.5e}, "
196192
f"Iter: {iter}"
197193
)
198194

199195
# Convergence check: decide when to terminate for small/no improvement
200-
print(self.objective_difference, " < ", self.objective_function * tol)
201-
if self.objective_difference < self.objective_function * tol and iter >= 20:
196+
if self.objective_difference < self._objective_history[-1] * tol and iter >= 20:
202197
break
198+
print(self.objective_difference, " < ", self._objective_history[-1] * tol)
203199

204200
# Normalize our results
205201
weights_row_max = np.max(self.weights, axis=1, keepdims=True)
@@ -214,17 +210,17 @@ def __init__(
214210
self.grad_components = np.zeros_like(self.components)
215211
self._prev_grad_components = np.zeros_like(self.components)
216212
self.residuals = self.get_residual_matrix()
217-
self.objective_function = self.get_objective_function()
218213
self.objective_difference = None
219-
self._objective_history = [self.objective_function]
214+
self._objective_history = []
215+
self.update_objective()
220216
for norm_iter in range(100):
221217
self.update_components()
222218
self.residuals = self.get_residual_matrix()
223-
self.objective_function = self.get_objective_function()
224-
print(f"Objective function after normX: {self.objective_function:.5e}")
225-
self._objective_history.append(self.objective_function)
219+
self.update_objective()
220+
print(f"Objective function after normalize_components: {self._objective_history[-1]:.5e}")
221+
self._objective_history.append(self._objective_history[-1])
226222
self.objective_difference = self._objective_history[-2] - self._objective_history[-1]
227-
if self.objective_difference < self.objective_function * tol and norm_iter >= 20:
223+
if self.objective_difference < self._objective_history[-1] * tol and norm_iter >= 20:
228224
break
229225
# end of normalization (and program)
230226
# note that objective function may not fully recover after normalization, this is okay
@@ -233,30 +229,33 @@ def __init__(
233229
def optimize_loop(self):
234230
# Update components first
235231
self._prev_grad_components = self.grad_components.copy()
232+
236233
self.update_components()
234+
237235
self.num_updates += 1
238236
self.residuals = self.get_residual_matrix()
239-
self.objective_function = self.get_objective_function()
240-
print(f"Objective function after update_components: {self.objective_function:.5e}")
241-
self._objective_history.append(self.objective_function)
237+
self.update_objective()
238+
print(f"Objective function after update_components: {self._objective_history[-1]:.5e}")
239+
242240
if self.objective_difference is None:
243-
self.objective_difference = self._objective_history[-1] - self.objective_function
241+
self.objective_difference = self._objective_history[-2] - self._objective_history[-1]
244242

245243
# Now we update weights
246244
self.update_weights()
245+
247246
self.num_updates += 1
248247
self.residuals = self.get_residual_matrix()
249-
self.objective_function = self.get_objective_function()
250-
print(f"Objective function after update_weights: {self.objective_function:.5e}")
251-
self._objective_history.append(self.objective_function)
248+
self.update_objective()
249+
print(f"Objective function after update_weights: {self._objective_history[-1]:.5e}")
252250

253251
# Now we update stretch
254252
self.update_stretch()
253+
255254
self.num_updates += 1
256255
self.residuals = self.get_residual_matrix()
257-
self.objective_function = self.get_objective_function()
258-
print(f"Objective function after update_stretch: {self.objective_function:.5e}")
259-
self._objective_history.append(self.objective_function)
256+
self.update_objective()
257+
print(f"Objective function after update_stretch: {self._objective_history[-1]:.5e}")
258+
260259
self.objective_difference = self._objective_history[-2] - self._objective_history[-1]
261260

262261
def apply_interpolation(self, a, x, return_derivatives=False):
@@ -328,7 +327,8 @@ def get_residual_matrix(self, components=None, weights=None, stretch=None):
328327
residuals += weights[k, :] * stretched_components # Element-wise scaling and sum
329328
return residuals
330329

331-
def get_objective_function(self, residuals=None, stretch=None):
330+
def update_objective(self, residuals=None, stretch=None):
331+
to_return = not (residuals is None and stretch is None)
332332
if residuals is None:
333333
residuals = self.residuals
334334
if stretch is None:
@@ -338,7 +338,11 @@ def get_objective_function(self, residuals=None, stretch=None):
338338
sparsity_term = self.eta * np.sum(np.sqrt(self.components)) # Square root penalty
339339
# Final objective function value
340340
function = residual_term + regularization_term + sparsity_term
341-
return function
341+
342+
if to_return:
343+
return function # Get value directly for use
344+
else:
345+
self._objective_history.append(function) # Store value
342346

343347
def apply_interpolation_matrix(self, components=None, weights=None, stretch=None, return_derivatives=False):
344348
"""
@@ -590,7 +594,7 @@ def update_components(self):
590594
)
591595
self.components = mask * self.components
592596

593-
objective_improvement = self._objective_history[-1] - self.get_objective_function(
597+
objective_improvement = self._objective_history[-1] - self.update_objective(
594598
residuals=self.get_residual_matrix()
595599
)
596600

@@ -645,7 +649,7 @@ def regularize_function(self, stretch=None):
645649
stretch_difference = stretch_difference - self.source_matrix
646650

647651
# Compute objective function
648-
reg_func = self.get_objective_function(stretch_difference, stretch)
652+
reg_func = self.update_objective(stretch_difference, stretch)
649653

650654
# Compute gradient
651655
tiled_derivative = np.sum(

0 commit comments

Comments
 (0)