| 
2 | 2 | from pathlib import Path  | 
3 | 3 | from typing import Any  | 
4 | 4 | from typing import Callable  | 
 | 5 | +from typing import Iterable  | 
5 | 6 | from typing import Iterator  | 
6 | 7 | from typing import Literal  | 
7 | 8 | from typing import Optional  | 
@@ -78,6 +79,12 @@ class BigWigDataset:  | 
78 | 79 |             GPU. More threads means that more IO can take place while the GPU is busy doing  | 
79 | 80 |             calculations (decompressing or neural network training for example). More threads  | 
80 | 81 |             also means a higher GPU memory usage. Default: 4  | 
 | 82 | +        custom_position_sampler: if set, this sampler will be used instead of the default  | 
 | 83 | +            position sampler (which samples randomly and uniform from regions of interest)  | 
 | 84 | +            This should be an iterable of tuples (chromosome, center).  | 
 | 85 | +        custom_track_sampler: if specified, this sampler will be used to sample tracks. When not  | 
 | 86 | +            specified, each batch simply contains all tracks, or a randomly sellected subset of  | 
 | 87 | +            tracks in case sub_sample_tracks is set. Should be Iterable batches of track indices.  | 
81 | 88 |         return_batch_objects: if True, the batches will be returned as instances of  | 
82 | 89 |             bigwig_loader.batch.Batch  | 
83 | 90 |     """  | 
@@ -107,6 +114,8 @@ def __init__(  | 
107 | 114 |         repeat_same_positions: bool = False,  | 
108 | 115 |         sub_sample_tracks: Optional[int] = None,  | 
109 | 116 |         n_threads: int = 4,  | 
 | 117 | +        custom_position_sampler: Optional[Iterable[tuple[str, int]]] = None,  | 
 | 118 | +        custom_track_sampler: Optional[Iterable[list[int]]] = None,  | 
110 | 119 |         return_batch_objects: bool = False,  | 
111 | 120 |     ):  | 
112 | 121 |         super().__init__()  | 
@@ -152,32 +161,34 @@ def __init__(  | 
152 | 161 |         self._sub_sample_tracks = sub_sample_tracks  | 
153 | 162 |         self._n_threads = n_threads  | 
154 | 163 |         self._return_batch_objects = return_batch_objects  | 
155 |  | - | 
156 |  | -    def _create_dataloader(self) -> StreamedDataloader:  | 
157 |  | -        position_sampler = RandomPositionSampler(  | 
 | 164 | +        self._position_sampler = custom_position_sampler or RandomPositionSampler(  | 
158 | 165 |             regions_of_interest=self.regions_of_interest,  | 
159 | 166 |             buffer_size=self._position_sampler_buffer_size,  | 
160 | 167 |             repeat_same=self._repeat_same_positions,  | 
161 | 168 |         )  | 
 | 169 | +        if custom_track_sampler is not None:  | 
 | 170 | +            self._track_sampler: Optional[Iterable[list[int]]] = custom_track_sampler  | 
 | 171 | +        elif sub_sample_tracks is not None:  | 
 | 172 | +            self._track_sampler = TrackSampler(  | 
 | 173 | +                total_number_of_tracks=len(self.bigwig_collection),  | 
 | 174 | +                sample_size=sub_sample_tracks,  | 
 | 175 | +            )  | 
 | 176 | +        else:  | 
 | 177 | +            self._track_sampler = None  | 
162 | 178 | 
 
  | 
 | 179 | +    def _create_dataloader(self) -> StreamedDataloader:  | 
163 | 180 |         sequence_sampler = GenomicSequenceSampler(  | 
164 | 181 |             reference_genome_path=self.reference_genome_path,  | 
165 | 182 |             sequence_length=self.sequence_length,  | 
166 |  | -            position_sampler=position_sampler,  | 
 | 183 | +            position_sampler=self._position_sampler,  | 
167 | 184 |             maximum_unknown_bases_fraction=self.maximum_unknown_bases_fraction,  | 
168 | 185 |         )  | 
169 |  | -        track_sampler = None  | 
170 |  | -        if self._sub_sample_tracks is not None:  | 
171 |  | -            track_sampler = TrackSampler(  | 
172 |  | -                total_number_of_tracks=len(self.bigwig_collection),  | 
173 |  | -                sample_size=self._sub_sample_tracks,  | 
174 |  | -            )  | 
175 | 186 | 
 
  | 
176 | 187 |         query_batch_generator = QueryBatchGenerator(  | 
177 | 188 |             genomic_location_sampler=sequence_sampler,  | 
178 | 189 |             center_bin_to_predict=self.center_bin_to_predict,  | 
179 | 190 |             batch_size=self.super_batch_size,  | 
180 |  | -            track_sampler=track_sampler,  | 
 | 191 | +            track_sampler=self._track_sampler,  | 
181 | 192 |         )  | 
182 | 193 | 
 
  | 
183 | 194 |         return StreamedDataloader(  | 
 | 
0 commit comments