From 06e3a99deb1a65e4bca2bbb66aa42ed03d916772 Mon Sep 17 00:00:00 2001 From: AishwaryaBadlani Date: Tue, 12 Aug 2025 17:20:02 +0500 Subject: [PATCH 1/5] ENH: Add DoubleQuadratic datafit for asymmetric loss --- skglm/datafits/__init__.py | 5 +- skglm/datafits/_double_quadratic.py | 294 +++++++++++++++++++++++++++ skglm/tests/test_double_quadratic.py | 52 +++++ 3 files changed, 349 insertions(+), 2 deletions(-) create mode 100644 skglm/datafits/_double_quadratic.py create mode 100644 skglm/tests/test_double_quadratic.py diff --git a/skglm/datafits/__init__.py b/skglm/datafits/__init__.py index 0c6a5973..67df316c 100644 --- a/skglm/datafits/__init__.py +++ b/skglm/datafits/__init__.py @@ -1,6 +1,7 @@ from .base import BaseDatafit, BaseMultitaskDatafit from .single_task import (Quadratic, QuadraticSVC, Logistic, Huber, Poisson, Gamma, Cox, WeightedQuadratic, QuadraticHessian,) +from ._double_quadratic import DoubleQuadratic from .multi_task import QuadraticMultiTask from .group import QuadraticGroup, LogisticGroup, PoissonGroup @@ -10,5 +11,5 @@ Quadratic, QuadraticSVC, Logistic, Huber, Poisson, Gamma, Cox, QuadraticMultiTask, QuadraticGroup, LogisticGroup, PoissonGroup, WeightedQuadratic, - QuadraticHessian -] + QuadraticHessian, DoubleQuadratic # Add this +] \ No newline at end of file diff --git a/skglm/datafits/_double_quadratic.py b/skglm/datafits/_double_quadratic.py new file mode 100644 index 00000000..c4abe393 --- /dev/null +++ b/skglm/datafits/_double_quadratic.py @@ -0,0 +1,294 @@ +import numpy as np +from numba import float64 +from .base import BaseDatafit + + +class DoubleQuadratic(BaseDatafit): + """Double Quadratic datafit with asymmetric loss. + + The datafit reads: + + .. math:: 1 / (2 \\times n_\\text{samples}) \\sum_i (\\alpha + (1-2\\alpha) \\cdot 1[\\epsilon_i > 0]) \\epsilon_i^2 + + where :math:`\\epsilon_i = (Xw)_i - y_i` are the residuals. + + Parameters + ---------- + alpha : float, default=0.5 + Asymmetry parameter controlling the relative weighting of positive vs + negative residuals: + - alpha = 0.5: symmetric loss (equivalent to standard Quadratic) + - alpha < 0.5: penalize positive residuals (overestimation) more heavily + - alpha > 0.5: penalize negative residuals (underestimation) more heavily + + Attributes + ---------- + Xty : array, shape (n_features,) + Pre-computed quantity used during the gradient evaluation. + Equal to ``X.T @ y``. + + Note + ---- + The class is jit compiled at fit time using Numba compiler. + This allows for faster computations. + """ + + def __init__(self, alpha=0.5): + if not 0 <= alpha <= 1: + raise ValueError(f"alpha must be between 0 and 1, got {alpha}") + self.alpha = alpha + + def get_spec(self): + spec = ( + ('alpha', float64), + ('Xty', float64[:]), + ) + return spec + + def params_to_dict(self): + return dict(alpha=self.alpha) + + def get_lipschitz(self, X, y): + """Compute per-coordinate Lipschitz constants. + + For DoubleQuadratic with scaling factor 2, the Lipschitz constant + for coordinate j is bounded by 2 * max_weight * ||X[:, j]||^2 / n_samples. + """ + n_features = X.shape[1] + + # Compute weight bounds (after scaling by 2) + weight_pos = 2 * (1 - self.alpha) # weight for positive residuals + weight_neg = 2 * self.alpha # weight for negative residuals + max_weight = max(weight_pos, weight_neg) + + lipschitz = np.zeros(n_features, dtype=X.dtype) + for j in range(n_features): + lipschitz[j] = max_weight * (X[:, j] ** 2).sum() / len(y) + + return lipschitz + + def get_lipschitz_sparse(self, X_data, X_indptr, X_indices, y): + """Sparse version of get_lipschitz.""" + n_features = len(X_indptr) - 1 + + # Compute weight bounds (after scaling by 2) + weight_pos = 2 * (1 - self.alpha) + weight_neg = 2 * self.alpha + max_weight = max(weight_pos, weight_neg) + + lipschitz = np.zeros(n_features, dtype=X_data.dtype) + + for j in range(n_features): + nrm2 = 0. + for idx in range(X_indptr[j], X_indptr[j + 1]): + nrm2 += X_data[idx] ** 2 + + lipschitz[j] = max_weight * nrm2 / len(y) + + return lipschitz + + def get_global_lipschitz(self, X, y): + """Global Lipschitz constant.""" + weight_pos = 2 * (1 - self.alpha) + weight_neg = 2 * self.alpha + max_weight = max(weight_pos, weight_neg) + + from scipy.linalg import norm + return max_weight * norm(X, ord=2) ** 2 / len(y) + + def get_global_lipschitz_sparse(self, X_data, X_indptr, X_indices, y): + """Sparse version of global Lipschitz constant.""" + weight_pos = 2 * (1 - self.alpha) + weight_neg = 2 * self.alpha + max_weight = max(weight_pos, weight_neg) + + from .utils import spectral_norm + return max_weight * spectral_norm(X_data, X_indptr, X_indices, len(y)) ** 2 / len(y) + + def initialize(self, X, y): + """Pre-compute X.T @ y for efficient gradient computation.""" + self.Xty = X.T @ y + + def initialize_sparse(self, X_data, X_indptr, X_indices, y): + """Sparse version of initialize.""" + n_features = len(X_indptr) - 1 + self.Xty = np.zeros(n_features, dtype=X_data.dtype) + + for j in range(n_features): + xty = 0 + for idx in range(X_indptr[j], X_indptr[j + 1]): + xty += X_data[idx] * y[X_indices[idx]] + + self.Xty[j] = xty + + def value(self, y, w, Xw): + """Compute the asymmetric quadratic loss value. + + When alpha=0.5, this should be identical to Quadratic loss. + The formula needs to be: (1/2n) * Σ weights * (y - Xw)² + where weights are normalized so that alpha=0.5 gives weight=1.0 + """ + # Match Quadratic exactly: use (y - Xw) for loss computation + residuals = y - Xw + + # For asymmetric weighting, check sign of (Xw - y) + prediction_residuals = Xw - y + + # Compute weights, normalized so alpha=0.5 gives weight=1.0 + # Original formula: α + (1-2α) * 1[εᵢ>0] + # At α=0.5: 0.5 + 0 = 0.5, but we want 1.0 + # So we need to scale by 2: 2 * (α + (1-2α) * 1[εᵢ>0]) + weights = 2 * (self.alpha + (1 - 2*self.alpha) * (prediction_residuals > 0)) + + # Return normalized loss + return np.sum(weights * residuals**2) / (2 * len(y)) + + def gradient_scalar(self, X, y, w, Xw, j): + """Compute gradient w.r.t. coordinate j.""" + prediction_residuals = Xw - y # For gradient computation + + # Compute weights with same scaling as value() + weights = 2 * (self.alpha + (1 - 2*self.alpha) * (prediction_residuals > 0)) + + # Gradient: X[:, j].T @ (weights * (Xw - y)) / n_samples + return (X[:, j] @ (weights * prediction_residuals)) / len(y) + + def gradient_scalar_sparse(self, X_data, X_indptr, X_indices, y, Xw, j): + """Sparse version of gradient_scalar.""" + prediction_residuals = Xw - y + + # Compute weights with same scaling + weights = 2 * (self.alpha + (1 - 2*self.alpha) * (prediction_residuals > 0)) + + # Compute X[:, j].T @ (weights * prediction_residuals) for sparse X + XjT_weighted_residuals = 0. + for i in range(X_indptr[j], X_indptr[j+1]): + sample_idx = X_indices[i] + XjT_weighted_residuals += X_data[i] * weights[sample_idx] * prediction_residuals[sample_idx] + + return XjT_weighted_residuals / len(y) + + def gradient(self, X, y, Xw): + """Compute full gradient vector.""" + prediction_residuals = Xw - y + + # Compute weights with same scaling as value() + weights = 2 * (self.alpha + (1 - 2*self.alpha) * (prediction_residuals > 0)) + + # Return X.T @ (weights * prediction_residuals) / n_samples + return X.T @ (weights * prediction_residuals) / len(y) + + def raw_grad(self, y, Xw): + """Compute gradient of datafit w.r.t Xw.""" + prediction_residuals = Xw - y + + # Compute weights with same scaling + weights = 2 * (self.alpha + (1 - 2*self.alpha) * (prediction_residuals > 0)) + + return weights * prediction_residuals / len(y) + + def raw_hessian(self, y, Xw): + """Compute Hessian of datafit w.r.t Xw.""" + prediction_residuals = Xw - y + + # Compute weights with same scaling + weights = 2 * (self.alpha + (1 - 2*self.alpha) * (prediction_residuals > 0)) + + return weights / len(y) + + def full_grad_sparse(self, X_data, X_indptr, X_indices, y, Xw): + """Sparse version of full gradient computation.""" + n_features = X_indptr.shape[0] - 1 + n_samples = y.shape[0] + prediction_residuals = Xw - y + + # Compute weights with same scaling + weights = 2 * (self.alpha + (1 - 2*self.alpha) * (prediction_residuals > 0)) + + grad = np.zeros(n_features, dtype=Xw.dtype) + for j in range(n_features): + XjT_weighted_residuals = 0. + for i in range(X_indptr[j], X_indptr[j + 1]): + sample_idx = X_indices[i] + XjT_weighted_residuals += X_data[i] * weights[sample_idx] * prediction_residuals[sample_idx] + grad[j] = XjT_weighted_residuals / n_samples + return grad + + def intercept_update_step(self, y, Xw): + """Compute intercept update step.""" + prediction_residuals = Xw - y + + # Compute weights with same scaling + weights = 2 * (self.alpha + (1 - 2*self.alpha) * (prediction_residuals > 0)) + + return np.mean(weights * prediction_residuals) + + +# Test function to validate our implementation +def _test_double_quadratic(): + """Test DoubleQuadratic implementation.""" + import numpy as np + from .single_task import Quadratic + + print("Testing DoubleQuadratic implementation...") + + # Test data + np.random.seed(42) + n_samples, n_features = 50, 10 + X = np.random.randn(n_samples, n_features) + y = np.random.randn(n_samples) + w = np.random.randn(n_features) + Xw = X @ w + + # Test 1: alpha=0.5 should match standard Quadratic + print("\n=== Test 1: alpha=0.5 vs Quadratic ===") + + quad = Quadratic() + quad.initialize(X, y) + + dquad = DoubleQuadratic(alpha=0.5) + dquad.initialize(X, y) + + loss_quad = quad.value(y, w, Xw) + loss_dquad = dquad.value(y, w, Xw) + + print(f"Quadratic loss: {loss_quad:.8f}") + print(f"DoubleQuadratic: {loss_dquad:.8f}") + print(f"Difference: {abs(loss_quad - loss_dquad):.2e}") + + # Test gradients + grad_quad = quad.gradient(X, y, Xw) + grad_dquad = dquad.gradient(X, y, Xw) + grad_diff = np.linalg.norm(grad_quad - grad_dquad) + + print(f"Gradient difference: {grad_diff:.2e}") + + # Test case 2: Asymmetric behavior + print("\n=== Test 2: Asymmetric behavior ===") + + # Create simple test case with known residuals + X_simple = np.eye(4) # Identity matrix + y_simple = np.array([0., 0., 0., 0.]) + w_simple = np.array([1., -1., 2., -2.]) # Predictions: [1, -1, 2, -2], so prediction_residuals = [1, -1, 2, -2] + Xw_simple = X_simple @ w_simple + + dquad_asym = DoubleQuadratic(alpha=0.3) # Penalize positive residuals more + dquad_asym.initialize(X_simple, y_simple) + + loss_asym = dquad_asym.value(y_simple, w_simple, Xw_simple) + + # Manual calculation: + # prediction_residuals = [1, -1, 2, -2] (Xw - y) + # weights = 0.3 + 0.4 * [1, 0, 1, 0] = [0.7, 0.3, 0.7, 0.3] + # loss = (1/(2*4)) * (0.7*1² + 0.3*1² + 0.7*4² + 0.3*4²) + expected = (1/8) * (0.7*1 + 0.3*1 + 0.7*4 + 0.3*4) + + print(f"Asymmetric loss: {loss_asym:.6f}") + print(f"Expected: {expected:.6f}") + print(f"Difference: {abs(loss_asym - expected):.2e}") + + print("\n=== All tests completed ===") + + +if __name__ == "__main__": + _test_double_quadratic() \ No newline at end of file diff --git a/skglm/tests/test_double_quadratic.py b/skglm/tests/test_double_quadratic.py new file mode 100644 index 00000000..be83bff5 --- /dev/null +++ b/skglm/tests/test_double_quadratic.py @@ -0,0 +1,52 @@ +import numpy as np +import pytest +from skglm.datafits import DoubleQuadratic, Quadratic + + +class TestDoubleQuadratic: + + def test_alpha_half_matches_quadratic(self): + """Test that alpha=0.5 gives same results as Quadratic.""" + np.random.seed(42) + X = np.random.randn(20, 5) + y = np.random.randn(20) + w = np.random.randn(5) + Xw = X @ w + + quad = Quadratic() + quad.initialize(X, y) + + dquad = DoubleQuadratic(alpha=0.5) + dquad.initialize(X, y) + + # Test loss values + assert np.allclose(quad.value(y, w, Xw), dquad.value(y, w, Xw)) + + # Test gradients + assert np.allclose(quad.gradient(X, y, Xw), dquad.gradient(X, y, Xw)) + + def test_asymmetric_behavior(self): + """Test that asymmetric behavior works correctly.""" + # Simple test case with known residuals + X = np.eye(4) + y = np.zeros(4) + w = np.array([1., -1., 2., -2.]) # residuals = [1, -1, 2, -2] + Xw = X @ w + + dquad = DoubleQuadratic(alpha=0.3) + dquad.initialize(X, y) + + loss = dquad.value(y, w, Xw) + + # Manual calculation with scaling: weights = 2 * [0.7, 0.3, 0.7, 0.3] = [1.4, 0.6, 1.4, 0.6] + expected = (1/8) * (1.4*1 + 0.6*1 + 1.4*4 + 0.6*4) + + assert np.allclose(loss, expected) + + def test_parameter_validation(self): + """Test parameter validation.""" + with pytest.raises(ValueError): + DoubleQuadratic(alpha=-0.1) + + with pytest.raises(ValueError): + DoubleQuadratic(alpha=1.1) \ No newline at end of file From cda40e1e5b65638b254ae1b130afd5525f419c3c Mon Sep 17 00:00:00 2001 From: AishwaryaBadlani Date: Thu, 14 Aug 2025 17:43:32 +0500 Subject: [PATCH 2/5] DOC: Update derivation equations of positive Group Lasso penalty - Add complete Lagrangian derivation for case w = 0 - Include rigorous KKT conditions and optimality analysis - Replace incomplete derivations with full mathematical proofs Fixes #243 --- doc/tutorials/prox_nn_group_lasso.rst | 37 +++++++++++++++++++-------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/doc/tutorials/prox_nn_group_lasso.rst b/doc/tutorials/prox_nn_group_lasso.rst index 0611ac82..c48e3837 100644 --- a/doc/tutorials/prox_nn_group_lasso.rst +++ b/doc/tutorials/prox_nn_group_lasso.rst @@ -68,7 +68,6 @@ Using the Moreau decomposition, Equations :eq:`fenchel` and :eq:`prox_projection A similar formula can be derived for the group Lasso with nonnegative constraints. - Proximity operator of the group Lasso with positivity constraints ================================================================= @@ -135,8 +134,6 @@ and thus, combined with Equations :eq:`prox_projection_nn_Sc` and :eq:`prox_proj (1 - \frac{\lambda}{\norm{x_S}})_{+} x_S . - - .. _subdiff_positive_group_lasso: Subdifferential of the positive Group Lasso penalty @@ -184,7 +181,7 @@ Minimizing over :math:`n` then over :math:`u`, thanks to [`1 0`, taking a non zero :math:`n_i` will only increase the quantity that :math:`u_i` needs to bring closer to 0. -For a rigorous derivation of this, introduce the Lagrangian on a squared objective +**Rigorous derivation:** Consider the Lagrangian (where we have squared the objective and the :math:`u` constraint for convenience when taking derivatives): .. math:: @@ -192,12 +189,33 @@ For a rigorous derivation of this, introduce the Lagrangian on a squared objecti \frac{1}{2}\norm{u + n - v}^2 + \nu(\frac{1}{2} \norm{u}^2 - \lambda^2 / 2) + \langle \mu, n \rangle , -and write down the optimality condition with respect to :math:`u` and :math:`n`. -Treat the case :math:`nu = 0` separately; in the other case show that :\math:`u` must be positive, and that :math:`v = (1 + \nu) u + n`, together with :math:`u = \mu / \nu` and complementary slackness, to reach the conclusion. +with a positive scalar :math:`\nu` and a positive vector :math:`\mu`. + +Slater's condition is met (assuming :math:`\lambda > 0`), so the KKT conditions are necessary and sufficient. Considering the optimality with respect to :math:`u` and :math:`n` respectively, we obtain: + +.. math:: + + u + n - v + \nu u &= 0 \\ + u + n - v + \mu &= 0 + +Hence :math:`\mu = \nu u`. If :math:`\nu = 0`, then :math:`v = u + n` and the optimal objective is 0. Else, :math:`\nu > 0` and :math:`\mu \geq 0`, so any solution :math:`u = \frac{1}{\nu}\mu` must be positive. By complementary slackness, :math:`\mu_j n_j = 0 = \nu u_j n_j`. So :math:`u` and :math:`n` have disjoint supports. + +Since :math:`v = (1 + \nu)u + n`, it is clear that: + +- If :math:`v_j > 0`, it is :math:`u_j` which is nonzero, equal to :math:`v_j/(1 + \nu)` +- If :math:`v_j < 0`, it is :math:`n_j` which is nonzero and equal to :math:`v_j` + +We have :math:`v_j > 0 \Rightarrow n_j = 0` and :math:`v_j < 0 \Rightarrow u_j = 0`, so we can rewrite the problem as: + +.. math:: + + \min_{u} \sum_{j: v_j > 0} (u_j - v_j)^2 \quad \text{s.t.} \quad \sum_{j: v_j > 0} u_j^2 \leq \lambda^2 + +which is the projection problem yielding the final result. Case :math:`|| w || \ne 0` --------------------------- -The subdifferential in that case is :math:`\lambda w / {|| w ||} + C_1 \times \ldots \times C_g` where :math:`C_j = {0}` if :math:`w_j > 0` and :math:`C_j = mathbb{R}_-` otherwise (:math:`w_j =0`). +The subdifferential in that case is :math:`\lambda w / {|| w ||} + C_1 \times \ldots \times C_g` where :math:`C_j = {0}` if :math:`w_j > 0` and :math:`C_j = \mathbb{R}_-` otherwise (:math:`w_j =0`). By letting :math:`p` denotes the projection of :math:`v` onto this set, one has @@ -216,13 +234,12 @@ The distance to the subdifferential is then: .. math:: - D(v) = || v - p || = \sqrt{\sum_{j, w_j > 0} (v_j - \lambda \frac{w_j}{||w||})^2 + \sum_{j, w_j=0} \max(0, v_j)^2 + D(v) = || v - p || = \sqrt{\sum_{j, w_j > 0} (v_j - \lambda \frac{w_j}{||w||})^2 + \sum_{j, w_j=0} \max(0, v_j)^2} since :math:`v_j - \min(v_j, 0) = v_j + \max(-v_j, 0) = \max(0, v_j)`. - References ========== -[1] ``_ +[1] ``_ \ No newline at end of file From 0446aac73106a917cf45b7d3c8fe6dca71c66530 Mon Sep 17 00:00:00 2001 From: AishwaryaBadlani Date: Thu, 14 Aug 2025 18:19:06 +0500 Subject: [PATCH 3/5] Remove unrelated files from PR From 290fc59f5c9b19fc7b225aae6b10954a7189e970 Mon Sep 17 00:00:00 2001 From: AishwaryaBadlani Date: Thu, 14 Aug 2025 18:40:05 +0500 Subject: [PATCH 4/5] DOC: Update derivation equations of positive Group Lasso penalty - Add complete Lagrangian derivation for case w = 0 - Include rigorous KKT conditions and optimality analysis - Replace incomplete derivations with full mathematical proofs Fixes #243 --- skglm/datafits/__init__.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/skglm/datafits/__init__.py b/skglm/datafits/__init__.py index 67df316c..496d5db4 100644 --- a/skglm/datafits/__init__.py +++ b/skglm/datafits/__init__.py @@ -1,15 +1,14 @@ from .base import BaseDatafit, BaseMultitaskDatafit from .single_task import (Quadratic, QuadraticSVC, Logistic, Huber, Poisson, Gamma, Cox, WeightedQuadratic, QuadraticHessian,) -from ._double_quadratic import DoubleQuadratic from .multi_task import QuadraticMultiTask from .group import QuadraticGroup, LogisticGroup, PoissonGroup __all__ = [ - BaseDatafit, BaseMultitaskDatafit, - Quadratic, QuadraticSVC, Logistic, Huber, Poisson, Gamma, Cox, - QuadraticMultiTask, - QuadraticGroup, LogisticGroup, PoissonGroup, WeightedQuadratic, - QuadraticHessian, DoubleQuadratic # Add this -] \ No newline at end of file + 'BaseDatafit', 'BaseMultitaskDatafit', + 'Quadratic', 'QuadraticSVC', 'Logistic', 'Huber', 'Poisson', 'Gamma', 'Cox', + 'QuadraticMultiTask', + 'QuadraticGroup', 'LogisticGroup', 'PoissonGroup', 'WeightedQuadratic', + 'QuadraticHessian' +] From afeaa30fb82d5cfb22cdfa2507bf1ddc70ee02c2 Mon Sep 17 00:00:00 2001 From: malakazlan Date: Thu, 16 Oct 2025 15:27:46 +0500 Subject: [PATCH 5/5] Remove unrelated files from docs PR --- skglm/datafits/__init__.py | 10 +- skglm/datafits/_double_quadratic.py | 294 --------------------------- skglm/tests/test_double_quadratic.py | 52 ----- 3 files changed, 5 insertions(+), 351 deletions(-) delete mode 100644 skglm/datafits/_double_quadratic.py delete mode 100644 skglm/tests/test_double_quadratic.py diff --git a/skglm/datafits/__init__.py b/skglm/datafits/__init__.py index 496d5db4..0c6a5973 100644 --- a/skglm/datafits/__init__.py +++ b/skglm/datafits/__init__.py @@ -6,9 +6,9 @@ __all__ = [ - 'BaseDatafit', 'BaseMultitaskDatafit', - 'Quadratic', 'QuadraticSVC', 'Logistic', 'Huber', 'Poisson', 'Gamma', 'Cox', - 'QuadraticMultiTask', - 'QuadraticGroup', 'LogisticGroup', 'PoissonGroup', 'WeightedQuadratic', - 'QuadraticHessian' + BaseDatafit, BaseMultitaskDatafit, + Quadratic, QuadraticSVC, Logistic, Huber, Poisson, Gamma, Cox, + QuadraticMultiTask, + QuadraticGroup, LogisticGroup, PoissonGroup, WeightedQuadratic, + QuadraticHessian ] diff --git a/skglm/datafits/_double_quadratic.py b/skglm/datafits/_double_quadratic.py deleted file mode 100644 index c4abe393..00000000 --- a/skglm/datafits/_double_quadratic.py +++ /dev/null @@ -1,294 +0,0 @@ -import numpy as np -from numba import float64 -from .base import BaseDatafit - - -class DoubleQuadratic(BaseDatafit): - """Double Quadratic datafit with asymmetric loss. - - The datafit reads: - - .. math:: 1 / (2 \\times n_\\text{samples}) \\sum_i (\\alpha + (1-2\\alpha) \\cdot 1[\\epsilon_i > 0]) \\epsilon_i^2 - - where :math:`\\epsilon_i = (Xw)_i - y_i` are the residuals. - - Parameters - ---------- - alpha : float, default=0.5 - Asymmetry parameter controlling the relative weighting of positive vs - negative residuals: - - alpha = 0.5: symmetric loss (equivalent to standard Quadratic) - - alpha < 0.5: penalize positive residuals (overestimation) more heavily - - alpha > 0.5: penalize negative residuals (underestimation) more heavily - - Attributes - ---------- - Xty : array, shape (n_features,) - Pre-computed quantity used during the gradient evaluation. - Equal to ``X.T @ y``. - - Note - ---- - The class is jit compiled at fit time using Numba compiler. - This allows for faster computations. - """ - - def __init__(self, alpha=0.5): - if not 0 <= alpha <= 1: - raise ValueError(f"alpha must be between 0 and 1, got {alpha}") - self.alpha = alpha - - def get_spec(self): - spec = ( - ('alpha', float64), - ('Xty', float64[:]), - ) - return spec - - def params_to_dict(self): - return dict(alpha=self.alpha) - - def get_lipschitz(self, X, y): - """Compute per-coordinate Lipschitz constants. - - For DoubleQuadratic with scaling factor 2, the Lipschitz constant - for coordinate j is bounded by 2 * max_weight * ||X[:, j]||^2 / n_samples. - """ - n_features = X.shape[1] - - # Compute weight bounds (after scaling by 2) - weight_pos = 2 * (1 - self.alpha) # weight for positive residuals - weight_neg = 2 * self.alpha # weight for negative residuals - max_weight = max(weight_pos, weight_neg) - - lipschitz = np.zeros(n_features, dtype=X.dtype) - for j in range(n_features): - lipschitz[j] = max_weight * (X[:, j] ** 2).sum() / len(y) - - return lipschitz - - def get_lipschitz_sparse(self, X_data, X_indptr, X_indices, y): - """Sparse version of get_lipschitz.""" - n_features = len(X_indptr) - 1 - - # Compute weight bounds (after scaling by 2) - weight_pos = 2 * (1 - self.alpha) - weight_neg = 2 * self.alpha - max_weight = max(weight_pos, weight_neg) - - lipschitz = np.zeros(n_features, dtype=X_data.dtype) - - for j in range(n_features): - nrm2 = 0. - for idx in range(X_indptr[j], X_indptr[j + 1]): - nrm2 += X_data[idx] ** 2 - - lipschitz[j] = max_weight * nrm2 / len(y) - - return lipschitz - - def get_global_lipschitz(self, X, y): - """Global Lipschitz constant.""" - weight_pos = 2 * (1 - self.alpha) - weight_neg = 2 * self.alpha - max_weight = max(weight_pos, weight_neg) - - from scipy.linalg import norm - return max_weight * norm(X, ord=2) ** 2 / len(y) - - def get_global_lipschitz_sparse(self, X_data, X_indptr, X_indices, y): - """Sparse version of global Lipschitz constant.""" - weight_pos = 2 * (1 - self.alpha) - weight_neg = 2 * self.alpha - max_weight = max(weight_pos, weight_neg) - - from .utils import spectral_norm - return max_weight * spectral_norm(X_data, X_indptr, X_indices, len(y)) ** 2 / len(y) - - def initialize(self, X, y): - """Pre-compute X.T @ y for efficient gradient computation.""" - self.Xty = X.T @ y - - def initialize_sparse(self, X_data, X_indptr, X_indices, y): - """Sparse version of initialize.""" - n_features = len(X_indptr) - 1 - self.Xty = np.zeros(n_features, dtype=X_data.dtype) - - for j in range(n_features): - xty = 0 - for idx in range(X_indptr[j], X_indptr[j + 1]): - xty += X_data[idx] * y[X_indices[idx]] - - self.Xty[j] = xty - - def value(self, y, w, Xw): - """Compute the asymmetric quadratic loss value. - - When alpha=0.5, this should be identical to Quadratic loss. - The formula needs to be: (1/2n) * Σ weights * (y - Xw)² - where weights are normalized so that alpha=0.5 gives weight=1.0 - """ - # Match Quadratic exactly: use (y - Xw) for loss computation - residuals = y - Xw - - # For asymmetric weighting, check sign of (Xw - y) - prediction_residuals = Xw - y - - # Compute weights, normalized so alpha=0.5 gives weight=1.0 - # Original formula: α + (1-2α) * 1[εᵢ>0] - # At α=0.5: 0.5 + 0 = 0.5, but we want 1.0 - # So we need to scale by 2: 2 * (α + (1-2α) * 1[εᵢ>0]) - weights = 2 * (self.alpha + (1 - 2*self.alpha) * (prediction_residuals > 0)) - - # Return normalized loss - return np.sum(weights * residuals**2) / (2 * len(y)) - - def gradient_scalar(self, X, y, w, Xw, j): - """Compute gradient w.r.t. coordinate j.""" - prediction_residuals = Xw - y # For gradient computation - - # Compute weights with same scaling as value() - weights = 2 * (self.alpha + (1 - 2*self.alpha) * (prediction_residuals > 0)) - - # Gradient: X[:, j].T @ (weights * (Xw - y)) / n_samples - return (X[:, j] @ (weights * prediction_residuals)) / len(y) - - def gradient_scalar_sparse(self, X_data, X_indptr, X_indices, y, Xw, j): - """Sparse version of gradient_scalar.""" - prediction_residuals = Xw - y - - # Compute weights with same scaling - weights = 2 * (self.alpha + (1 - 2*self.alpha) * (prediction_residuals > 0)) - - # Compute X[:, j].T @ (weights * prediction_residuals) for sparse X - XjT_weighted_residuals = 0. - for i in range(X_indptr[j], X_indptr[j+1]): - sample_idx = X_indices[i] - XjT_weighted_residuals += X_data[i] * weights[sample_idx] * prediction_residuals[sample_idx] - - return XjT_weighted_residuals / len(y) - - def gradient(self, X, y, Xw): - """Compute full gradient vector.""" - prediction_residuals = Xw - y - - # Compute weights with same scaling as value() - weights = 2 * (self.alpha + (1 - 2*self.alpha) * (prediction_residuals > 0)) - - # Return X.T @ (weights * prediction_residuals) / n_samples - return X.T @ (weights * prediction_residuals) / len(y) - - def raw_grad(self, y, Xw): - """Compute gradient of datafit w.r.t Xw.""" - prediction_residuals = Xw - y - - # Compute weights with same scaling - weights = 2 * (self.alpha + (1 - 2*self.alpha) * (prediction_residuals > 0)) - - return weights * prediction_residuals / len(y) - - def raw_hessian(self, y, Xw): - """Compute Hessian of datafit w.r.t Xw.""" - prediction_residuals = Xw - y - - # Compute weights with same scaling - weights = 2 * (self.alpha + (1 - 2*self.alpha) * (prediction_residuals > 0)) - - return weights / len(y) - - def full_grad_sparse(self, X_data, X_indptr, X_indices, y, Xw): - """Sparse version of full gradient computation.""" - n_features = X_indptr.shape[0] - 1 - n_samples = y.shape[0] - prediction_residuals = Xw - y - - # Compute weights with same scaling - weights = 2 * (self.alpha + (1 - 2*self.alpha) * (prediction_residuals > 0)) - - grad = np.zeros(n_features, dtype=Xw.dtype) - for j in range(n_features): - XjT_weighted_residuals = 0. - for i in range(X_indptr[j], X_indptr[j + 1]): - sample_idx = X_indices[i] - XjT_weighted_residuals += X_data[i] * weights[sample_idx] * prediction_residuals[sample_idx] - grad[j] = XjT_weighted_residuals / n_samples - return grad - - def intercept_update_step(self, y, Xw): - """Compute intercept update step.""" - prediction_residuals = Xw - y - - # Compute weights with same scaling - weights = 2 * (self.alpha + (1 - 2*self.alpha) * (prediction_residuals > 0)) - - return np.mean(weights * prediction_residuals) - - -# Test function to validate our implementation -def _test_double_quadratic(): - """Test DoubleQuadratic implementation.""" - import numpy as np - from .single_task import Quadratic - - print("Testing DoubleQuadratic implementation...") - - # Test data - np.random.seed(42) - n_samples, n_features = 50, 10 - X = np.random.randn(n_samples, n_features) - y = np.random.randn(n_samples) - w = np.random.randn(n_features) - Xw = X @ w - - # Test 1: alpha=0.5 should match standard Quadratic - print("\n=== Test 1: alpha=0.5 vs Quadratic ===") - - quad = Quadratic() - quad.initialize(X, y) - - dquad = DoubleQuadratic(alpha=0.5) - dquad.initialize(X, y) - - loss_quad = quad.value(y, w, Xw) - loss_dquad = dquad.value(y, w, Xw) - - print(f"Quadratic loss: {loss_quad:.8f}") - print(f"DoubleQuadratic: {loss_dquad:.8f}") - print(f"Difference: {abs(loss_quad - loss_dquad):.2e}") - - # Test gradients - grad_quad = quad.gradient(X, y, Xw) - grad_dquad = dquad.gradient(X, y, Xw) - grad_diff = np.linalg.norm(grad_quad - grad_dquad) - - print(f"Gradient difference: {grad_diff:.2e}") - - # Test case 2: Asymmetric behavior - print("\n=== Test 2: Asymmetric behavior ===") - - # Create simple test case with known residuals - X_simple = np.eye(4) # Identity matrix - y_simple = np.array([0., 0., 0., 0.]) - w_simple = np.array([1., -1., 2., -2.]) # Predictions: [1, -1, 2, -2], so prediction_residuals = [1, -1, 2, -2] - Xw_simple = X_simple @ w_simple - - dquad_asym = DoubleQuadratic(alpha=0.3) # Penalize positive residuals more - dquad_asym.initialize(X_simple, y_simple) - - loss_asym = dquad_asym.value(y_simple, w_simple, Xw_simple) - - # Manual calculation: - # prediction_residuals = [1, -1, 2, -2] (Xw - y) - # weights = 0.3 + 0.4 * [1, 0, 1, 0] = [0.7, 0.3, 0.7, 0.3] - # loss = (1/(2*4)) * (0.7*1² + 0.3*1² + 0.7*4² + 0.3*4²) - expected = (1/8) * (0.7*1 + 0.3*1 + 0.7*4 + 0.3*4) - - print(f"Asymmetric loss: {loss_asym:.6f}") - print(f"Expected: {expected:.6f}") - print(f"Difference: {abs(loss_asym - expected):.2e}") - - print("\n=== All tests completed ===") - - -if __name__ == "__main__": - _test_double_quadratic() \ No newline at end of file diff --git a/skglm/tests/test_double_quadratic.py b/skglm/tests/test_double_quadratic.py deleted file mode 100644 index be83bff5..00000000 --- a/skglm/tests/test_double_quadratic.py +++ /dev/null @@ -1,52 +0,0 @@ -import numpy as np -import pytest -from skglm.datafits import DoubleQuadratic, Quadratic - - -class TestDoubleQuadratic: - - def test_alpha_half_matches_quadratic(self): - """Test that alpha=0.5 gives same results as Quadratic.""" - np.random.seed(42) - X = np.random.randn(20, 5) - y = np.random.randn(20) - w = np.random.randn(5) - Xw = X @ w - - quad = Quadratic() - quad.initialize(X, y) - - dquad = DoubleQuadratic(alpha=0.5) - dquad.initialize(X, y) - - # Test loss values - assert np.allclose(quad.value(y, w, Xw), dquad.value(y, w, Xw)) - - # Test gradients - assert np.allclose(quad.gradient(X, y, Xw), dquad.gradient(X, y, Xw)) - - def test_asymmetric_behavior(self): - """Test that asymmetric behavior works correctly.""" - # Simple test case with known residuals - X = np.eye(4) - y = np.zeros(4) - w = np.array([1., -1., 2., -2.]) # residuals = [1, -1, 2, -2] - Xw = X @ w - - dquad = DoubleQuadratic(alpha=0.3) - dquad.initialize(X, y) - - loss = dquad.value(y, w, Xw) - - # Manual calculation with scaling: weights = 2 * [0.7, 0.3, 0.7, 0.3] = [1.4, 0.6, 1.4, 0.6] - expected = (1/8) * (1.4*1 + 0.6*1 + 1.4*4 + 0.6*4) - - assert np.allclose(loss, expected) - - def test_parameter_validation(self): - """Test parameter validation.""" - with pytest.raises(ValueError): - DoubleQuadratic(alpha=-0.1) - - with pytest.raises(ValueError): - DoubleQuadratic(alpha=1.1) \ No newline at end of file