Add neuron selection support to SequenceInterpolator and SpikeInterpolator#126
Add neuron selection support to SequenceInterpolator and SpikeInterpolator#126Vrittigyl wants to merge 2 commits intosensorium-competition:mainfrom
Conversation
|
Review these changes at https://app.gitnotebooks.com/sensorium-competition/experanto/pull/126 |
There was a problem hiding this comment.
Pull request overview
Adds neuron-subset selection to interpolators so callers can load/process only specific neurons by biological IDs (meta/unit_ids.npy) or by direct column/neuron indexes, reducing memory usage for large datasets (Issue #124).
Changes:
- Extend
SequenceInterpolatorto acceptneuron_ids/indexesand filter loaded data + normalization stats accordingly. - Extend
PhaseShiftedSequenceInterpolatorto filterphase_shiftsconsistently with neuron selection. - Extend
SpikeInterpolatorto acceptneuron_ids/indexesand rebuildspikes/indicesarrays for the selected neurons.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
| @@ -230,20 +290,34 @@ def __init__( | |||
| else: | |||
| self._data = np.load(self.root_folder / "data.npy") | |||
There was a problem hiding this comment.
IMO if data gets cached, we should only cache the indexed neurons. Downside is that if someone alters self.indexes of the instance later manually, it is still relying on the old data. But I think I still prefer this. What do you think @pollytur ?
If changing indexes later is a real use case, one could handle it with a function that takes care of it and would reload the data appropriately, but I am not sure we need it right now.
There was a problem hiding this comment.
I’ve changed the logic so caching happens after applying neuron_indices, so only the selected subset is loaded into memory. According to me this would keep memory usage efficient for large data.
can revert changes if needed
c070196 to
ff9f0b3
Compare
| else: | ||
| self._data = np.load(self.root_folder / "data.npy") | ||
|
|
||
| # Apply indexing BEFORE caching |
There was a problem hiding this comment.
Hi @pollytur
as far as I know, if one applies indexes to memmap as lists or arrays, this loads a copy into RAM (see https://stackoverflow.com/questions/18614927/how-to-slice-memmap-efficiently or https://stackoverflow.com/questions/78426050/how-to-index-a-numpy-memmap-without-creating-an-in-memory-copy)
Only regular continuous slicing creates just a view, but I don't think neuron ids will always be continuous.
Have you thought about this? We might need to set caching to True if neurons are indexed. Or we find a workaround (I haven't investigated it yet).
There was a problem hiding this comment.
great catch, no I have not thought about it tbh
we probably want to investigate the workaround (changing order of neurons as we need and save it as a temp memmap file is the first though but thats insanly memory inefficient since neuronal responses are also the heaviest part of the dataset from a memory perspective...)
There was a problem hiding this comment.
Or maybe instead of directly indexing, we could iteratively fetch only the required columns (or in chunks) and optionally cache them.
There was a problem hiding this comment.
@Vrittigyl could you please provide an implementation for what you proposed and show that it still caches the data if cache_data == True and doesn't cache the data otherwise, also when we select only certain neurons? Thanks!
There was a problem hiding this comment.
I’ve implemented this and tested it using test functions.
When cache_data=False, _data stays as memmap and only required neuron columns are accessed during interpolation. When cache_data=True, only selected neurons are loaded into memory.
I verified this by checking data types, shapes, and comparing outputs.
| def _resolve_indices(self, neuron_ids, neuron_indices): | ||
| if neuron_ids is None and neuron_indices is None: | ||
| return None | ||
|
|
||
| if neuron_ids is not None: | ||
| unit_ids = np.load(self.root_folder / "meta/unit_ids.npy") | ||
| ids_to_indexes = [] | ||
|
|
||
| for nid in neuron_ids: | ||
| match = np.where(unit_ids == nid)[0] | ||
| if len(match) == 0: | ||
| raise ValueError(f"Neuron id {nid} not found") | ||
| ids_to_indexes.append(int(match[0])) | ||
|
|
||
| if neuron_indices is None: | ||
| return ids_to_indexes | ||
|
|
||
| if set(ids_to_indexes) != set(neuron_indices): | ||
| raise ValueError( | ||
| "neuron_ids and neuron_indices refer to different neurons" | ||
| ) | ||
|
|
||
| warnings.warn( | ||
| "Both neuron_ids and neuron_indices provided; using neuron_indices", | ||
| stacklevel=2, | ||
| ) | ||
|
|
||
| return self._validate_indices(neuron_indices) | ||
|
|
||
| def _validate_indices(self, neuron_indices): | ||
| try: | ||
| indexes_seq = list(neuron_indices) | ||
| except TypeError as exc: | ||
| raise TypeError("neuron_indices must be iterable") from exc | ||
|
|
||
| if not all(isinstance(i, (int, np.integer)) for i in indexes_seq): | ||
| raise TypeError("neuron_indices must contain integers") | ||
|
|
||
| if indexes_seq: | ||
| if min(indexes_seq) < 0 or max(indexes_seq) >= self.n_signals: | ||
| raise ValueError("neuron_indices out of bounds") | ||
|
|
||
| if len(set(indexes_seq)) != len(indexes_seq): | ||
| raise ValueError("neuron_indices contain duplicates") | ||
|
|
||
| return indexes_seq |
There was a problem hiding this comment.
Sorry, but I don't like putting those functions into Interpolator abstract class as these functions only make sense for neuron-related interpolators. Can we move these functions outside of any class and just pass all needed params into the function (as we can't use self then anymore)? I would then load unit_ids before calling _resolve_indices and pass unit_ids directly instead of self.root_folder.
There was a problem hiding this comment.
Will surely fix it sir
1bf6f1d to
772ba38
Compare
|
@Vrittigyl could you please take a look why tests are failing now? |
03fe2b7 to
a7db3bc
Compare
_data was expected to be sliced after neuron selection, but now _data stays full to support memmap. |
|
@Vrittigyl please merge current main into your branch - I resolved the conflict but its not a full merge, so ruff CI / CD fails because of it please ping me after its done and I will take a look |
fee5ed9 to
66b46df
Compare
b2a1376 to
f6a99a0
Compare
@pollytur I’ve merged the latest main into my branch, resolved the conflicts, and fixed the CI issues. Please take a look! |
|
@Vrittigyl sorry for back and forward Also you have been editing same files so merging is quite crucial here |
| if len(chunk) == 0: | ||
| continue | ||
| video_array = np.stack(chunk, axis=0) | ||
| video_array = np.stack(list(chunk), axis=0) |
There was a problem hiding this comment.
I changed this because the previous version was causing pyright failures with np.stack when chunk wasn’t a proper sequence. Converting it to list(chunk) ensures it works consistently.
| "shifts_per_signal": True, | ||
| } | ||
| ) as (timestamps, data, shift, seq_interp): | ||
| assert shift is not None |
There was a problem hiding this comment.
please add error message in case the assert fails, ideally with f -string and values if needed
| }, | ||
| interp_kwargs={"keep_nans": keep_nans}, | ||
| ) as (timestamps, data, shift, seq_interp): | ||
| assert shift is not None |
There was a problem hiding this comment.
please add error message in case the assert fails, ideally with f -string and values if needed
| }, | ||
| interp_kwargs={"keep_nans": keep_nans}, | ||
| ) as (_, _, phase_shifts, seq_interp): | ||
| assert phase_shifts is not None |
There was a problem hiding this comment.
please add error message in case the assert fails, ideally with f -string and values if needed
| interp_kwargs={"cache_data": True}, | ||
| ) as (gt_spikes, interp): | ||
|
|
||
| assert isinstance(interp, SpikeInterpolator) |
There was a problem hiding this comment.
please add error message in case the assert fails, ideally with f -string and values if needed
| }, | ||
| interp_kwargs={"cache_data": False}, | ||
| ) as (gt_spikes, interp): | ||
| assert isinstance(interp, SpikeInterpolator) |
There was a problem hiding this comment.
please add error message in case the assert fails, ideally with f -string and values if needed
| }, | ||
| interp_kwargs={"cache_data": True}, | ||
| ) as (gt_spikes, interp): | ||
| assert isinstance(interp, SpikeInterpolator) |
There was a problem hiding this comment.
please add error message in case the assert fails, ideally with f -string and values if needed
| if isinstance(signal, tuple): | ||
| signal = signal[0] |
There was a problem hiding this comment.
why do we need it here and everywhere else in this file?
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def resolve_neuron_indices(neuron_ids, neuron_indices, unit_ids, n_signals): |
There was a problem hiding this comment.
imho resolve_neuron_indices should go to utils.py
@reneburghardt what would you say?
| return validate_neuron_indices(neuron_indices, n_signals) | ||
|
|
||
|
|
||
| def validate_neuron_indices(neuron_indices, n_signals): |
There was a problem hiding this comment.
imho validate_neuron_indices should go to utils.py
@reneburghardt what would you say?
I am also tempted to put some function like select_channels to SequenceInterpolator class. Because in principle we might also want to select only certain channels from eye tracker as well. And a particular case of this function would call the function from utils for the neurons (to resolve unit_ids meta to indexes and validate it)
Description
This PR adds support for selecting specific neurons in both
SequenceInterpolatorandSpikeInterpolator.Users can now pass either:
neuron_ids: biological neuron IDs mapped usingmeta/unit_ids.npyindexes: direct neuron indexesIf both are provided:
ValueErroris raised if they refer to different neurons.Changes
SequenceInterpolator
neuron_idsorindexes.PhaseShiftedSequenceInterpolator
phase_shiftsto match selected neurons when neuron selection is used.SpikeInterpolator
neuron_idsorindexes.This allows loading and processing only a subset of neurons, reducing memory usage and improving flexibility when working with imaging datasets.
Closes #124