Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
66fe819
test: add comprehensive coverage for Experiment class and data genera…
BitForge95 Mar 25, 2026
4314ecd
fix: restore project files accidentally overwritten during rebase
BitForge95 Mar 25, 2026
1c43c9f
fix: restore .readthedocs.yml
BitForge95 Mar 25, 2026
8ef966a
style: auto-format with black and isort
github-actions[bot] Mar 25, 2026
6790470
refactor: remove redundant helpers and use setup_test_experiment
BitForge95 Mar 26, 2026
37d39a4
refactor: use make_modality_config and exact time
BitForge95 Mar 30, 2026
a43feed
style: auto-format with black and isort
github-actions[bot] Mar 30, 2026
3779133
fix: add missing closing import
BitForge95 Apr 11, 2026
82cc45f
style: auto-format with black and isort
github-actions[bot] Apr 11, 2026
d593c6b
test: restore helpful comments, assert messages, and fix zip compatib…
BitForge95 Apr 13, 2026
09fc26d
style: auto-format with black and isort
github-actions[bot] Apr 13, 2026
a3e0db5
chore: restore strict=True as per maintainer preference for Python 3.12+
BitForge95 Apr 13, 2026
0e4cbe9
refactor: remove unused return values from sequence generator
BitForge95 Apr 14, 2026
cd0ef94
test: add missing error messages for valid_range asserts
BitForge95 Apr 14, 2026
22d6bd7
style: auto-format with black and isort
github-actions[bot] Apr 14, 2026
a84e316
test: restore missing comments in DEVICE_TIME_RANGE_CASES
BitForge95 Apr 14, 2026
aeef5b8
fix: restore sequence return values and correct invalid device naming
BitForge95 Apr 14, 2026
84d7d9c
fix: enforce non-NaN rows in mock data and restore return values
BitForge95 Apr 14, 2026
a55183f
fix: move non-NaN enforcement inside contain_nans block
BitForge95 Apr 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 13 additions & 17 deletions experanto/experiment.py
Comment thread
pollytur marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
def _load_devices(self) -> None:
# Populate devices by going through subfolders
# Assumption: blocks are sorted by start time
device_folders = [d for d in self.root_folder.iterdir() if (d.is_dir())]
device_folders = [d for d in self.root_folder.iterdir() if d.is_dir()]

for d in device_folders:
if d.name not in self.modality_config:
Expand All @@ -95,14 +95,14 @@ def _load_devices(self) -> None:
dev = instantiate(
interp_conf, root_folder=d, cache_data=self.cache_data
)

# Check if instantiated object is proper Interpolator
if not isinstance(dev, Interpolator):
raise ValueError(
"Please provide an Interpolator which inherits from experantos Interpolator class."
"Instantiated object must inherit from Interpolator class."
)

elif isinstance(interp_conf, Interpolator):
# Already instantiated Interpolator
dev = interp_conf

else:
Expand Down Expand Up @@ -207,26 +207,22 @@ def interpolate(
dict_keys(['screen', 'responses', 'eye_tracker'])
"""
if device is None:
values = {}
valid = {}
values, valid = {}, {}
for d, interp in self.devices.items():
res = interp.interpolate(times, return_valid=return_valid)
if return_valid:
vals, vlds = res
values[d] = vals
valid[d] = vlds
values[d], valid[d] = vals, vlds
else:
values[d] = res
if return_valid:
return values, valid
else:
return values
return (values, valid) if return_valid else values

elif isinstance(device, str):
assert device in self.devices, f"Unknown device '{device}'"
res = self.devices[device].interpolate(times, return_valid=return_valid)
return res
else:
raise ValueError(f"Unsupported device type: {type(device)}")
if device not in self.devices:
raise KeyError(f"Unknown device '{device}'")
return self.devices[device].interpolate(times, return_valid=return_valid)

raise ValueError(f"Unsupported device type: {type(device)}")

def get_valid_range(self, device_name: str) -> tuple[float, float]:
"""Get the valid time range for a specific device.
Expand All @@ -239,7 +235,7 @@ def get_valid_range(self, device_name: str) -> tuple[float, float]:
Returns
-------
tuple
A tuple ``(start_time, end_time)`` representing the valid
A tuple `(start_time, end_time)` representing the valid
time interval in seconds.
"""
return tuple(self.devices[device_name].valid_interval)
52 changes: 17 additions & 35 deletions tests/create_experiment.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,30 @@
import shutil
from contextlib import contextmanager

import numpy as np
import yaml
from .create_sequence_data import _generate_sequence_data


@contextmanager
def make_sequence_device(
root, name, start, end, sampling_rate=10.0, n_signals=5, override_meta=None
def setup_test_experiment(
Comment thread
pollytur marked this conversation as resolved.
tmp_path,
n_devices=2,
devices_kwargs=None,
default_sampling_rate=1.0,
):
Comment thread
pollytur marked this conversation as resolved.
"""Create a single sequence device folder under root."""
device_root = root / name
try:
(device_root / "meta").mkdir(parents=True, exist_ok=True)

n_samples = (
int((end - start) * sampling_rate) + 1
) # +1 to include both start and end as sample points
timestamps = np.linspace(start, end, n_samples)
data = np.random.rand(n_samples, n_signals)

np.save(device_root / "timestamps.npy", timestamps)
np.save(device_root / "data.npy", data)
devices_kwargs = devices_kwargs or [{}] * n_devices
default_params = {"sampling_rate": default_sampling_rate}

meta = {
"start_time": start,
"end_time": end,
"modality": "sequence",
"sampling_rate": sampling_rate,
"phase_shift_per_signal": False,
"is_mem_mapped": False,
"n_signals": n_signals,
"n_timestamps": n_samples,
"dtype": "float64",
}
if override_meta:
meta.update(override_meta)
with open(device_root / "meta.yml", "w") as f:
yaml.safe_dump(meta, f)

yield device_root
devices_kwargs = [default_params | kwargs for kwargs in devices_kwargs]

try:
tmp_path.mkdir(parents=True, exist_ok=True)
for device_id, device_kwargs in enumerate(devices_kwargs):
device_path = tmp_path / f"device_{device_id}"
_generate_sequence_data(device_path, **device_kwargs)
yield tmp_path
finally:
shutil.rmtree(device_root)
if tmp_path.exists():
shutil.rmtree(tmp_path)


def make_modality_config(*device_names, sampling_rates=None, offsets=None):
Expand Down
126 changes: 81 additions & 45 deletions tests/create_sequence_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,74 +10,110 @@
SEQUENCE_ROOT = Path("tests/sequence_data")


@contextmanager
def create_sequence_data(
def _generate_sequence_data(
sequence_root,
n_signals=10,
shifts_per_signal=False,
use_mem_mapped=False,
start_time=0.0,
t_end=10.0,
sampling_rate=10.0,
contain_nans=False,
):
try:
SEQUENCE_ROOT.mkdir(parents=True, exist_ok=True)
(SEQUENCE_ROOT / "meta").mkdir(parents=True, exist_ok=True)

meta = {
"start_time": 0,
"end_time": t_end,
"modality": "sequence",
"sampling_rate": sampling_rate,
"phase_shift_per_signal": shifts_per_signal,
"is_mem_mapped": use_mem_mapped,
"n_signals": n_signals,
}

timestamps = np.linspace(
meta["start_time"],
meta["end_time"],
int((meta["end_time"] - meta["start_time"]) * meta["sampling_rate"]) + 1,
"""Generates synthetic sequence data folders for testing interpolator logic."""

sequence_root = Path(sequence_root)
Comment thread
BitForge95 marked this conversation as resolved.
sequence_root.mkdir(parents=True, exist_ok=True)
(sequence_root / "meta").mkdir(parents=True, exist_ok=True)

meta = {
"start_time": start_time,
"end_time": t_end,
"modality": "sequence",
"sampling_rate": sampling_rate,
"phase_shift_per_signal": shifts_per_signal,
"is_mem_mapped": use_mem_mapped,
"n_signals": n_signals,
}

# Determine number of samples based on duration and rate
duration = meta["end_time"] - meta["start_time"]
n_samples = int(round(duration * meta["sampling_rate"])) + 1

timestamps = np.linspace(meta["start_time"], meta["end_time"], n_samples)

data = np.random.rand(len(timestamps), n_signals)

if contain_nans:
nan_indices = np.random.choice(
data.size, size=int(0.1 * data.size), replace=False
)
np.save(SEQUENCE_ROOT / "timestamps.npy", timestamps)
meta["n_timestamps"] = len(timestamps)
data.flat[nan_indices] = np.nan
Comment thread
pollytur marked this conversation as resolved.
# ensure each row has at least one non-NaN
if n_signals > 0:
row_all_nan = np.isnan(data).all(axis=1)
data[row_all_nan, 0] = 0.0

data = np.random.rand(len(timestamps), n_signals)
if not use_mem_mapped:
np.save(sequence_root / "data.npy", data)
else:
filename = sequence_root / "data.mem"
fp = np.memmap(filename, dtype=data.dtype, mode="w+", shape=data.shape)
fp[:] = data[:]
fp.flush()
del fp

if contain_nans:
nan_indices = np.random.choice(
data.size, size=int(0.1 * data.size), replace=False
)
data.flat[nan_indices] = np.nan
np.save(sequence_root / "timestamps.npy", timestamps)
meta["n_timestamps"] = len(timestamps)
meta["dtype"] = str(data.dtype)

if not use_mem_mapped:
np.save(SEQUENCE_ROOT / "data.npy", data)
else:
filename = SEQUENCE_ROOT / "data.mem"
# Handle per-signal phase shifts if required by the test case
shifts = None
if shifts_per_signal:
shifts = np.random.rand(n_signals) / meta["sampling_rate"] * 0.9
np.save(sequence_root / "meta" / "phase_shifts.npy", shifts)

fp = np.memmap(filename, dtype=data.dtype, mode="w+", shape=data.shape)
fp[:] = data[:]
fp.flush() # Ensure data is written to disk
del fp
meta["dtype"] = str(data.dtype)
with open(sequence_root / "meta.yml", "w") as f:
yaml.safe_dump(meta, f)

if shifts_per_signal:
shifts = np.random.rand(n_signals) / meta["sampling_rate"] * 0.9
np.save(SEQUENCE_ROOT / "meta" / "phase_shifts.npy", shifts)
return timestamps, data, shifts

with open(SEQUENCE_ROOT / "meta.yml", "w") as f:
yaml.safe_dump(meta, f)

yield timestamps, data, shifts if shifts_per_signal else None
@contextmanager
def create_sequence_data(
n_signals=10,
shifts_per_signal=False,
use_mem_mapped=False,
t_end=10.0,
sampling_rate=10.0,
contain_nans=False,
start_time=0.0,
):
Comment thread
BitForge95 marked this conversation as resolved.
"""Context manager for temporary sequence data creation and cleanup."""
Comment thread
BitForge95 marked this conversation as resolved.
try:
yield _generate_sequence_data(
sequence_root=SEQUENCE_ROOT,
n_signals=n_signals,
shifts_per_signal=shifts_per_signal,
use_mem_mapped=use_mem_mapped,
t_end=t_end,
sampling_rate=sampling_rate,
contain_nans=contain_nans,
start_time=start_time,
)
finally:
shutil.rmtree(SEQUENCE_ROOT)
if SEQUENCE_ROOT.exists():
shutil.rmtree(SEQUENCE_ROOT)


@contextmanager
def sequence_data_and_interpolator(data_kwargs=None, interp_kwargs=None):
data_kwargs = data_kwargs or {}
interp_kwargs = interp_kwargs or {}
with create_sequence_data(**data_kwargs) as (timestamps, data, shifts):
# Restore the helper expected by the rest of the test suite

with closing(
Interpolator.create("tests/sequence_data", **interp_kwargs)
Interpolator.create(str(SEQUENCE_ROOT), **interp_kwargs)
) as seq_interp:
yield timestamps, data, shifts, seq_interp
Loading
Loading