Skip to content

Commit 5f44361

Browse files
committed
Add pre-call callback.
Signed-off-by: Michal Zientkiewicz <[email protected]>
1 parent c045dd6 commit 5f44361

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

dali/python/nvidia/dali/experimental/dali2/_op_builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ def build_call_function(schema, op_class):
200200
header = f"__call__({', '.join(['self'] + inputs + call_args)})"
201201

202202
def call(self, *raw_args, batch_size=None, **raw_kwargs):
203+
self._pre_call(*raw_args, **raw_kwargs)
203204
is_batch = batch_size is not None
204205
if batch_size is None:
205206
for i, x in enumerate(list(raw_args) + list(raw_kwargs.values())):

dali/python/nvidia/dali/experimental/dali2/ops.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ def infer_output_devices(self, *inputs, **args):
105105
self._init_spec(inputs, args)
106106
return self._output_devices
107107

108+
def _pre_call(self, *inputs, **args):
109+
pass
110+
108111
def _is_backend_initialized(self):
109112
return self._op_backend is not None
110113

@@ -302,16 +305,14 @@ def __init__(
302305
self._actual_batch_size, name, device, num_inputs, call_arg_names, **kwargs
303306
)
304307

305-
def __call__(self, ctx=None, *inputs, **args):
308+
def _pre_call(self, *inputs, **args):
306309
if self._api_type is None:
307310
self._api_type = "run"
308311
elif self._api_type != "run":
309312
raise RuntimeError(
310313
"Cannot mix `samples`, `batches` and `run`/`__call__` on the same reader until the end of the epoch."
311314
)
312315

313-
return super()(ctx, *inputs, **args)
314-
315316
def run(self, ctx=None, *inputs, **args):
316317
if self._api_type is None:
317318
self._api_type = "run"
@@ -342,7 +343,7 @@ def samples(self, ctx: Optional[_eval_context.EvalContext] = None):
342343
meta = self._op_backend.GetReaderMeta()
343344
idx = 0
344345
while idx < meta["epoch_size_padded"]:
345-
outputs = super()(ctx)
346+
outputs = super().run(ctx)
346347
batch_size = len(outputs[0])
347348
idx += batch_size
348349
for x in zip(*outputs):
@@ -378,7 +379,7 @@ def batches(self, batch_size=None, ctx: Optional[_eval_context.EvalContext] = No
378379
meta = self._op_backend.GetReaderMeta()
379380
idx = 0
380381
while idx < meta["epoch_size_padded"]:
381-
outputs = super()(ctx)
382+
outputs = super().run(ctx)
382383
batch_size = len(outputs[0])
383384
idx += batch_size
384385
yield outputs

0 commit comments

Comments
 (0)