@@ -105,6 +105,9 @@ def infer_output_devices(self, *inputs, **args):
105105 self ._init_spec (inputs , args )
106106 return self ._output_devices
107107
108+ def _pre_call (self , * inputs , ** args ):
109+ pass
110+
108111 def _is_backend_initialized (self ):
109112 return self ._op_backend is not None
110113
@@ -302,16 +305,14 @@ def __init__(
302305 self ._actual_batch_size , name , device , num_inputs , call_arg_names , ** kwargs
303306 )
304307
305- def __call__ (self , ctx = None , * inputs , ** args ):
308+ def _pre_call (self , * inputs , ** args ):
306309 if self ._api_type is None :
307310 self ._api_type = "run"
308311 elif self ._api_type != "run" :
309312 raise RuntimeError (
310313 "Cannot mix `samples`, `batches` and `run`/`__call__` on the same reader until the end of the epoch."
311314 )
312315
313- return super ()(ctx , * inputs , ** args )
314-
315316 def run (self , ctx = None , * inputs , ** args ):
316317 if self ._api_type is None :
317318 self ._api_type = "run"
@@ -342,7 +343,7 @@ def samples(self, ctx: Optional[_eval_context.EvalContext] = None):
342343 meta = self ._op_backend .GetReaderMeta ()
343344 idx = 0
344345 while idx < meta ["epoch_size_padded" ]:
345- outputs = super ()(ctx )
346+ outputs = super (). run (ctx )
346347 batch_size = len (outputs [0 ])
347348 idx += batch_size
348349 for x in zip (* outputs ):
@@ -378,7 +379,7 @@ def batches(self, batch_size=None, ctx: Optional[_eval_context.EvalContext] = No
378379 meta = self ._op_backend .GetReaderMeta ()
379380 idx = 0
380381 while idx < meta ["epoch_size_padded" ]:
381- outputs = super ()(ctx )
382+ outputs = super (). run (ctx )
382383 batch_size = len (outputs [0 ])
383384 idx += batch_size
384385 yield outputs
0 commit comments