1- import random
2-
31import keras .src .layers as layers
42from keras .src .api_export import keras_export
53from keras .src .layers .preprocessing .image_preprocessing .base_image_preprocessing_layer import ( # noqa: E501
@@ -169,9 +167,15 @@ def get_random_transformation(self, data, training=True, seed=None):
169167 augmentation_layer = getattr (self , layer_name )
170168 augmentation_layer .backend .set_backend ("tensorflow" )
171169
170+ layer_idxes = self .backend .random .randint (
171+ (self .num_ops ,),
172+ 0 ,
173+ len (self ._AUGMENT_LAYERS ),
174+ seed = self ._get_seed_generator (self .backend ._backend ),
175+ )
176+
172177 transformation = {}
173- random .shuffle (self ._AUGMENT_LAYERS )
174- for layer_name in self ._AUGMENT_LAYERS [: self .num_ops ]:
178+ for layer_name in self ._AUGMENT_LAYERS :
175179 augmentation_layer = getattr (self , layer_name )
176180 transformation [layer_name ] = (
177181 augmentation_layer .get_random_transformation (
@@ -181,17 +185,25 @@ def get_random_transformation(self, data, training=True, seed=None):
181185 )
182186 )
183187
184- return transformation
188+ return {
189+ "transforms" : transformation ,
190+ "layer_idxes" : layer_idxes ,
191+ }
185192
186193 def transform_images (self , images , transformation , training = True ):
187194 if training :
188195 images = self .backend .cast (images , self .compute_dtype )
189196
190- for layer_name , transformation_value in transformation .items ():
191- augmentation_layer = getattr (self , layer_name )
192- images = augmentation_layer .transform_images (
193- images , transformation_value
194- )
197+ layer_idxes = transformation ["layer_idxes" ]
198+ transforms = transformation ["transforms" ]
199+ for i in range (self .num_ops ):
200+ for idx , (key , value ) in enumerate (transforms .items ()):
201+ augmentation_layer = getattr (self , key )
202+ images = self .backend .numpy .where (
203+ layer_idxes [i ] == idx ,
204+ augmentation_layer .transform_images (images , value ),
205+ images ,
206+ )
195207
196208 images = self .backend .cast (images , self .compute_dtype )
197209 return images
@@ -206,11 +218,29 @@ def transform_bounding_boxes(
206218 training = True ,
207219 ):
208220 if training :
209- for layer_name , transformation_value in transformation .items ():
210- augmentation_layer = getattr (self , layer_name )
211- bounding_boxes = augmentation_layer .transform_bounding_boxes (
212- bounding_boxes , transformation_value , training = training
221+ layer_idxes = transformation ["layer_idxes" ]
222+ transforms = transformation ["transforms" ]
223+ for idx , (key , value ) in enumerate (transforms .items ()):
224+ augmentation_layer = getattr (self , key )
225+
226+ transformed_bounding_box = (
227+ augmentation_layer .transform_bounding_boxes (
228+ bounding_boxes .copy (), value
229+ )
213230 )
231+ for i in range (self .num_ops ):
232+ bounding_boxes ["boxes" ] = self .backend .numpy .where (
233+ layer_idxes [i ] == idx ,
234+ transformed_bounding_box ["boxes" ],
235+ bounding_boxes ["boxes" ],
236+ )
237+
238+ bounding_boxes ["labels" ] = self .backend .numpy .where (
239+ layer_idxes [i ] == idx ,
240+ transformed_bounding_box ["labels" ],
241+ bounding_boxes ["labels" ],
242+ )
243+
214244 return bounding_boxes
215245
216246 def transform_segmentation_masks (
0 commit comments