@@ -276,7 +276,7 @@ def value(self):
276276 return self ._maybe_autocast (self ._value )
277277
278278 def assign (self , value ):
279- value = self ._convert_to_tensor (value , dtype = self .dtype )
279+ value = self ._convert_to_tensor (value , dtype = self ._dtype )
280280 if not shape_equal (value .shape , self .shape ):
281281 raise ValueError (
282282 "The shape of the target variable and "
@@ -599,7 +599,6 @@ def standardize_shape(shape):
599599 # `tf.TensorShape` may contain `Dimension` objects.
600600 # We need to convert the items in it to either int or `None`
601601 shape = shape .as_list ()
602- shape = tuple (shape )
603602
604603 if config .backend () == "jax" :
605604 # Replace `_DimExpr` (dimension expression) with None
@@ -609,25 +608,37 @@ def standardize_shape(shape):
609608 None if jax_export .is_symbolic_dim (d ) else d for d in shape
610609 )
611610
612- if config .backend () == "torch" :
613- # `shape` might be `torch.Size`. We need to convert the items in it to
614- # either int or `None`
615- shape = tuple (map (lambda x : int (x ) if x is not None else None , shape ))
616-
617- for e in shape :
618- if e is None :
611+ # Handle dimensions that are not ints and not None, verify they're >= 0.
612+ standardized_shape = []
613+ for d in shape :
614+ if d is None :
615+ standardized_shape .append (d )
619616 continue
620- if not is_int_dtype (type (e )):
617+
618+ # Reject these even if they can be cast to int successfully.
619+ if isinstance (d , (str , float )):
621620 raise ValueError (
622621 f"Cannot convert '{ shape } ' to a shape. "
623- f"Found invalid entry '{ e } ' of type '{ type (e )} '. "
622+ f"Found invalid dimension '{ d } ' of type '{ type (d )} '. "
624623 )
625- if e < 0 :
624+
625+ try :
626+ # Cast numpy scalars, tf constant tensors, etc.
627+ d = int (d )
628+ except Exception as e :
629+ raise ValueError (
630+ f"Cannot convert '{ shape } ' to a shape. "
631+ f"Found invalid dimension '{ d } ' of type '{ type (d )} '. "
632+ ) from e
633+ if d < 0 :
626634 raise ValueError (
627635 f"Cannot convert '{ shape } ' to a shape. "
628636 "Negative dimensions are not allowed."
629637 )
630- return shape
638+ standardized_shape .append (d )
639+
640+ # This also turns subclasses of `tuple` (e.g. `torch.Size`) to plain tuple.
641+ return tuple (standardized_shape )
631642
632643
633644def shape_equal (a_shape , b_shape ):
0 commit comments