@@ -164,16 +164,28 @@ def create(root_folder: str, cache_data: bool = False, **kwargs) -> Interpolator
164164 if modality == "sequence" :
165165 if meta_data .get ("phase_shift_per_signal" , False ):
166166 return PhaseShiftedSequenceInterpolator (
167- root_folder , cache_data , ** kwargs
167+ root_folder , cache_data = cache_data , ** kwargs
168168 )
169169 else :
170- return SequenceInterpolator (root_folder , cache_data , ** kwargs )
170+ return SequenceInterpolator (
171+ root_folder , cache_data = cache_data , ** kwargs
172+ )
171173 elif modality == "screen" :
172- return ScreenInterpolator (root_folder , cache_data , ** kwargs )
174+ use_stimuli_names = kwargs .pop (
175+ "use_stimuli_names" , meta_data .get ("use_stimuli_names" , False )
176+ )
177+ return ScreenInterpolator (
178+ root_folder ,
179+ cache_data = cache_data ,
180+ use_stimuli_names = use_stimuli_names ,
181+ ** kwargs ,
182+ )
173183 elif modality == "time_interval" :
174- return TimeIntervalInterpolator (root_folder , cache_data , ** kwargs )
184+ return TimeIntervalInterpolator (
185+ root_folder , cache_data = cache_data , ** kwargs
186+ )
175187 elif modality == "spikes" :
176- return SpikeInterpolator (root_folder , cache_data , ** kwargs )
188+ return SpikeInterpolator (root_folder , cache_data = cache_data , ** kwargs )
177189 else :
178190 raise ValueError (
179191 f"There is no interpolator for { modality } . Please use 'sequence', 'screen', 'time_interval' as modality or provide a custom interpolator."
@@ -601,6 +613,8 @@ class ScreenInterpolator(Interpolator):
601613 native image size from metadata.
602614 normalize : bool, default=False
603615 If True, normalizes frames using stored mean/std statistics.
616+ use_stimuli_names : bool, default=False
617+ If True, uses ``stimulus_name`` from metadata to locate data files instead of trial keys.
604618 **kwargs
605619 Additional keyword arguments (ignored).
606620
@@ -625,6 +639,7 @@ def __init__(
625639 rescale : bool = False ,
626640 rescale_size : tuple [int , int ] | None = None ,
627641 normalize : bool = False ,
642+ use_stimuli_names : bool = False ,
628643 ** kwargs ,
629644 ) -> None :
630645 super ().__init__ (root_folder )
@@ -634,6 +649,7 @@ def __init__(
634649 self .valid_interval = TimeInterval (self .start_time , self .end_time )
635650 self .rescale = rescale
636651 self .cache_trials = cache_data # Store the cache preference
652+ self .use_stimuli_names = use_stimuli_names
637653 self ._parse_trials ()
638654
639655 # create mapping from image index to file index
@@ -718,8 +734,14 @@ def _parse_trials(self) -> None:
718734 metadatas , keys = self .read_combined_meta ()
719735
720736 for key , metadata in zip (keys , metadatas , strict = True ):
721- data_file_name = self .root_folder / "data" / f"{ key } .npy"
722- # Pass the cache_trials parameter when creating trials
737+ if self .use_stimuli_names :
738+ stimulus_name = metadata .get ("stimulus_name" )
739+ assert (
740+ stimulus_name is not None
741+ ), f"stimulus_name is required in metadata when use_stimuli_names is True, but not found for key: { key } "
742+ data_file_name = self .root_folder / "data" / f"{ stimulus_name } .npy"
743+ else :
744+ data_file_name = self .root_folder / "data" / f"{ key } .npy"
723745 self .trials .append (
724746 ScreenTrial .create (
725747 data_file_name , metadata , cache_data = self .cache_trials
0 commit comments