-
Notifications
You must be signed in to change notification settings - Fork 361
Sample Weights introduction #652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,10 +11,13 @@ | |
| from scipy.integrate import odeint | ||
| from scipy.integrate import solve_ivp | ||
| from scipy.interpolate import interp1d | ||
| from sklearn import set_config | ||
| from sklearn.base import BaseEstimator | ||
| from sklearn.metrics import r2_score | ||
| from sklearn.pipeline import Pipeline | ||
| from sklearn.utils.validation import check_is_fitted | ||
|
|
||
| set_config(enable_metadata_routing=True) | ||
| from typing_extensions import Self | ||
|
|
||
| from .differentiation import BaseDifferentiation | ||
|
|
@@ -293,6 +296,9 @@ def __init__( | |
| differentiation_method = FiniteDifference(axis=-2) | ||
| self.differentiation_method = differentiation_method | ||
| self.discrete_time = discrete_time | ||
| self.set_fit_request(sample_weight=True) | ||
| self.set_score_request(sample_weight=True) | ||
| self.optimizer.set_fit_request(sample_weight=True) | ||
|
Comment on lines
+299
to
+301
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, leave this up to the user |
||
|
|
||
| def fit( | ||
| self, | ||
|
|
@@ -301,6 +307,7 @@ def fit( | |
| x_dot=None, | ||
| u=None, | ||
| feature_names: Optional[list[str]] = None, | ||
| sample_weight=None, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Provide static type |
||
| ): | ||
| """ | ||
| Fit a SINDy model. | ||
|
|
@@ -342,6 +349,11 @@ def fit( | |
| Names for the input features (e.g. :code:`['x', 'y', 'z']`). | ||
| If None, will use :code:`['x0', 'x1', ...]`. | ||
|
|
||
| sample_weight : float or array-like of shape (n_samples,), optional | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why allow a single float? Also, shouldn't this be (*n_spatial, n_time)? Or a list of those arrays, for when multiple trajectories are set? |
||
| Per-sample weights for the regression. Passed internally to | ||
| the optimizer (e.g. STLSQ). Supports compatibility with | ||
| scikit-learn tools such as GridSearchCV when using weighted data. | ||
|
|
||
| Returns | ||
| ------- | ||
| self: a fitted :class:`SINDy` instance | ||
|
|
@@ -370,14 +382,24 @@ def fit( | |
|
|
||
| self.feature_names = feature_names | ||
|
|
||
| if sample_weight is not None: | ||
| mode = ( | ||
| "weak" | ||
| if "Weak" in self.feature_library.__class__.__name__ | ||
| else "standard" | ||
| ) | ||
| sample_weight = _expand_sample_weights( | ||
| sample_weight, x, feature_library=self.feature_library, mode=mode | ||
| ) | ||
|
|
||
| steps = [ | ||
| ("features", self.feature_library), | ||
| ("shaping", SampleConcatter()), | ||
| ("model", self.optimizer), | ||
| ] | ||
| x_dot = concat_sample_axis(x_dot) | ||
| self.model = Pipeline(steps) | ||
| self.model.fit(x, x_dot) | ||
| self.model.fit(x, x_dot, sample_weight=sample_weight) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's remove the Pipeline and directly call the components. That way, we don't need to rely on metadata routing internally |
||
| self._fit_shape() | ||
|
|
||
| return self | ||
|
|
@@ -411,6 +433,7 @@ def predict(self, x, u=None): | |
| x, _, u = _comprehend_and_validate_inputs(x, 1, None, u, self.feature_library) | ||
|
|
||
| check_is_fitted(self, "model") | ||
|
|
||
| if self.n_control_features_ > 0 and u is None: | ||
| raise TypeError("Model was fit using control variables, so u is required") | ||
| if self.n_control_features_ == 0 and u is not None: | ||
|
|
@@ -466,7 +489,16 @@ def print(self, lhs=None, precision=3, **kwargs): | |
| names = f"{lhs[i]}" | ||
| print(f"{names} = {eqn}", **kwargs) | ||
|
|
||
| def score(self, x, t, x_dot=None, u=None, metric=r2_score, **metric_kws): | ||
| def score( | ||
| self, | ||
| x, | ||
| t, | ||
| x_dot=None, | ||
| u=None, | ||
| metric=r2_score, | ||
| sample_weight=None, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, static typing |
||
| **metric_kws, | ||
| ): | ||
| """ | ||
| Returns a score for the time derivative prediction produced by the model. | ||
|
|
||
|
|
@@ -500,9 +532,14 @@ def score(self, x, t, x_dot=None, u=None, metric=r2_score, **metric_kws): | |
| <https://scikit-learn.org/stable/modules/model_evaluation.html>`_ | ||
| for more options. | ||
|
|
||
| sample_weight : array-like of shape (n_samples,), optional | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ibid - shape |
||
| Per-sample weights passed directly to the metric. This is the | ||
| preferred way to supply weights. | ||
|
|
||
| metric_kws: dict, optional | ||
| Optional keyword arguments to pass to the metric function. | ||
|
|
||
|
|
||
| Returns | ||
| ------- | ||
| score: float | ||
|
|
@@ -522,10 +559,21 @@ def score(self, x, t, x_dot=None, u=None, metric=r2_score, **metric_kws): | |
|
|
||
| x, x_dot = self._process_trajectories(x, t, x_dot) | ||
|
|
||
| if sample_weight is not None: | ||
| sample_weight = _expand_sample_weights(sample_weight, x) | ||
|
|
||
| x_dot = concat_sample_axis(x_dot) | ||
| x_dot_predict = concat_sample_axis(x_dot_predict) | ||
|
|
||
| x_dot, x_dot_predict = drop_nan_samples(x_dot, x_dot_predict) | ||
| if sample_weight is not None: | ||
| x_dot, x_dot_predict, good_idx = drop_nan_samples( | ||
| x_dot, x_dot_predict, return_indices=True | ||
| ) | ||
| sample_weight = sample_weight[good_idx] | ||
| metric_kws = {**metric_kws, "sample_weight": sample_weight} | ||
| else: | ||
| x_dot, x_dot_predict = drop_nan_samples(x_dot, x_dot_predict) | ||
|
|
||
| return metric(x_dot, x_dot_predict, **metric_kws) | ||
|
|
||
| def _process_trajectories(self, x, t, x_dot): | ||
|
|
@@ -909,3 +957,116 @@ def comprehend_and_validate(arr, t): | |
| ) | ||
| u = [comprehend_and_validate(ui, ti) for ui, ti in _zip_like_sequence(u, t)] | ||
| return x, x_dot, u | ||
|
|
||
|
|
||
| def _expand_sample_weights( | ||
| sample_weight, trajectories, feature_library=None, mode="standard" | ||
| ): | ||
|
Comment on lines
+962
to
+964
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you clarify what "expand" means? |
||
| """ | ||
| Expand per-trajectory or per-sample weights for use in SINDy estimators. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| sample_weight : sequence of scalars or array-like | ||
| Weights for each trajectory. In "standard" mode, each entry can be: | ||
| - a scalar weight (applied to all samples in that trajectory), or | ||
| - an array of length equal to the number of samples (n_time) for that | ||
| trajectory. | ||
| In "weak" mode, each entry must be a single scalar weight per trajectory. | ||
|
|
||
| trajectories : sequence | ||
| Sequence of trajectory-like objects, each having attributes `n_time` and | ||
| `n_coord`. | ||
|
|
||
| feature_library : object, optional | ||
| Library instance used in weak-form mode. Must define attribute `K` | ||
| (the number of weak test functions). If missing, assumes K=1 with a warning. | ||
|
|
||
| mode : {'standard', 'weak'}, default='standard' | ||
| - "standard": Expand per-sample weights to match concatenated samples. | ||
| - "weak": Repeat each trajectory’s single scalar weight `K` times. | ||
|
|
||
| Returns | ||
| ------- | ||
| np.ndarray or None | ||
| A 1D numpy array of concatenated and expanded sample weights, | ||
| or None if `sample_weight` is None. | ||
| """ | ||
| # ------------------------------------------------------------- | ||
| # Early exit for None | ||
| # ------------------------------------------------------------- | ||
| if sample_weight is None: | ||
| return None | ||
|
|
||
| if not ( | ||
| isinstance(sample_weight, Sequence) | ||
| and not isinstance(sample_weight, np.ndarray) | ||
|
Comment on lines
+1002
to
+1003
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't believe np.ndarrays are |
||
| ): | ||
| raise ValueError( | ||
| "sample_weight must be a list or tuple, not a scalar or numpy array." | ||
| ) | ||
|
|
||
| if len(sample_weight) != len(trajectories): | ||
| raise ValueError("sample_weight length must match number of trajectories.") | ||
|
|
||
| # ------------------------------------------------------------- | ||
| # Weak mode: one weight per trajectory, repeated K times | ||
| # ------------------------------------------------------------- | ||
| if mode == "weak": | ||
| if feature_library is None: | ||
| raise ValueError("feature_library is required in weak mode.") | ||
|
|
||
| K = getattr(feature_library, "K", None) | ||
| if K is None: | ||
| warnings.warn("feature_library missing 'K'; assuming K=1.", UserWarning) | ||
| K = 1 | ||
|
|
||
| validated = [] | ||
| for w, traj in zip(sample_weight, trajectories): | ||
| arr = np.asarray(w) | ||
| if arr.ndim > 0 and arr.size > 1: | ||
| raise ValueError( | ||
| "Weak mode expects exactly one weight per trajectory (scalar), " | ||
| f"but got shape {arr.shape} for trajectory with {traj.n_time}" | ||
| f"samples." | ||
| ) | ||
| validated.append(float(arr)) | ||
| return np.repeat(validated, K) | ||
|
|
||
| # ------------------------------------------------------------- | ||
| # Standard mode: expand scalars or per-sample arrays | ||
| # ------------------------------------------------------------- | ||
| expanded = [] | ||
| for w, traj in zip(sample_weight, trajectories): | ||
| arr = np.asarray(w) | ||
|
|
||
| # Scalar → expand to all samples in trajectory | ||
| if arr.ndim == 0: | ||
| arr = np.full(traj.n_time, arr, dtype=float) | ||
|
|
||
| # 1D array → must match number of samples | ||
| elif arr.ndim == 1: | ||
| if arr.shape[0] != traj.n_time: | ||
| raise ValueError( | ||
| f"sample_weight length {arr.shape[0]} does" | ||
| f" not match trajectory length {traj.n_time}." | ||
| ) | ||
|
|
||
| # 2D array → only (n,1) allowed | ||
| elif arr.ndim == 2: | ||
| if arr.shape[1] != 1: | ||
| raise ValueError( | ||
| "sample_weight 2D arrays must have second dimension = 1." | ||
| ) | ||
| if arr.shape[0] != traj.n_time: | ||
| raise ValueError( | ||
| "sample_weight 2D array length does not match trajectory length." | ||
| ) | ||
| arr = arr.ravel() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why |
||
|
|
||
| else: | ||
| raise ValueError("Invalid sample_weight shape.") | ||
|
|
||
| expanded.append(arr.ravel()) | ||
|
|
||
| return np.concatenate(expanded) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As I mentioned, if you're doing reshaping to match axes, this should probably go in |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leave this up to the user. scikit-learn documentation explains that if they're using a pipeline or composite transform, they'll need this.