Skip to content

Commit ff53b06

Browse files
committed
More fixes
Signed-off-by: Janusz Lisiecki <[email protected]>
1 parent 41896c6 commit ff53b06

File tree

1 file changed

+10
-23
lines changed
  • dali/python/nvidia/dali/plugin/pytorch/loader_evaluator

1 file changed

+10
-23
lines changed

dali/python/nvidia/dali/plugin/pytorch/loader_evaluator/loader.py

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,8 @@
2323
import threading
2424
import time
2525
from collections import deque
26-
from typing import Any, Dict, Iterator, Optional, Tuple
26+
from typing import Any, Dict, Iterator, Optional
2727

28-
from torch import Tensor as TorchTensor
2928
from torch.utils.data import DataLoader
3029

3130

@@ -82,7 +81,7 @@ def __init__(
8281
if self.mode == "replay":
8382
self._cache_batches()
8483

85-
def __iter__(self) -> Iterator[Tuple[TorchTensor, TorchTensor]]:
84+
def __iter__(self) -> Iterator[Any]:
8685
"""Iterate through the dataloader based on the current mode."""
8786
if self.mode == "log":
8887
return self._log_mode_iter()
@@ -109,7 +108,7 @@ def _cache_batches(self):
109108
if len(self.cached_batches) > 0:
110109
self.cache_ready = True
111110

112-
def _log_mode_iter(self) -> Iterator[Tuple[TorchTensor, TorchTensor]]:
111+
def _log_mode_iter(self) -> Iterator[Any]:
113112
"""Log mode: iterate normally while collecting metrics."""
114113
self.start_time = time.time()
115114
self.batch_times = []
@@ -127,7 +126,7 @@ def _log_mode_iter(self) -> Iterator[Tuple[TorchTensor, TorchTensor]]:
127126

128127
self.end_time = time.time()
129128

130-
def _replay_mode_iter(self) -> Iterator[Tuple[TorchTensor, TorchTensor]]:
129+
def _replay_mode_iter(self) -> Iterator[Any]:
131130
"""Replay mode: replay cached batches for ideal performance
132131
simulation."""
133132
if not self.cache_ready or len(self.cached_batches) == 0:
@@ -143,24 +142,12 @@ def _replay_mode_iter(self) -> Iterator[Tuple[TorchTensor, TorchTensor]]:
143142
original_length = len(self.dataloader)
144143
batch_count = 0
145144

146-
while batch_count < original_length:
147-
for batch in self.cached_batches:
148-
if batch_count >= original_length:
149-
break
150-
151-
batch_start = time.time()
152-
batch_count += 1
153-
# Record timing before yield to ensure it's captured even if
154-
# iteration breaks
155-
batch_time = time.time() - batch_start
156-
self.batch_times.append(batch_time)
157-
# Record metrics during yielding (time to yield = cache access
158-
# time)
159-
yield batch
160-
161-
# Check if we've reached the original length after incrementing
162-
if batch_count >= original_length:
163-
break
145+
for i in range(original_length):
146+
batch_start = time.time()
147+
batch = self.cached_batches[i % len(self.cached_batches)]
148+
batch_time = time.time() - batch_start
149+
self.batch_times.append(batch_time)
150+
yield batch
164151

165152
self.end_time = time.time()
166153

0 commit comments

Comments
 (0)