Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 121 additions & 53 deletions aeon/classification/distance_based/_elastic_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,62 +251,130 @@ def _fit(self, X, y):
f"Currently evaluating {self._distance_measures[dm]}"
)

# If 100 parameter options are being considered per measure,
# use a GridSearchCV
if self.proportion_of_param_options == 1:
grid = GridSearchCV(
estimator=KNeighborsTimeSeriesClassifier(
distance=this_measure, n_neighbors=1
),
param_grid=ElasticEnsemble._get_100_param_options(
self._distance_measures[dm], X
),
cv=LeaveOneOut(),
scoring="accuracy",
n_jobs=self._n_jobs,
verbose=self.verbose,
)
grid.fit(param_train_to_use, param_train_y)
best_distance_params = None
acc = 1.0 # Default for majority vote

# Optimized path:
# If we use all training data for param finding AND
# we are not using majority_vote (which needs weighting),
# we can combine the param search and accuracy estimation
# into a single loop to avoid the redundant CV pass.
if self.proportion_train_in_param_finding == 1.0 and not self.majority_vote:
if self.verbose > 0:
print( # noqa: T201
f"Using optimized manual CV path for "
f"{self._distance_measures[dm]}"
)

# Else, used RandomizedSearchCV to randomly sample parameter
# options for each measure
else:
grid = RandomizedSearchCV(
estimator=KNeighborsTimeSeriesClassifier(
distance=this_measure, n_neighbors=1
),
param_distributions=ElasticEnsemble._get_100_param_options(
self._distance_measures[dm], X
),
n_iter=math.ceil(100 * self.proportion_of_param_options),
cv=LeaveOneOut(),
scoring="accuracy",
n_jobs=self._n_jobs,
random_state=rand,
verbose=self.verbose,
param_grid = ElasticEnsemble._get_100_param_options(
self._distance_measures[dm], X
)
grid.fit(param_train_to_use, param_train_y)

if self.majority_vote:
acc = 1
# once the best parameter option has been estimated on the
# training data, perform a final pass with this parameter option
# to get the individual predictions with cross_cal_predict (
# Note: optimisation potentially possible here if a GridSearchCV
# was used previously. TO-DO: determine how to extract
# predictions for the best param option from GridSearchCV)
all_params = param_grid["distance_params"]

# Handle randomized search (proportion_of_param_options)
if self.proportion_of_param_options < 1:
n_iter = math.ceil(
len(all_params) * self.proportion_of_param_options
)
# Use a copy and shuffle to mimic RandomizedSearchCV
params_to_search = all_params.copy()
rand.shuffle(params_to_search)
params_to_search = params_to_search[:n_iter]
else:
params_to_search = all_params

best_acc = -1.0
best_distance_params = None

for params in params_to_search:
model = KNeighborsTimeSeriesClassifier(
n_neighbors=1,
distance=this_measure,
distance_params=params,
n_jobs=self._n_jobs,
)
# This CV is run on the FULL training set
preds = cross_val_predict(
model, full_train_to_use, y, cv=LeaveOneOut()
)
current_acc = accuracy_score(y, preds)

if current_acc > best_acc:
best_acc = current_acc
best_distance_params = params

acc = best_acc # Set the final accuracy for weighting

# Standard (original) path:
# This path is used if:
# 1. We are using a SUBSET of data for param finding, OR
# 2. We ARE using majority_vote.
else:
best_model = KNeighborsTimeSeriesClassifier(
n_neighbors=1,
distance=this_measure,
distance_params=grid.best_params_["distance_params"],
n_jobs=self._n_jobs,
)
preds = cross_val_predict(
best_model, full_train_to_use, y, cv=LeaveOneOut()
)
acc = accuracy_score(y, preds)
if self.verbose > 0:
print( # noqa: T201
f"Using standard GridSearchCV/RandomizedSearchCV "
f"path for {self._distance_measures[dm]}"
)

# If 100 parameter options are being considered per measure,
# use a GridSearchCV
if self.proportion_of_param_options == 1:
grid = GridSearchCV(
estimator=KNeighborsTimeSeriesClassifier(
distance=this_measure, n_neighbors=1
),
param_grid=ElasticEnsemble._get_100_param_options(
self._distance_measures[dm], X
),
cv=LeaveOneOut(),
scoring="accuracy",
n_jobs=self._n_jobs,
verbose=self.verbose,
)
grid.fit(param_train_to_use, param_train_y)

# Else, used RandomizedSearchCV to randomly sample parameter
# options for each measure
else:
grid = RandomizedSearchCV(
estimator=KNeighborsTimeSeriesClassifier(
distance=this_measure, n_neighbors=1
),
param_distributions=ElasticEnsemble._get_100_param_options(
self._distance_measures[dm], X
),
n_iter=math.ceil(100 * self.proportion_of_param_options),
cv=LeaveOneOut(),
scoring="accuracy",
n_jobs=self._n_jobs,
random_state=rand,
verbose=self.verbose,
)
grid.fit(param_train_to_use, param_train_y)

best_distance_params = grid.best_params_["distance_params"]

if self.majority_vote:
acc = 1
# once the best parameter option has been estimated on the
# training data, perform a final pass with this parameter option
# to get the individual predictions with cross_cal_predict (
# Note: optimisation potentially possible here if a GridSearchCV
# was used previously. TO-DO: determine how to extract
# predictions for the best param option from GridSearchCV)
else:
best_model = KNeighborsTimeSeriesClassifier(
n_neighbors=1,
distance=this_measure,
distance_params=grid.best_params_["distance_params"],
n_jobs=self._n_jobs,
)
preds = cross_val_predict(
best_model, full_train_to_use, y, cv=LeaveOneOut()
)
acc = accuracy_score(y, preds)

# Common code for both paths
if self.verbose > 0:
print( # noqa: T201
f"Training acc for {self._distance_measures[dm]}: {acc}"
Expand All @@ -317,7 +385,7 @@ def _fit(self, X, y):
best_model = KNeighborsTimeSeriesClassifier(
n_neighbors=1,
distance=this_measure,
distance_params=grid.best_params_["distance_params"],
distance_params=best_distance_params,
)
best_model.fit(full_train_to_use, y)
end_build_time = time.time()
Expand Down
1 change: 1 addition & 0 deletions docs/changelogs/v1.3.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ September 2025

### Enhancements

- [ENH] Optimize `ElasticEnsemble` `_fit` to avoid redundant cross-validation** ({pr}`3109`) {user}`Nithurshen`
- [ENH] Improvements to ST transformer and classifier ({pr}`2968`) {user}`MatthewMiddlehurst`
- [ENH] KNN n_jobs and updated kneighbours method ({pr}`2578`) {user}`chrisholder`
- [ENH] Refactor signature code ({pr}`2943`) {user}`TonyBagnall`
Expand Down