diff --git a/experanto/experiment.py b/experanto/experiment.py index fa739c9..9ac1215 100644 --- a/experanto/experiment.py +++ b/experanto/experiment.py @@ -76,7 +76,7 @@ def __init__( def _load_devices(self) -> None: # Populate devices by going through subfolders # Assumption: blocks are sorted by start time - device_folders = [d for d in self.root_folder.iterdir() if (d.is_dir())] + device_folders = [d for d in self.root_folder.iterdir() if d.is_dir()] for d in device_folders: if d.name not in self.modality_config: @@ -95,14 +95,14 @@ def _load_devices(self) -> None: dev = instantiate( interp_conf, root_folder=d, cache_data=self.cache_data ) + # Check if instantiated object is proper Interpolator if not isinstance(dev, Interpolator): raise ValueError( - "Please provide an Interpolator which inherits from experantos Interpolator class." + "Instantiated object must inherit from Interpolator class." ) elif isinstance(interp_conf, Interpolator): - # Already instantiated Interpolator dev = interp_conf else: @@ -207,26 +207,22 @@ def interpolate( dict_keys(['screen', 'responses', 'eye_tracker']) """ if device is None: - values = {} - valid = {} + values, valid = {}, {} for d, interp in self.devices.items(): res = interp.interpolate(times, return_valid=return_valid) if return_valid: vals, vlds = res - values[d] = vals - valid[d] = vlds + values[d], valid[d] = vals, vlds else: values[d] = res - if return_valid: - return values, valid - else: - return values + return (values, valid) if return_valid else values + elif isinstance(device, str): - assert device in self.devices, f"Unknown device '{device}'" - res = self.devices[device].interpolate(times, return_valid=return_valid) - return res - else: - raise ValueError(f"Unsupported device type: {type(device)}") + if device not in self.devices: + raise KeyError(f"Unknown device '{device}'") + return self.devices[device].interpolate(times, return_valid=return_valid) + + raise ValueError(f"Unsupported device type: {type(device)}") def get_valid_range(self, device_name: str) -> tuple[float, float]: """Get the valid time range for a specific device. @@ -239,7 +235,7 @@ def get_valid_range(self, device_name: str) -> tuple[float, float]: Returns ------- tuple - A tuple ``(start_time, end_time)`` representing the valid + A tuple `(start_time, end_time)` representing the valid time interval in seconds. """ return tuple(self.devices[device_name].valid_interval) diff --git a/tests/create_experiment.py b/tests/create_experiment.py index 2a97c07..1d2413c 100644 --- a/tests/create_experiment.py +++ b/tests/create_experiment.py @@ -1,48 +1,30 @@ import shutil from contextlib import contextmanager -import numpy as np -import yaml +from .create_sequence_data import _generate_sequence_data @contextmanager -def make_sequence_device( - root, name, start, end, sampling_rate=10.0, n_signals=5, override_meta=None +def setup_test_experiment( + tmp_path, + n_devices=2, + devices_kwargs=None, + default_sampling_rate=1.0, ): - """Create a single sequence device folder under root.""" - device_root = root / name - try: - (device_root / "meta").mkdir(parents=True, exist_ok=True) - - n_samples = ( - int((end - start) * sampling_rate) + 1 - ) # +1 to include both start and end as sample points - timestamps = np.linspace(start, end, n_samples) - data = np.random.rand(n_samples, n_signals) - - np.save(device_root / "timestamps.npy", timestamps) - np.save(device_root / "data.npy", data) + devices_kwargs = devices_kwargs or [{}] * n_devices + default_params = {"sampling_rate": default_sampling_rate} - meta = { - "start_time": start, - "end_time": end, - "modality": "sequence", - "sampling_rate": sampling_rate, - "phase_shift_per_signal": False, - "is_mem_mapped": False, - "n_signals": n_signals, - "n_timestamps": n_samples, - "dtype": "float64", - } - if override_meta: - meta.update(override_meta) - with open(device_root / "meta.yml", "w") as f: - yaml.safe_dump(meta, f) - - yield device_root + devices_kwargs = [default_params | kwargs for kwargs in devices_kwargs] + try: + tmp_path.mkdir(parents=True, exist_ok=True) + for device_id, device_kwargs in enumerate(devices_kwargs): + device_path = tmp_path / f"device_{device_id}" + _generate_sequence_data(device_path, **device_kwargs) + yield tmp_path finally: - shutil.rmtree(device_root) + if tmp_path.exists(): + shutil.rmtree(tmp_path) def make_modality_config(*device_names, sampling_rates=None, offsets=None): diff --git a/tests/create_sequence_data.py b/tests/create_sequence_data.py index 0b70c8d..9e6ad88 100644 --- a/tests/create_sequence_data.py +++ b/tests/create_sequence_data.py @@ -10,66 +10,100 @@ SEQUENCE_ROOT = Path("tests/sequence_data") -@contextmanager -def create_sequence_data( +def _generate_sequence_data( + sequence_root, n_signals=10, shifts_per_signal=False, use_mem_mapped=False, + start_time=0.0, t_end=10.0, sampling_rate=10.0, contain_nans=False, ): - try: - SEQUENCE_ROOT.mkdir(parents=True, exist_ok=True) - (SEQUENCE_ROOT / "meta").mkdir(parents=True, exist_ok=True) - - meta = { - "start_time": 0, - "end_time": t_end, - "modality": "sequence", - "sampling_rate": sampling_rate, - "phase_shift_per_signal": shifts_per_signal, - "is_mem_mapped": use_mem_mapped, - "n_signals": n_signals, - } - - timestamps = np.linspace( - meta["start_time"], - meta["end_time"], - int((meta["end_time"] - meta["start_time"]) * meta["sampling_rate"]) + 1, + """Generates synthetic sequence data folders for testing interpolator logic.""" + + sequence_root = Path(sequence_root) + sequence_root.mkdir(parents=True, exist_ok=True) + (sequence_root / "meta").mkdir(parents=True, exist_ok=True) + + meta = { + "start_time": start_time, + "end_time": t_end, + "modality": "sequence", + "sampling_rate": sampling_rate, + "phase_shift_per_signal": shifts_per_signal, + "is_mem_mapped": use_mem_mapped, + "n_signals": n_signals, + } + + # Determine number of samples based on duration and rate + duration = meta["end_time"] - meta["start_time"] + n_samples = int(round(duration * meta["sampling_rate"])) + 1 + + timestamps = np.linspace(meta["start_time"], meta["end_time"], n_samples) + + data = np.random.rand(len(timestamps), n_signals) + + if contain_nans: + nan_indices = np.random.choice( + data.size, size=int(0.1 * data.size), replace=False ) - np.save(SEQUENCE_ROOT / "timestamps.npy", timestamps) - meta["n_timestamps"] = len(timestamps) + data.flat[nan_indices] = np.nan + # ensure each row has at least one non-NaN + if n_signals > 0: + row_all_nan = np.isnan(data).all(axis=1) + data[row_all_nan, 0] = 0.0 - data = np.random.rand(len(timestamps), n_signals) + if not use_mem_mapped: + np.save(sequence_root / "data.npy", data) + else: + filename = sequence_root / "data.mem" + fp = np.memmap(filename, dtype=data.dtype, mode="w+", shape=data.shape) + fp[:] = data[:] + fp.flush() + del fp - if contain_nans: - nan_indices = np.random.choice( - data.size, size=int(0.1 * data.size), replace=False - ) - data.flat[nan_indices] = np.nan + np.save(sequence_root / "timestamps.npy", timestamps) + meta["n_timestamps"] = len(timestamps) + meta["dtype"] = str(data.dtype) - if not use_mem_mapped: - np.save(SEQUENCE_ROOT / "data.npy", data) - else: - filename = SEQUENCE_ROOT / "data.mem" + # Handle per-signal phase shifts if required by the test case + shifts = None + if shifts_per_signal: + shifts = np.random.rand(n_signals) / meta["sampling_rate"] * 0.9 + np.save(sequence_root / "meta" / "phase_shifts.npy", shifts) - fp = np.memmap(filename, dtype=data.dtype, mode="w+", shape=data.shape) - fp[:] = data[:] - fp.flush() # Ensure data is written to disk - del fp - meta["dtype"] = str(data.dtype) + with open(sequence_root / "meta.yml", "w") as f: + yaml.safe_dump(meta, f) - if shifts_per_signal: - shifts = np.random.rand(n_signals) / meta["sampling_rate"] * 0.9 - np.save(SEQUENCE_ROOT / "meta" / "phase_shifts.npy", shifts) + return timestamps, data, shifts - with open(SEQUENCE_ROOT / "meta.yml", "w") as f: - yaml.safe_dump(meta, f) - yield timestamps, data, shifts if shifts_per_signal else None +@contextmanager +def create_sequence_data( + n_signals=10, + shifts_per_signal=False, + use_mem_mapped=False, + t_end=10.0, + sampling_rate=10.0, + contain_nans=False, + start_time=0.0, +): + """Context manager for temporary sequence data creation and cleanup.""" + try: + yield _generate_sequence_data( + sequence_root=SEQUENCE_ROOT, + n_signals=n_signals, + shifts_per_signal=shifts_per_signal, + use_mem_mapped=use_mem_mapped, + t_end=t_end, + sampling_rate=sampling_rate, + contain_nans=contain_nans, + start_time=start_time, + ) finally: - shutil.rmtree(SEQUENCE_ROOT) + if SEQUENCE_ROOT.exists(): + shutil.rmtree(SEQUENCE_ROOT) @contextmanager @@ -77,7 +111,9 @@ def sequence_data_and_interpolator(data_kwargs=None, interp_kwargs=None): data_kwargs = data_kwargs or {} interp_kwargs = interp_kwargs or {} with create_sequence_data(**data_kwargs) as (timestamps, data, shifts): + # Restore the helper expected by the rest of the test suite + with closing( - Interpolator.create("tests/sequence_data", **interp_kwargs) + Interpolator.create(str(SEQUENCE_ROOT), **interp_kwargs) ) as seq_interp: yield timestamps, data, shifts, seq_interp diff --git a/tests/test_experiment.py b/tests/test_experiment.py index 72d5f2a..c3e2628 100644 --- a/tests/test_experiment.py +++ b/tests/test_experiment.py @@ -1,12 +1,21 @@ import logging -from contextlib import ExitStack +from unittest.mock import MagicMock import numpy as np import pytest +import yaml +from hypothesis import HealthCheck, given, settings +from hypothesis import strategies as st from experanto.experiment import Experiment +from experanto.interpolators import Interpolator -from .create_experiment import make_modality_config, make_sequence_device +from .create_experiment import ( + make_modality_config, + setup_test_experiment, +) + +# --- Test Data and Mocks --- DEVICE_TIME_RANGE_CASES = [ # Single device: start and end should match that device's range @@ -32,7 +41,7 @@ "large_time_stamps", ] -# Inverted range is intentionally separate from INVALID_META_CASES — +# Inverted range is intentionally separate from INVALID_META_CASES - # None/NaN/inf are caught per-device before being added to self.devices, # whereas start > end is only caught after all devices are loaded. INVALID_META_CASES = [ @@ -60,7 +69,200 @@ ] -# Test for union of device time ranges +class DummyInterpolator(Interpolator): + """Small concrete interpolator used for testing Experiment routing logic.""" + + def __init__(self): + self.start_time = 0.0 + self.end_time = 100.0 + self.valid_interval = (self.start_time, self.end_time) + self.interpolate = MagicMock(return_value=np.array([1, 2, 3])) + + +@pytest.fixture +def mock_interpolator(): + """Shared interpolator instance to isolate Experiment logic from interpolation math.""" + return DummyInterpolator() + + +# --- Tests --- + + +def test_experiment_initialization_and_device_loading(tmp_path, mock_interpolator): + """Verify that only devices defined in modality_config are initialized.""" + (tmp_path / "screen").mkdir() + (tmp_path / "eye_tracker").mkdir() + (tmp_path / "ignored_device").mkdir() + config = { + "screen": {"interpolation": mock_interpolator}, + "eye_tracker": {"interpolation": mock_interpolator}, + } + exp = Experiment(root_folder=str(tmp_path), modality_config=config) + assert ( + "screen" in exp.devices + ), f"Expected 'screen' in experiment devices but exp.devices = {exp.devices}" + assert ( + "eye_tracker" in exp.devices + ), f"Expected 'eye_tracker' in experiment devices but exp.devices = {exp.devices}" + assert ( + "ignored_device" not in exp.devices + ), f"Expected 'ignored_device' to be excluded from experiment devices but exp.devices = {exp.devices}" + assert set(exp.device_names) == { + "screen", + "eye_tracker", + }, f"Expected {{'screen', 'eye_tracker'}}, got {set(exp.device_names)}" + + +def test_experiment_interpolate_routing(tmp_path, mock_interpolator): + """Check if Experiment correctly delegates calls to underlying interpolators.""" + (tmp_path / "screen").mkdir() + config = {"screen": {"interpolation": mock_interpolator}} + exp = Experiment(root_folder=str(tmp_path), modality_config=config) + test_times = np.array([10.0, 20.0]) + res = exp.interpolate(test_times, device="screen") + mock_interpolator.interpolate.assert_called_once_with( + test_times, return_valid=False + ) + np.testing.assert_array_equal( + res, + np.array([1, 2, 3]), + err_msg=f"Interpolated result does not match mock output. Got {res}", + ) + res_dict = exp.interpolate(test_times, device=None) + assert isinstance( + res_dict, dict + ), f"Expected dict return for device=None, got {type(res_dict)}" + np.testing.assert_array_equal( + res_dict["screen"], + np.array([1, 2, 3]), + err_msg=f"Interpolated dict result does not match mock output. Got {res_dict['screen']}", + ) + + +@pytest.mark.parametrize( + "device_name, start_t, end_t", + [("device_0", 0.0, 10.0), ("device_1", 0.0, 20.0), ("device_2", 5.0, 15.0)], +) +def test_get_valid_range_all_devices(tmp_path, device_name, start_t, end_t): + """Integration test for valid_interval propagation from disk to object.""" + with setup_test_experiment( + tmp_path, + n_devices=3, + devices_kwargs=[ + {"t_end": 10.0}, + {"t_end": 20.0}, + {"start_time": 5.0, "t_end": 15.0}, + ], + ) as experiment_path: + config = make_modality_config("device_0", "device_1", "device_2") + config["device_2"] = { + "sampling_rate": 1.0, + "chunk_size": 40, + "interpolation": {"interpolation_mode": "nearest_neighbor"}, + } + experiment = Experiment( + root_folder=str(experiment_path), modality_config=config + ) + valid_range = experiment.get_valid_range(device_name) + assert valid_range == ( + start_t, + end_t, + ), f"Expected valid range {(start_t, end_t)} for {device_name}, got {valid_range}" + + +def test_get_valid_range_raises_for_invalid_device(tmp_path): + with setup_test_experiment(tmp_path) as experiment_path: + experiment = Experiment( + root_folder=str(experiment_path), + modality_config=make_modality_config("device_0", "device_1"), + ) + with pytest.raises(KeyError): + experiment.get_valid_range("device_does_not_exist") + + +def test_experiment_with_non_zero_start_time(tmp_path): + """Test boundary conditions for data not starting at t=0.""" + start_offset, duration = 1.5, 10.0 + with setup_test_experiment( + tmp_path, + n_devices=1, + devices_kwargs=[{"t_end": start_offset + duration, "start_time": start_offset}], + ) as experiment_path: + experiment = Experiment( + root_folder=str(experiment_path), + modality_config=make_modality_config("device_0"), + ) + res = experiment.interpolate(np.array([start_offset + 1.0]), device="device_0") + assert res is not None + with pytest.warns(UserWarning, match="no valid times queried"): + experiment.interpolate(np.array([start_offset - 1.0]), device="device_0") + + +@given( + start_offset=st.floats(min_value=0.0, max_value=100.0), + sampling_rate=st.floats(min_value=0.1, max_value=100.0), + duration=st.floats(min_value=0.0, max_value=100.0), +) +@settings(deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture]) +def test_experiment_numeric_precision_offset( + tmp_path, start_offset, sampling_rate, duration +): + """Stress test using non-integer rates and offsets to catch float drift.""" + with setup_test_experiment( + tmp_path, + n_devices=1, + devices_kwargs=[ + { + "start_time": start_offset, + "t_end": start_offset + duration, + "sampling_rate": sampling_rate, + } + ], + ) as experiment_path: + experiment = Experiment( + root_folder=str(experiment_path), + modality_config=make_modality_config("device_0"), + ) + valid_range = experiment.get_valid_range("device_0") + + assert valid_range[0] == pytest.approx( + start_offset + ), f"Expected valid_range[0] to be approx {start_offset}, got {valid_range[0]}" + + assert valid_range[1] == pytest.approx( + start_offset + duration + ), f"Expected valid_range[1] to be approx {start_offset + duration}, got {valid_range[1]}" + + res = experiment.interpolate(np.array([start_offset]), device="device_0") + assert res is not None + + +@pytest.mark.parametrize("return_valid", [False, True]) +@pytest.mark.parametrize("device", [None, "device_0"]) +def test_experiment_multi_device_interpolation(tmp_path, return_valid, device): + """Check data consistency when interpolating across multiple modalities.""" + with setup_test_experiment(tmp_path, n_devices=2) as experiment_path: + exp = Experiment( + root_folder=str(experiment_path), + modality_config=make_modality_config("device_0", "device_1"), + ) + times = np.array([1.0, 2.0]) + results = exp.interpolate(times, device=device, return_valid=return_valid) + if return_valid: + data, valid_idx = results + assert ( + data["device_0"].shape == (2, 10) + if device is None + else data.shape == (2, 10) + ) + else: + assert ( + results["device_0"].shape == (2, 10) + if device is None + else results.shape == (2, 10) + ) + + @pytest.mark.parametrize("n_signals", [5, 20]) @pytest.mark.parametrize( "device_ranges, expected_start, expected_end", @@ -70,30 +272,27 @@ def test_experiment_start_end_time_reflects_union( tmp_path, device_ranges, expected_start, expected_end, n_signals ): - """ - Experiment.start_time and end_time should reflect the union of all - device time ranges — earliest start and latest end across all devices. - """ - device_names = [f"device_{i}" for i in range(len(device_ranges))] - - with ExitStack() as stack: - for name, (start, end) in zip(device_names, device_ranges, strict=True): - stack.enter_context( - make_sequence_device( - tmp_path, - name, - start=start, - end=end, - n_signals=n_signals, - sampling_rate=float(np.random.randint(5, 30)), - ) - ) + """Experiment.start_time and end_time should reflect the union of all device time ranges.""" + devices_kwargs = [ + { + "start_time": start, + "t_end": end, + "n_signals": n_signals, + "sampling_rate": float(np.random.randint(5, 30)), + } + for start, end in device_ranges + ] + + with setup_test_experiment( + tmp_path, n_devices=len(device_ranges), devices_kwargs=devices_kwargs + ) as experiment_path: + # Dynamically build the config using make_modality_config + device_names = [f"device_{i}" for i in range(len(device_ranges))] + offsets = [float(np.random.rand()) for _ in device_ranges] + config = make_modality_config(*device_names, offsets=offsets) experiment = Experiment( - root_folder=tmp_path, - modality_config=make_modality_config( - *device_names, offsets=[float(np.random.rand()) for _ in device_names] - ), + root_folder=str(experiment_path), modality_config=config ) assert experiment.start_time == ( @@ -104,49 +303,25 @@ def test_experiment_start_end_time_reflects_union( ), f"Expected end_time={expected_end}, got {experiment.end_time}" -# Safety check @pytest.mark.parametrize("override_meta", INVALID_META_CASES, ids=INVALID_META_IDS) def test_experiment_invalid_metadata(tmp_path, override_meta): - """ - Experiment should raise an error when initialized with invalid metadata. - Covers cases where start_time or end_time is None, NaN, or infinite. - """ - with make_sequence_device( - tmp_path, - "device_0", - start=0.0, - end=10.0, - override_meta=override_meta, - ): - with pytest.raises( - ValueError, match="Experiment time range could not be determined" - ): - Experiment( - root_folder=tmp_path, - modality_config=make_modality_config("device_0"), - ) + with setup_test_experiment( + tmp_path, n_devices=1, devices_kwargs=[{"start_time": 0.0, "t_end": 10.0}] + ) as experiment_path: + # Explicitly corrupt the generated metadata file + meta_file = experiment_path / "device_0" / "meta.yml" + with open(meta_file) as f: + meta = yaml.safe_load(f) + meta.update(override_meta) + with open(meta_file, "w") as f: + yaml.safe_dump(meta, f) + config = make_modality_config("device_0") -def test_experiment_inverted_time_range_raises(tmp_path): - """ - Experiment should raise ValueError when start_time > end_time. - This is a separate guard from invalid metadata (None/NaN/inf) because it - only becomes apparent after all devices are loaded and the overall time range is computed. - """ - with make_sequence_device( - tmp_path, - "device_0", - start=0.0, - end=10.0, - override_meta={"start_time": 5.0, "end_time": 2.0}, - ): with pytest.raises( ValueError, match="Experiment time range could not be determined" ): - Experiment( - root_folder=tmp_path, - modality_config=make_modality_config("device_0"), - ) + Experiment(root_folder=str(experiment_path), modality_config=config) @pytest.mark.parametrize("override_meta", INVALID_META_CASES, ids=INVALID_META_IDS) @@ -161,39 +336,45 @@ def test_experiment_skips_invalid_devices(tmp_path, override_meta, caplog): duration_val = np.random.lognormal(mean=0.0, sigma=1.0) end_val = start_val + duration_val + # Generate random values for the invalid device as well start_nonval = np.random.lognormal(mean=0.0, sigma=1.0) duration_nonval = np.random.lognormal(mean=0.0, sigma=1.0) end_nonval = start_nonval + duration_nonval - with ExitStack() as stack: - # Valid device with proper metadata - stack.enter_context( - make_sequence_device( - tmp_path, - "valid_device", - start=start_val, - end=end_val, - ) - ) - # Invalid device with missing start_time and end_time - stack.enter_context( - make_sequence_device( - tmp_path, - "invalid_device", - start=start_nonval, - end=end_nonval, - override_meta=override_meta, - ) - ) + devices_kwargs = [ + {"start_time": start_val, "t_end": end_val}, # valid device + {"start_time": start_nonval, "t_end": end_nonval}, # invalid device + ] + + with setup_test_experiment( + tmp_path, n_devices=2, devices_kwargs=devices_kwargs + ) as experiment_path: + # Rename the folders to match what the old test expected + (experiment_path / "device_0").rename(experiment_path / "valid_device") + (experiment_path / "device_1").rename(experiment_path / "invalid_device") + + # Explicitly corrupt the metadata file for the invalid device + meta_file = experiment_path / "invalid_device" / "meta.yml" + with open(meta_file) as f: + meta = yaml.safe_load(f) + meta.update(override_meta) + with open(meta_file, "w") as f: + yaml.safe_dump(meta, f) + + config = make_modality_config("valid_device", "invalid_device") with caplog.at_level(logging.WARNING, logger="experanto.experiment"): experiment = Experiment( - root_folder=tmp_path, - modality_config=make_modality_config("valid_device", "invalid_device"), + root_folder=str(experiment_path), modality_config=config ) - assert "valid_device" in experiment.devices - assert "invalid_device" not in experiment.devices + assert ( + "valid_device" in experiment.devices + ), f"Expected 'valid_device' to be in experiment.devices but experiment.devices = {experiment.devices}" + + assert ( + "invalid_device" not in experiment.devices + ), f"Expected 'invalid_device' to be skipped in experiment.devices, instead experiment.devices = {experiment.devices}" assert experiment.start_time == ( start_val @@ -201,6 +382,7 @@ def test_experiment_skips_invalid_devices(tmp_path, override_meta, caplog): assert experiment.end_time == ( end_val ), f"Expected end_time={end_val}, got {experiment.end_time}" + assert any( "invalid_device" in message for message in caplog.messages - ), "Expected warning about invalid_device was skipped" + ), f"Expected warning about invalid_device was skipped. caplog.messages = {caplog.messages}"