Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 164 additions & 3 deletions pysindy/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
from scipy.integrate import odeint
from scipy.integrate import solve_ivp
from scipy.interpolate import interp1d
from sklearn import set_config
from sklearn.base import BaseEstimator
from sklearn.metrics import r2_score
from sklearn.pipeline import Pipeline
from sklearn.utils.validation import check_is_fitted

set_config(enable_metadata_routing=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leave this up to the user. scikit-learn documentation explains that if they're using a pipeline or composite transform, they'll need this.

from typing_extensions import Self

from .differentiation import BaseDifferentiation
Expand Down Expand Up @@ -293,6 +296,9 @@ def __init__(
differentiation_method = FiniteDifference(axis=-2)
self.differentiation_method = differentiation_method
self.discrete_time = discrete_time
self.set_fit_request(sample_weight=True)
self.set_score_request(sample_weight=True)
self.optimizer.set_fit_request(sample_weight=True)
Comment on lines +299 to +301
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, leave this up to the user


def fit(
self,
Expand All @@ -301,6 +307,7 @@ def fit(
x_dot=None,
u=None,
feature_names: Optional[list[str]] = None,
sample_weight=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Provide static type

):
"""
Fit a SINDy model.
Expand Down Expand Up @@ -342,6 +349,11 @@ def fit(
Names for the input features (e.g. :code:`['x', 'y', 'z']`).
If None, will use :code:`['x0', 'x1', ...]`.

sample_weight : float or array-like of shape (n_samples,), optional
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why allow a single float?

Also, shouldn't this be (*n_spatial, n_time)? Or a list of those arrays, for when multiple trajectories are set?

Per-sample weights for the regression. Passed internally to
the optimizer (e.g. STLSQ). Supports compatibility with
scikit-learn tools such as GridSearchCV when using weighted data.

Returns
-------
self: a fitted :class:`SINDy` instance
Expand Down Expand Up @@ -370,14 +382,24 @@ def fit(

self.feature_names = feature_names

if sample_weight is not None:
mode = (
"weak"
if "Weak" in self.feature_library.__class__.__name__
else "standard"
)
sample_weight = _expand_sample_weights(
sample_weight, x, feature_library=self.feature_library, mode=mode
)

steps = [
("features", self.feature_library),
("shaping", SampleConcatter()),
("model", self.optimizer),
]
x_dot = concat_sample_axis(x_dot)
self.model = Pipeline(steps)
self.model.fit(x, x_dot)
self.model.fit(x, x_dot, sample_weight=sample_weight)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove the Pipeline and directly call the components. That way, we don't need to rely on metadata routing internally

self._fit_shape()

return self
Expand Down Expand Up @@ -411,6 +433,7 @@ def predict(self, x, u=None):
x, _, u = _comprehend_and_validate_inputs(x, 1, None, u, self.feature_library)

check_is_fitted(self, "model")

if self.n_control_features_ > 0 and u is None:
raise TypeError("Model was fit using control variables, so u is required")
if self.n_control_features_ == 0 and u is not None:
Expand Down Expand Up @@ -466,7 +489,16 @@ def print(self, lhs=None, precision=3, **kwargs):
names = f"{lhs[i]}"
print(f"{names} = {eqn}", **kwargs)

def score(self, x, t, x_dot=None, u=None, metric=r2_score, **metric_kws):
def score(
self,
x,
t,
x_dot=None,
u=None,
metric=r2_score,
sample_weight=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, static typing

**metric_kws,
):
"""
Returns a score for the time derivative prediction produced by the model.

Expand Down Expand Up @@ -500,9 +532,14 @@ def score(self, x, t, x_dot=None, u=None, metric=r2_score, **metric_kws):
<https://scikit-learn.org/stable/modules/model_evaluation.html>`_
for more options.

sample_weight : array-like of shape (n_samples,), optional
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ibid - shape

Per-sample weights passed directly to the metric. This is the
preferred way to supply weights.

metric_kws: dict, optional
Optional keyword arguments to pass to the metric function.


Returns
-------
score: float
Expand All @@ -522,10 +559,21 @@ def score(self, x, t, x_dot=None, u=None, metric=r2_score, **metric_kws):

x, x_dot = self._process_trajectories(x, t, x_dot)

if sample_weight is not None:
sample_weight = _expand_sample_weights(sample_weight, x)

x_dot = concat_sample_axis(x_dot)
x_dot_predict = concat_sample_axis(x_dot_predict)

x_dot, x_dot_predict = drop_nan_samples(x_dot, x_dot_predict)
if sample_weight is not None:
x_dot, x_dot_predict, good_idx = drop_nan_samples(
x_dot, x_dot_predict, return_indices=True
)
sample_weight = sample_weight[good_idx]
metric_kws = {**metric_kws, "sample_weight": sample_weight}
else:
x_dot, x_dot_predict = drop_nan_samples(x_dot, x_dot_predict)

return metric(x_dot, x_dot_predict, **metric_kws)

def _process_trajectories(self, x, t, x_dot):
Expand Down Expand Up @@ -909,3 +957,116 @@ def comprehend_and_validate(arr, t):
)
u = [comprehend_and_validate(ui, ti) for ui, ti in _zip_like_sequence(u, t)]
return x, x_dot, u


def _expand_sample_weights(
sample_weight, trajectories, feature_library=None, mode="standard"
):
Comment on lines +962 to +964
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify what "expand" means?

"""
Expand per-trajectory or per-sample weights for use in SINDy estimators.

Parameters
----------
sample_weight : sequence of scalars or array-like
Weights for each trajectory. In "standard" mode, each entry can be:
- a scalar weight (applied to all samples in that trajectory), or
- an array of length equal to the number of samples (n_time) for that
trajectory.
In "weak" mode, each entry must be a single scalar weight per trajectory.

trajectories : sequence
Sequence of trajectory-like objects, each having attributes `n_time` and
`n_coord`.

feature_library : object, optional
Library instance used in weak-form mode. Must define attribute `K`
(the number of weak test functions). If missing, assumes K=1 with a warning.

mode : {'standard', 'weak'}, default='standard'
- "standard": Expand per-sample weights to match concatenated samples.
- "weak": Repeat each trajectory’s single scalar weight `K` times.

Returns
-------
np.ndarray or None
A 1D numpy array of concatenated and expanded sample weights,
or None if `sample_weight` is None.
"""
# -------------------------------------------------------------
# Early exit for None
# -------------------------------------------------------------
if sample_weight is None:
return None

if not (
isinstance(sample_weight, Sequence)
and not isinstance(sample_weight, np.ndarray)
Comment on lines +1002 to +1003
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe np.ndarrays are Sequences, so the second is redundant. Also, in these cases, distribute the not across conditions

):
raise ValueError(
"sample_weight must be a list or tuple, not a scalar or numpy array."
)

if len(sample_weight) != len(trajectories):
raise ValueError("sample_weight length must match number of trajectories.")

# -------------------------------------------------------------
# Weak mode: one weight per trajectory, repeated K times
# -------------------------------------------------------------
if mode == "weak":
if feature_library is None:
raise ValueError("feature_library is required in weak mode.")

K = getattr(feature_library, "K", None)
if K is None:
warnings.warn("feature_library missing 'K'; assuming K=1.", UserWarning)
K = 1

validated = []
for w, traj in zip(sample_weight, trajectories):
arr = np.asarray(w)
if arr.ndim > 0 and arr.size > 1:
raise ValueError(
"Weak mode expects exactly one weight per trajectory (scalar), "
f"but got shape {arr.shape} for trajectory with {traj.n_time}"
f"samples."
)
validated.append(float(arr))
return np.repeat(validated, K)

# -------------------------------------------------------------
# Standard mode: expand scalars or per-sample arrays
# -------------------------------------------------------------
expanded = []
for w, traj in zip(sample_weight, trajectories):
arr = np.asarray(w)

# Scalar → expand to all samples in trajectory
if arr.ndim == 0:
arr = np.full(traj.n_time, arr, dtype=float)

# 1D array → must match number of samples
elif arr.ndim == 1:
if arr.shape[0] != traj.n_time:
raise ValueError(
f"sample_weight length {arr.shape[0]} does"
f" not match trajectory length {traj.n_time}."
)

# 2D array → only (n,1) allowed
elif arr.ndim == 2:
if arr.shape[1] != 1:
raise ValueError(
"sample_weight 2D arrays must have second dimension = 1."
)
if arr.shape[0] != traj.n_time:
raise ValueError(
"sample_weight 2D array length does not match trajectory length."
)
arr = arr.ravel()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why ravel()? This is an internal function, so don't try to help out the user. If data isn't shaped the way we expect it raise an error.


else:
raise ValueError("Invalid sample_weight shape.")

expanded.append(arr.ravel())

return np.concatenate(expanded)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I mentioned, if you're doing reshaping to match axes, this should probably go in SampleConcatter

2 changes: 2 additions & 0 deletions pysindy/feature_library/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .polynomial_library import PolynomialLibrary
from .sindy_pi_library import SINDyPILibrary
from .weak_pde_library import WeakPDELibrary
from .weighted_weak_pde_library import WeightedWeakPDELibrary

__all__ = [
"ConcatLibrary",
Expand All @@ -21,6 +22,7 @@
"PolynomialLibrary",
"PDELibrary",
"WeakPDELibrary",
"WeightedWeakPDELibrary",
"SINDyPILibrary",
"ParameterizedLibrary",
"base",
Expand Down
Loading
Loading