@@ -73,9 +73,6 @@ def __init__(self, datapipe: IterDataPipe, key_value_fn: Optional[Callable] = No
7373 self ._depleted = False
7474
7575 def _load_map (self ):
76- if self ._map is None :
77- self ._map = {}
78- self ._itr = iter (self .datapipe )
7976 while not self ._depleted :
8077 try :
8178 self ._load_next_item ()
@@ -84,10 +81,7 @@ def _load_map(self):
8481
8582 def __getitem__ (self , index ):
8683 try :
87- if self ._map is None :
88- self ._map = {}
89- self ._itr = iter (self .datapipe )
90- else :
84+ if self ._map is not None :
9185 return self ._map [index ]
9286 except KeyError :
9387 pass
@@ -101,7 +95,10 @@ def __getitem__(self, index):
10195 raise IndexError (f"Index { index } is invalid for IterToMapConverter." )
10296
10397 def _load_next_item (self ):
104- elem = next (self ._itr )
98+ if self ._map is None :
99+ self ._map = {}
100+ self ._itr = iter (self .datapipe )
101+ elem = next (self ._itr ) # type: ignore[arg-type]
105102 inp = elem if self .key_value_fn is None else self .key_value_fn (elem )
106103 try :
107104 length = len (inp )
@@ -135,14 +132,10 @@ def __getstate__(self):
135132 dill_key_value_fn = dill .dumps (self .key_value_fn )
136133 else :
137134 dill_key_value_fn = self .key_value_fn
138- return (
139- self .datapipe ,
140- dill_key_value_fn ,
141- self ._map ,
142- )
135+ return (self .datapipe , dill_key_value_fn , self ._map , self ._itr , self ._depleted )
143136
144137 def __setstate__ (self , state ):
145- (self .datapipe , dill_key_value_fn , self ._map ) = state
138+ (self .datapipe , dill_key_value_fn , self ._map , self . _itr , self . _depleted ) = state
146139 if DILL_AVAILABLE :
147140 self .key_value_fn = dill .loads (dill_key_value_fn ) # type: ignore[assignment]
148141 else :
0 commit comments