Skip to content

Commit 1bf6f1d

Browse files
authored
Merge branch 'main' into neuron-selection
2 parents 4613cf7 + d851484 commit 1bf6f1d

1 file changed

Lines changed: 29 additions & 7 deletions

File tree

experanto/interpolators.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -164,16 +164,28 @@ def create(root_folder: str, cache_data: bool = False, **kwargs) -> Interpolator
164164
if modality == "sequence":
165165
if meta_data.get("phase_shift_per_signal", False):
166166
return PhaseShiftedSequenceInterpolator(
167-
root_folder, cache_data, **kwargs
167+
root_folder, cache_data=cache_data, **kwargs
168168
)
169169
else:
170-
return SequenceInterpolator(root_folder, cache_data, **kwargs)
170+
return SequenceInterpolator(
171+
root_folder, cache_data=cache_data, **kwargs
172+
)
171173
elif modality == "screen":
172-
return ScreenInterpolator(root_folder, cache_data, **kwargs)
174+
use_stimuli_names = kwargs.pop(
175+
"use_stimuli_names", meta_data.get("use_stimuli_names", False)
176+
)
177+
return ScreenInterpolator(
178+
root_folder,
179+
cache_data=cache_data,
180+
use_stimuli_names=use_stimuli_names,
181+
**kwargs,
182+
)
173183
elif modality == "time_interval":
174-
return TimeIntervalInterpolator(root_folder, cache_data, **kwargs)
184+
return TimeIntervalInterpolator(
185+
root_folder, cache_data=cache_data, **kwargs
186+
)
175187
elif modality == "spikes":
176-
return SpikeInterpolator(root_folder, cache_data, **kwargs)
188+
return SpikeInterpolator(root_folder, cache_data=cache_data, **kwargs)
177189
else:
178190
raise ValueError(
179191
f"There is no interpolator for {modality}. Please use 'sequence', 'screen', 'time_interval' as modality or provide a custom interpolator."
@@ -601,6 +613,8 @@ class ScreenInterpolator(Interpolator):
601613
native image size from metadata.
602614
normalize : bool, default=False
603615
If True, normalizes frames using stored mean/std statistics.
616+
use_stimuli_names : bool, default=False
617+
If True, uses ``stimulus_name`` from metadata to locate data files instead of trial keys.
604618
**kwargs
605619
Additional keyword arguments (ignored).
606620
@@ -625,6 +639,7 @@ def __init__(
625639
rescale: bool = False,
626640
rescale_size: tuple[int, int] | None = None,
627641
normalize: bool = False,
642+
use_stimuli_names: bool = False,
628643
**kwargs,
629644
) -> None:
630645
super().__init__(root_folder)
@@ -634,6 +649,7 @@ def __init__(
634649
self.valid_interval = TimeInterval(self.start_time, self.end_time)
635650
self.rescale = rescale
636651
self.cache_trials = cache_data # Store the cache preference
652+
self.use_stimuli_names = use_stimuli_names
637653
self._parse_trials()
638654

639655
# create mapping from image index to file index
@@ -718,8 +734,14 @@ def _parse_trials(self) -> None:
718734
metadatas, keys = self.read_combined_meta()
719735

720736
for key, metadata in zip(keys, metadatas, strict=True):
721-
data_file_name = self.root_folder / "data" / f"{key}.npy"
722-
# Pass the cache_trials parameter when creating trials
737+
if self.use_stimuli_names:
738+
stimulus_name = metadata.get("stimulus_name")
739+
assert (
740+
stimulus_name is not None
741+
), f"stimulus_name is required in metadata when use_stimuli_names is True, but not found for key: {key}"
742+
data_file_name = self.root_folder / "data" / f"{stimulus_name}.npy"
743+
else:
744+
data_file_name = self.root_folder / "data" / f"{key}.npy"
723745
self.trials.append(
724746
ScreenTrial.create(
725747
data_file_name, metadata, cache_data=self.cache_trials

0 commit comments

Comments
 (0)