Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
33c1ff5
A re-written _inc_exc_datasets() that fixes issues and provides more …
toncho11 Nov 10, 2025
e4d7fc6
Add another improvement - handling of the paradigms in moabb.
toncho11 Nov 10, 2025
109ad92
small typo
toncho11 Nov 10, 2025
0ae75e8
A number of fixes and improvements.
toncho11 Nov 10, 2025
3df21c6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 10, 2025
fdf8ea7
small docstring update
toncho11 Nov 10, 2025
ce0729e
Merge branch 'improve_fix_inc_exc_datasets_in_benchmark' of https://g…
toncho11 Nov 10, 2025
9a4853c
small comment clarification
toncho11 Nov 10, 2025
9193304
Improved tests
toncho11 Nov 12, 2025
1ca36c7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 12, 2025
7bc64bf
optuna test enabled
toncho11 Nov 12, 2025
6652208
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 12, 2025
d8cf34b
There are 8 tests. This has been a lot of effort.
toncho11 Nov 12, 2025
b7215ab
Merge branch 'develop' into improve_fix_inc_exc_datasets_in_benchmark
toncho11 Nov 13, 2025
8733fe4
updated whats_new file
toncho11 Nov 13, 2025
58a5da3
Improved documentation
toncho11 Nov 13, 2025
210be29
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 13, 2025
c7c05b6
comments updated
toncho11 Nov 13, 2025
abd52c1
Multiple paradigms in the parameter "paradigms" are handled better now.
toncho11 Nov 13, 2025
786d6d9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 13, 2025
ebd3916
small improvements
toncho11 Nov 13, 2025
ba08834
All fake datasets are set to 2 subjects to reduce execution time.
toncho11 Nov 13, 2025
2c7bd0a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 13, 2025
3b15ef2
Update LogVar.yml
gcattan Nov 15, 2025
baa51c3
Update SSVEP_CCA.yml
gcattan Nov 15, 2025
af28fba
Update CSP.yml
gcattan Nov 15, 2025
42a06d0
Update test_benchmark.py
gcattan Nov 15, 2025
b0cee37
Update fake.py
gcattan Nov 15, 2025
da811f8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 15, 2025
11fb14b
Update __init__.py
gcattan Nov 15, 2025
0696f6b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 15, 2025
a52b834
Update test_benchmark.py
gcattan Nov 15, 2025
2f0be0c
Update test_benchmark.py
gcattan Nov 15, 2025
f6e5c91
Update test_benchmark.py
gcattan Nov 15, 2025
33921ab
Update test_benchmark.py
gcattan Nov 15, 2025
ba033e6
Update test_benchmark.py
gcattan Nov 15, 2025
ec6820a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 15, 2025
105915e
Update fake.py
gcattan Nov 15, 2025
caf2589
Update __init__.py
gcattan Nov 15, 2025
5db6c8b
Update test_benchmark.py
gcattan Nov 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Enhancements
~~~~~~~~~~~~

Bugs
- Fixes the management of include/exclude datasets in :func:`moabb.benchmark`, adds additional verifications (:gh:`834` by ``Anton Andreev`_)
~~~~
API changes
Expand Down
234 changes: 188 additions & 46 deletions moabb/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from moabb import paradigms as moabb_paradigms
from moabb.analysis import analyze
from moabb.datasets.base import BaseDataset
from moabb.datasets.fake import FakeDataset
from moabb.evaluations import (
CrossSessionEvaluation,
CrossSubjectEvaluation,
Expand Down Expand Up @@ -54,7 +56,7 @@ def benchmark( # noqa: C901
possible to include or exclude specific datasets and to choose the type of
evaluation.
If particular paradigms are mentioned through select_paradigms, only the pipelines corresponding to those paradigms
If particular paradigms are mentioned through parameter paradigms, only the pipelines corresponding to those paradigms
will be run. If no paradigms are mentioned, all pipelines will be run.
To define the include_datasets or exclude_dataset, you could start from the full dataset list,
Expand Down Expand Up @@ -110,11 +112,11 @@ def benchmark( # noqa: C901
File path to context.yml file that describes context parameters.
If none, assumes all defaults. Must contain an entry for all
paradigms described in the pipelines.
include_datasets: list of str or Dataset object
include_datasets: list of str Dataset codes or Dataset objects
Datasets (dataset.code or object) to include in the benchmark run.
By default, all suitable datasets are included. If both include_datasets
and exclude_datasets are specified, raise an error.
exclude_datasets: list of str or Dataset object
exclude_datasets: list of str Dataset codes or Dataset objects
Datasets to exclude from the benchmark run
optuna: Enable Optuna for the hyperparameter search
Expand Down Expand Up @@ -160,35 +162,41 @@ def benchmark( # noqa: C901
with _open_lock(contexts, "r") as cfile:
context_params = yaml.load(cfile.read(), Loader=yaml.FullLoader)

prdgms = generate_paradigms(pipeline_configs, context_params, log)
prdgms_from_pipelines = generate_paradigms(pipeline_configs, context_params, log)

# Filter requested benchmark paradigms vs available in provided pipelines
if paradigms is not None:
prdgms = {p: prdgms[p] for p in paradigms}
prdgms_from_pipelines = {p: prdgms_from_pipelines[p] for p in paradigms}

param_grid = generate_param_grid(pipeline_configs, context_params, log)

log.debug(f"The paradigms being run are {prdgms.keys()}")
print(f"The paradigms being run are {prdgms_from_pipelines.keys()}")

if len(context_params) == 0:
for paradigm in prdgms:
for paradigm in prdgms_from_pipelines:
context_params[paradigm] = {}

# Looping over the evaluations to be done
df_eval = []
for evaluation in evaluations:
eval_results = dict()
for paradigm in prdgms:

for paradigm in prdgms_from_pipelines:
# get the context
log.debug(f"{paradigm}: {context_params[paradigm]}")
p = getattr(moabb_paradigms, paradigm)(**context_params[paradigm])
# List of dataset class instances
datasets = p.datasets
d = _inc_exc_datasets(datasets, include_datasets, exclude_datasets)
log.debug(
f"Datasets considered for {paradigm} paradigm {[dt.code for dt in d]}"
p = _get_paradigm_instance(paradigm, context_params)
# List of dataset class instances, handles FakeDatasets as well
datasets = (
p.datasets
+ [ds for ds in (include_datasets or []) if isinstance(ds, FakeDataset)]
if any(isinstance(ds, FakeDataset) for ds in (include_datasets or []))
else p.datasets
)
d = _inc_exc_datasets(datasets, include_datasets, exclude_datasets)
print(f"Datasets considered for {paradigm} paradigm {[dt.code for dt in d]}")

ppl_with_epochs, ppl_with_array = {}, {}
for pn, pv in prdgms[paradigm].items():
for pn, pv in prdgms_from_pipelines[paradigm].items():
ppl_with_array[pn] = pv

if len(ppl_with_epochs) > 0:
Expand Down Expand Up @@ -317,37 +325,171 @@ def _save_results(eval_results, output, plot):
analyze(prdgm_result, str(prdgm_path), plot=plot)


def _inc_exc_datasets(datasets, include_datasets, exclude_datasets):
d = list()
def _inc_exc_datasets(paradigm_datasets, include_datasets=None, exclude_datasets=None):
"""
Filter datasets based on include_datasets and exclude_datasets.
Parameters
----------
datasets : list
List of dataset class instances (each with a `.code` attribute).
include_datasets : list[str or Dataset], optional
List of dataset codes or dataset class instances to include.
exclude_datasets : list[str or Dataset], optional
List of dataset codes or dataset class instances to exclude.
Returns
-------
list
Filtered list of dataset class instances.
"""
# --- Safety checks ---
if include_datasets is not None and exclude_datasets is not None:
raise ValueError("Cannot specify both include_datasets and exclude_datasets.")

all_paradigm_codes = [ds.code for ds in paradigm_datasets]
d = list(paradigm_datasets)

# --- Inclusion logic ---
if include_datasets is not None:
# Assert if the inputs are key_codes
if isinstance(include_datasets[0], str):
# Map from key_codes to class instances
datasets_codes = [d.code for d in datasets]
# Get the indices of the matching datasets
for incdat in include_datasets:
if incdat in datasets_codes:
d.append(datasets[datasets_codes.index(incdat)])
else:
# The case where the class instances have been given
# can be passed on directly
d = list(include_datasets)
if exclude_datasets is not None:
raise AttributeError(
"You could not specify both include and exclude datasets"
)
include_codes = _validate_list_per_paradigm(
all_paradigm_codes, include_datasets, "include_datasets"
)
# Keep only included datasets
filtered = [ds for ds in paradigm_datasets if ds.code in include_codes]
return filtered

# --- Exclusion logic ---
if exclude_datasets is not None:
exclude_codes = _validate_list_per_paradigm(
all_paradigm_codes, exclude_datasets, "exclude_datasets"
)
# Remove excluded datasets
filtered = [ds for ds in paradigm_datasets if ds.code not in exclude_codes]
return filtered

elif exclude_datasets is not None:
d = list(datasets)
# Assert if the inputs are not key_codes i.e. expected to be dataset class objects
if not isinstance(exclude_datasets[0], str):
# Convert the input to key_codes
exclude_datasets = [e.code for e in exclude_datasets]

# Map from key_codes to class instances
datasets_codes = [d.code for d in datasets]
for excdat in exclude_datasets:
del d[datasets_codes.index(excdat)]
else:
d = list(datasets)
return d


def _get_paradigm_instance(paradigm_name, context_params=None):
"""
Get a paradigm instance from moabb.paradigms by name (case-insensitive).
Parameters
----------
paradigm_name : str
Name of the paradigm to look up (e.g., 'P300', 'MotorImagery').
context_params : dict, optional
Dictionary mapping paradigm names to parameters.
Will pass context_params[paradigm_name] to the paradigm constructor if available.
Returns
-------
paradigm_instance : moabb.paradigms.BaseParadigm
Instance of the requested paradigm.
Raises
------
ValueError
If the paradigm name is not found in moabb.paradigms.
"""
context_params = context_params or {}

# Find matching class name in moabb.paradigms, case-insensitive
cls_name = next(
(name for name in dir(moabb_paradigms) if name.lower() == paradigm_name.lower()),
None,
)
if cls_name is None:
raise ValueError(
f"Paradigm '{paradigm_name}' not found in moabb.paradigms. "
f"Available paradigms: {[name for name in dir(moabb_paradigms) if not name.startswith('_')]}"
)

# Get the class and instantiate
ParadigmClass = getattr(moabb_paradigms, cls_name)
params = context_params.get(paradigm_name, {})
return ParadigmClass(**params)


def _validate_list_per_paradigm(all_paradigm_codes, ds_list, list_name):
"""
Validates a list of include/exclude datasets for specific paradigm.
Ensures the user-provided list of datasets is valid for use in the benchmark.
Allows dataset lists that contain entries for multiple paradigms, as long as
at least one dataset is compatible with the current paradigm.
Checks:
- The input is a list or tuple.
- The list is not empty.
- All elements are either strings or BaseDataset objects (no mix).
- No duplicates are present.
- At least one dataset matches the current paradigm’s available datasets (all_paradigm_codes).
- Fake datasets (codes starting with "FakeDataset") are always accepted.
Parameters
----------
ds_list : list[str or BaseDataset]
The list to validate. Can contain dataset codes (e.g., ["BNCI2014-001"])
or dataset objects (instances of BaseDataset or its subclasses).
list_name : str
The name of the list ("include_datasets" or "exclude_datasets"),
used only for error messages.
Returns
-------
list[str]
A normalized list of dataset codes extracted from the input list.
"""
if not isinstance(ds_list, (list, tuple)):
raise TypeError(f"{list_name} must be a list or tuple.")

# Empty list edge case
if len(ds_list) == 0:
raise ValueError(f"{list_name} cannot be an empty list.")

# Ensure homogeneity: all strings or all dataset objects
all_str = all(isinstance(x, str) for x in ds_list)
all_obj = all(isinstance(x, BaseDataset) for x in ds_list)
if not (all_str or all_obj):
raise TypeError(
f"{list_name} must contain either all strings or all dataset objects, not a mix."
)

# --- Handle case: list of dataset codes (strings) ---
if all_str:
# Check duplicates
if len(ds_list) != len(set(ds_list)):
raise ValueError(f"{list_name} contains duplicate dataset codes.")

# Accept all codes that belong to the current paradigm or are fake datasets
valid = [
x for x in ds_list if x in all_paradigm_codes or x.startswith("FakeDataset")
]
if not valid:
raise ValueError(
f"None of the datasets in {list_name} match the current paradigm’s datasets "
f"({all_paradigm_codes}). Provided: {ds_list}"
)
return valid

# --- Handle case: list of dataset objects ---
codes = [x.code for x in ds_list]
if len(codes) != len(set(codes)):
raise ValueError(f"{list_name} contains duplicate dataset instances.")

# Accept all dataset objects that belong to the current paradigm or are fake datasets
valid = [
x.code
for x in ds_list
if x.code in all_paradigm_codes or x.code.startswith("FakeDataset")
]
if not valid:
raise ValueError(
f"None of the dataset objects in {list_name} match the current paradigm’s datasets "
f"({all_paradigm_codes}). Provided: {[x.code for x in ds_list]}"
)
return valid
Loading
Loading