diff --git a/docs/source/whats_new.rst b/docs/source/whats_new.rst index f39597faf..7970796ed 100644 --- a/docs/source/whats_new.rst +++ b/docs/source/whats_new.rst @@ -21,6 +21,7 @@ Enhancements ~~~~~~~~~~~~ Bugs +- Fixes the management of include/exclude datasets in :func:`moabb.benchmark`, adds additional verifications (:gh:`834` by ``Anton Andreev`_) ~~~~ API changes diff --git a/moabb/benchmark.py b/moabb/benchmark.py index 2d3324f0a..6b84ec074 100644 --- a/moabb/benchmark.py +++ b/moabb/benchmark.py @@ -10,6 +10,8 @@ from moabb import paradigms as moabb_paradigms from moabb.analysis import analyze +from moabb.datasets.base import BaseDataset +from moabb.datasets.fake import FakeDataset from moabb.evaluations import ( CrossSessionEvaluation, CrossSubjectEvaluation, @@ -54,7 +56,7 @@ def benchmark( # noqa: C901 possible to include or exclude specific datasets and to choose the type of evaluation. - If particular paradigms are mentioned through select_paradigms, only the pipelines corresponding to those paradigms + If particular paradigms are mentioned through parameter paradigms, only the pipelines corresponding to those paradigms will be run. If no paradigms are mentioned, all pipelines will be run. To define the include_datasets or exclude_dataset, you could start from the full dataset list, @@ -110,11 +112,11 @@ def benchmark( # noqa: C901 File path to context.yml file that describes context parameters. If none, assumes all defaults. Must contain an entry for all paradigms described in the pipelines. - include_datasets: list of str or Dataset object + include_datasets: list of str Dataset codes or Dataset objects Datasets (dataset.code or object) to include in the benchmark run. By default, all suitable datasets are included. If both include_datasets and exclude_datasets are specified, raise an error. - exclude_datasets: list of str or Dataset object + exclude_datasets: list of str Dataset codes or Dataset objects Datasets to exclude from the benchmark run optuna: Enable Optuna for the hyperparameter search @@ -160,35 +162,41 @@ def benchmark( # noqa: C901 with _open_lock(contexts, "r") as cfile: context_params = yaml.load(cfile.read(), Loader=yaml.FullLoader) - prdgms = generate_paradigms(pipeline_configs, context_params, log) + prdgms_from_pipelines = generate_paradigms(pipeline_configs, context_params, log) + + # Filter requested benchmark paradigms vs available in provided pipelines if paradigms is not None: - prdgms = {p: prdgms[p] for p in paradigms} + prdgms_from_pipelines = {p: prdgms_from_pipelines[p] for p in paradigms} param_grid = generate_param_grid(pipeline_configs, context_params, log) - log.debug(f"The paradigms being run are {prdgms.keys()}") + print(f"The paradigms being run are {prdgms_from_pipelines.keys()}") if len(context_params) == 0: - for paradigm in prdgms: + for paradigm in prdgms_from_pipelines: context_params[paradigm] = {} # Looping over the evaluations to be done df_eval = [] for evaluation in evaluations: eval_results = dict() - for paradigm in prdgms: + + for paradigm in prdgms_from_pipelines: # get the context log.debug(f"{paradigm}: {context_params[paradigm]}") - p = getattr(moabb_paradigms, paradigm)(**context_params[paradigm]) - # List of dataset class instances - datasets = p.datasets - d = _inc_exc_datasets(datasets, include_datasets, exclude_datasets) - log.debug( - f"Datasets considered for {paradigm} paradigm {[dt.code for dt in d]}" + p = _get_paradigm_instance(paradigm, context_params) + # List of dataset class instances, handles FakeDatasets as well + datasets = ( + p.datasets + + [ds for ds in (include_datasets or []) if isinstance(ds, FakeDataset)] + if any(isinstance(ds, FakeDataset) for ds in (include_datasets or [])) + else p.datasets ) + d = _inc_exc_datasets(datasets, include_datasets, exclude_datasets) + print(f"Datasets considered for {paradigm} paradigm {[dt.code for dt in d]}") ppl_with_epochs, ppl_with_array = {}, {} - for pn, pv in prdgms[paradigm].items(): + for pn, pv in prdgms_from_pipelines[paradigm].items(): ppl_with_array[pn] = pv if len(ppl_with_epochs) > 0: @@ -317,37 +325,171 @@ def _save_results(eval_results, output, plot): analyze(prdgm_result, str(prdgm_path), plot=plot) -def _inc_exc_datasets(datasets, include_datasets, exclude_datasets): - d = list() +def _inc_exc_datasets(paradigm_datasets, include_datasets=None, exclude_datasets=None): + """ + Filter datasets based on include_datasets and exclude_datasets. + + Parameters + ---------- + datasets : list + List of dataset class instances (each with a `.code` attribute). + include_datasets : list[str or Dataset], optional + List of dataset codes or dataset class instances to include. + exclude_datasets : list[str or Dataset], optional + List of dataset codes or dataset class instances to exclude. + + Returns + ------- + list + Filtered list of dataset class instances. + + """ + # --- Safety checks --- + if include_datasets is not None and exclude_datasets is not None: + raise ValueError("Cannot specify both include_datasets and exclude_datasets.") + + all_paradigm_codes = [ds.code for ds in paradigm_datasets] + d = list(paradigm_datasets) + + # --- Inclusion logic --- if include_datasets is not None: - # Assert if the inputs are key_codes - if isinstance(include_datasets[0], str): - # Map from key_codes to class instances - datasets_codes = [d.code for d in datasets] - # Get the indices of the matching datasets - for incdat in include_datasets: - if incdat in datasets_codes: - d.append(datasets[datasets_codes.index(incdat)]) - else: - # The case where the class instances have been given - # can be passed on directly - d = list(include_datasets) - if exclude_datasets is not None: - raise AttributeError( - "You could not specify both include and exclude datasets" - ) + include_codes = _validate_list_per_paradigm( + all_paradigm_codes, include_datasets, "include_datasets" + ) + # Keep only included datasets + filtered = [ds for ds in paradigm_datasets if ds.code in include_codes] + return filtered + + # --- Exclusion logic --- + if exclude_datasets is not None: + exclude_codes = _validate_list_per_paradigm( + all_paradigm_codes, exclude_datasets, "exclude_datasets" + ) + # Remove excluded datasets + filtered = [ds for ds in paradigm_datasets if ds.code not in exclude_codes] + return filtered - elif exclude_datasets is not None: - d = list(datasets) - # Assert if the inputs are not key_codes i.e. expected to be dataset class objects - if not isinstance(exclude_datasets[0], str): - # Convert the input to key_codes - exclude_datasets = [e.code for e in exclude_datasets] - - # Map from key_codes to class instances - datasets_codes = [d.code for d in datasets] - for excdat in exclude_datasets: - del d[datasets_codes.index(excdat)] - else: - d = list(datasets) return d + + +def _get_paradigm_instance(paradigm_name, context_params=None): + """ + Get a paradigm instance from moabb.paradigms by name (case-insensitive). + + Parameters + ---------- + paradigm_name : str + Name of the paradigm to look up (e.g., 'P300', 'MotorImagery'). + context_params : dict, optional + Dictionary mapping paradigm names to parameters. + Will pass context_params[paradigm_name] to the paradigm constructor if available. + + Returns + ------- + paradigm_instance : moabb.paradigms.BaseParadigm + Instance of the requested paradigm. + + Raises + ------ + ValueError + If the paradigm name is not found in moabb.paradigms. + """ + context_params = context_params or {} + + # Find matching class name in moabb.paradigms, case-insensitive + cls_name = next( + (name for name in dir(moabb_paradigms) if name.lower() == paradigm_name.lower()), + None, + ) + if cls_name is None: + raise ValueError( + f"Paradigm '{paradigm_name}' not found in moabb.paradigms. " + f"Available paradigms: {[name for name in dir(moabb_paradigms) if not name.startswith('_')]}" + ) + + # Get the class and instantiate + ParadigmClass = getattr(moabb_paradigms, cls_name) + params = context_params.get(paradigm_name, {}) + return ParadigmClass(**params) + + +def _validate_list_per_paradigm(all_paradigm_codes, ds_list, list_name): + """ + Validates a list of include/exclude datasets for specific paradigm. + + Ensures the user-provided list of datasets is valid for use in the benchmark. + Allows dataset lists that contain entries for multiple paradigms, as long as + at least one dataset is compatible with the current paradigm. + + Checks: + - The input is a list or tuple. + - The list is not empty. + - All elements are either strings or BaseDataset objects (no mix). + - No duplicates are present. + - At least one dataset matches the current paradigm’s available datasets (all_paradigm_codes). + - Fake datasets (codes starting with "FakeDataset") are always accepted. + + Parameters + ---------- + ds_list : list[str or BaseDataset] + The list to validate. Can contain dataset codes (e.g., ["BNCI2014-001"]) + or dataset objects (instances of BaseDataset or its subclasses). + list_name : str + The name of the list ("include_datasets" or "exclude_datasets"), + used only for error messages. + + Returns + ------- + list[str] + A normalized list of dataset codes extracted from the input list. + + """ + if not isinstance(ds_list, (list, tuple)): + raise TypeError(f"{list_name} must be a list or tuple.") + + # Empty list edge case + if len(ds_list) == 0: + raise ValueError(f"{list_name} cannot be an empty list.") + + # Ensure homogeneity: all strings or all dataset objects + all_str = all(isinstance(x, str) for x in ds_list) + all_obj = all(isinstance(x, BaseDataset) for x in ds_list) + if not (all_str or all_obj): + raise TypeError( + f"{list_name} must contain either all strings or all dataset objects, not a mix." + ) + + # --- Handle case: list of dataset codes (strings) --- + if all_str: + # Check duplicates + if len(ds_list) != len(set(ds_list)): + raise ValueError(f"{list_name} contains duplicate dataset codes.") + + # Accept all codes that belong to the current paradigm or are fake datasets + valid = [ + x for x in ds_list if x in all_paradigm_codes or x.startswith("FakeDataset") + ] + if not valid: + raise ValueError( + f"None of the datasets in {list_name} match the current paradigm’s datasets " + f"({all_paradigm_codes}). Provided: {ds_list}" + ) + return valid + + # --- Handle case: list of dataset objects --- + codes = [x.code for x in ds_list] + if len(codes) != len(set(codes)): + raise ValueError(f"{list_name} contains duplicate dataset instances.") + + # Accept all dataset objects that belong to the current paradigm or are fake datasets + valid = [ + x.code + for x in ds_list + if x.code in all_paradigm_codes or x.code.startswith("FakeDataset") + ] + if not valid: + raise ValueError( + f"None of the dataset objects in {list_name} match the current paradigm’s datasets " + f"({all_paradigm_codes}). Provided: {[x.code for x in ds_list]}" + ) + return valid diff --git a/moabb/tests/test_benchmark.py b/moabb/tests/test_benchmark.py index 5425c1b3a..2c4b01963 100644 --- a/moabb/tests/test_benchmark.py +++ b/moabb/tests/test_benchmark.py @@ -6,6 +6,7 @@ from moabb import benchmark from moabb.datasets.fake import FakeDataset from moabb.evaluations.base import optuna_available +from moabb.paradigms import FakeImageryParadigm, FakeP300Paradigm class TestBenchmark: @@ -22,64 +23,91 @@ def test_benchmark_strdataset(self): res = benchmark( pipelines=str(self.pp_dir), evaluations=["WithinSession"], + paradigms=["FakeP300Paradigm", "FakeImageryParadigm"], include_datasets=[ - "FakeDataset-imagery-10-2--60-60--120-120--lefthand-righthand--c3-cz-c4", "FakeDataset-p300-10-2--60-60--120-120--target-nontarget--c3-cz-c4", - "FakeDataset-ssvep-10-2--60-60--120-120--13-15--c3-cz-c4", - "FakeDataset-cvep-10-2--60-60--120-120--10-00--c3-cz-c4", + "FakeDataset-imagery-10-2--60-60--120-120--lefthand-righthand--c3-cz-c4", ], overwrite=True, ) - assert len(res) == 80 + assert len(res) == 60 def test_benchmark_objdataset(self): res = benchmark( pipelines=str(self.pp_dir), evaluations=["WithinSession"], include_datasets=[ - FakeDataset(["left_hand", "right_hand"], paradigm="imagery"), - FakeDataset(["Target", "NonTarget"], paradigm="p300"), - FakeDataset(["13", "15"], paradigm="ssvep"), - FakeDataset(["1.0", "0.0"], paradigm="cvep"), + FakeDataset( + ["left_hand", "right_hand"], paradigm="imagery", n_subjects=2 + ), + FakeDataset(["Target", "NonTarget"], paradigm="p300", n_subjects=2), + FakeDataset(["13", "15"], paradigm="ssvep", n_subjects=2), + FakeDataset(["1.0", "0.0"], paradigm="cvep", n_subjects=2), ], overwrite=True, ) - assert len(res) == 80 + assert len(res) == 16 def test_nodataset(self): with pytest.raises(ValueError): benchmark( pipelines=str(self.pp_dir), - exclude_datasets=["FakeDataset"], + exclude_datasets=["NonExistingDatasetCode"], overwrite=True, ) def test_selectparadigm(self): + ds_imagery = FakeImageryParadigm().datasets[0] + ds_p300 = FakeP300Paradigm().datasets[0] res = benchmark( pipelines=str(self.pp_dir), evaluations=["WithinSession"], paradigms=["FakeImageryParadigm"], + include_datasets=[ds_imagery, ds_p300], overwrite=True, ) - assert len(res) == 40 + assert len(res) == 120 def test_include_exclude(self): - with pytest.raises(AttributeError): + with pytest.raises(ValueError): + benchmark( + pipelines=str(self.pp_dir), + include_datasets=["Dataset1"], + exclude_datasets=["Dataset2"], + overwrite=True, + ) + + def test_include_unique(self): + with pytest.raises(ValueError): + benchmark( + pipelines=str(self.pp_dir), + include_datasets=["Dataset1", "Dataset1"], + overwrite=True, + ) + + def test_include_two_types(self): + with pytest.raises(TypeError): benchmark( pipelines=str(self.pp_dir), - include_datasets=["FakeDataset"], - exclude_datasets=["AnotherDataset"], + include_datasets=[ + "Dataset1", + FakeDataset(["left_hand", "right_hand"], paradigm="imagery"), + ], overwrite=True, ) def test_optuna(self): if not optuna_available: pytest.skip("Optuna is not installed") + ds = FakeImageryParadigm().datasets[0] res = benchmark( pipelines=str(self.pp_dir), evaluations=["WithinSession"], paradigms=["FakeImageryParadigm"], + include_datasets=[ + ds, + ], overwrite=True, optuna=True, ) - assert len(res) == 40 + assert len(res) == 120