2323import threading
2424import time
2525from collections import deque
26- from typing import Any , Dict , Iterator , Optional , Tuple
26+ from typing import Any , Dict , Iterator , Optional
2727
28- from torch import Tensor as TorchTensor
2928from torch .utils .data import DataLoader
3029
3130
@@ -82,7 +81,7 @@ def __init__(
8281 if self .mode == "replay" :
8382 self ._cache_batches ()
8483
85- def __iter__ (self ) -> Iterator [Tuple [ TorchTensor , TorchTensor ] ]:
84+ def __iter__ (self ) -> Iterator [Any ]:
8685 """Iterate through the dataloader based on the current mode."""
8786 if self .mode == "log" :
8887 return self ._log_mode_iter ()
@@ -109,7 +108,7 @@ def _cache_batches(self):
109108 if len (self .cached_batches ) > 0 :
110109 self .cache_ready = True
111110
112- def _log_mode_iter (self ) -> Iterator [Tuple [ TorchTensor , TorchTensor ] ]:
111+ def _log_mode_iter (self ) -> Iterator [Any ]:
113112 """Log mode: iterate normally while collecting metrics."""
114113 self .start_time = time .time ()
115114 self .batch_times = []
@@ -127,7 +126,7 @@ def _log_mode_iter(self) -> Iterator[Tuple[TorchTensor, TorchTensor]]:
127126
128127 self .end_time = time .time ()
129128
130- def _replay_mode_iter (self ) -> Iterator [Tuple [ TorchTensor , TorchTensor ] ]:
129+ def _replay_mode_iter (self ) -> Iterator [Any ]:
131130 """Replay mode: replay cached batches for ideal performance
132131 simulation."""
133132 if not self .cache_ready or len (self .cached_batches ) == 0 :
@@ -143,24 +142,12 @@ def _replay_mode_iter(self) -> Iterator[Tuple[TorchTensor, TorchTensor]]:
143142 original_length = len (self .dataloader )
144143 batch_count = 0
145144
146- while batch_count < original_length :
147- for batch in self .cached_batches :
148- if batch_count >= original_length :
149- break
150-
151- batch_start = time .time ()
152- batch_count += 1
153- # Record timing before yield to ensure it's captured even if
154- # iteration breaks
155- batch_time = time .time () - batch_start
156- self .batch_times .append (batch_time )
157- # Record metrics during yielding (time to yield = cache access
158- # time)
159- yield batch
160-
161- # Check if we've reached the original length after incrementing
162- if batch_count >= original_length :
163- break
145+ for i in range (original_length ):
146+ batch_start = time .time ()
147+ batch = self .cached_batches [i % len (self .cached_batches )]
148+ batch_time = time .time () - batch_start
149+ self .batch_times .append (batch_time )
150+ yield batch
164151
165152 self .end_time = time .time ()
166153
0 commit comments