File tree Expand file tree Collapse file tree 2 files changed +27
-2
lines changed
torch_xla/distributed/spmd Expand file tree Collapse file tree 2 files changed +27
-2
lines changed Original file line number Diff line number Diff line change @@ -212,6 +212,23 @@ def test_spec_invalidation_on_resharding(self):
212212 assert resharded_tensor ._spec is not initial_spec
213213 assert resharded_tensor ._spec .placements [1 ].dim == 1
214214
215+ def test_auto_wrapped_tensor_spec_failure (self ):
216+ """Test that auto-wrapped tensors fail when accessing _spec property.
217+
218+ Auto-wrapped tensors are created through operations that trigger __torch_dispatch__
219+ but don't yet have access to the sharding propagation done through open xla,
220+ causing ._spec to fail.
221+ """
222+ device_count = xr .global_runtime_device_count ()
223+ mesh = DeviceMesh ("xla" , torch .arange (device_count ))
224+ tensor = torch .randn (4 , 4 )
225+ sharded_tensor = distribute_tensor (tensor , mesh , [Shard (0 )])
226+
227+ auto_wrapped = sharded_tensor + sharded_tensor
228+
229+ with self .assertRaises (ValueError ):
230+ _ = auto_wrapped ._spec
231+
215232
216233if __name__ == '__main__' :
217234 test = unittest .main ()
Original file line number Diff line number Diff line change @@ -205,7 +205,11 @@ def _spec(self):
205205 mesh = DeviceMesh ("xla" ,
206206 torch .tensor (device_list ).reshape (self .mesh_shape ))
207207 else :
208- raise ValueError ("mesh_shape must be specified to create DTensorSpec" )
208+ raise ValueError (
209+ "mesh_shape must be specified to create DTensorSpec. "
210+ "If this tensor was created through torch operations, it may be auto-wrapped. "
211+ "Use wrap_as_sharded_tensor() to set mesh_shape before accessing _spec. "
212+ )
209213
210214 # use existing partition_spec
211215 if self .partition_spec is not None :
@@ -220,7 +224,11 @@ def _spec(self):
220224 placements .append (
221225 Shard (tensor_dim ) if tensor_dim is not None else Replicate ())
222226 else :
223- raise ValueError ("partition_spec must be specified to create DTensorSpec" )
227+ raise ValueError (
228+ "partition_spec must be specified to create DTensorSpec. "
229+ "If this tensor was created through torch operations, it may be auto-wrapped. "
230+ "Use wrap_as_sharded_tensor() to set partition_spec before accessing _spec. "
231+ )
224232
225233 # tensor metadata
226234 tensor_meta = TensorMeta (
You can’t perform that action at this time.
0 commit comments