1+ import sys
2+ import unittest
3+ import torch
4+ import numpy as np
5+
6+ from torch .distributed .tensor import DeviceMesh
7+ from torch .distributed ._tensor import DTensor
8+ from torch .distributed .tensor .placement_types import Replicate , Shard
9+ import torch_xla
10+ import torch_xla .runtime as xr
11+ import torch_xla .core .xla_model as xm
12+ from torch_xla .distributed .spmd .xla_sharded_tensor import XLAShardedTensor
13+ import test_xla_sharding_base
14+
15+
16+ class DTensorXLAFromLocalConversionTest (test_xla_sharding_base .XlaShardingTest ):
17+ """
18+ Test suite for the automatic conversion of regular tensors to XLAShardedTensor
19+ in DTensor.from_local() when using XLA device mesh.
20+ """
21+
22+ @classmethod
23+ def setUpClass (cls ):
24+ super ().setUpClass ()
25+
26+ def test_basic_conversion (self ):
27+ """Test basic conversion of regular tensor to XLAShardedTensor."""
28+ world_size = xr .global_runtime_device_count ()
29+
30+ # Create a regular tensor (not on XLA device)
31+ tensor = torch .randn (100_000 , 88 )
32+ tensor_cpu = tensor .cpu () # Keep a CPU copy for comparison
33+
34+ # Create a DeviceMesh
35+ device_mesh = DeviceMesh ("xla" , list (range (world_size )))
36+
37+ # Use DTensor.from_local with the regular tensor
38+ dt = DTensor .from_local (tensor , device_mesh = device_mesh )
39+
40+ # Verify the tensor was converted correctly
41+ self .assertEqual (dt .shape , tensor .shape )
42+
43+ # Check the value of the tensor
44+ torch .testing .assert_close (dt .global_tensor , tensor_cpu , check_device = False )
45+
46+ # Verify operations work
47+ result = dt + 1.0
48+ self .assertEqual (result .shape , tensor .shape )
49+
50+ print ("Basic conversion successful" )
51+
52+
53+ def test_conversion_with_placements (self ):
54+ """Test conversion with explicit placements."""
55+ world_size = xr .global_runtime_device_count ()
56+
57+ # Create a regular tensor (not on XLA device)
58+ tensor = torch .randn (100_000 , 88 )
59+ tensor_cpu = tensor .cpu () # Keep a CPU copy for comparison
60+
61+ # Create a DeviceMesh
62+ device_mesh = DeviceMesh ("xla" , list (range (world_size )))
63+
64+ # Use DTensor.from_local with explicit placements
65+ dt = DTensor .from_local (
66+ tensor ,
67+ device_mesh = device_mesh ,
68+ placements = [Replicate ()]
69+ )
70+
71+ # Verify the tensor was converted correctly
72+ self .assertEqual (dt .shape , tensor .shape )
73+
74+ # Check the value of the tensor
75+ torch .testing .assert_close (dt .global_tensor , tensor_cpu , check_device = False )
76+
77+ # Verify operations work
78+ result = dt + 1.0
79+ self .assertEqual (result .shape , tensor .shape )
80+
81+ print ("Conversion with placements successful" )
82+
83+ def test_conversion_with_sharding (self ):
84+ """Test conversion with sharding placement."""
85+ world_size = xr .global_runtime_device_count ()
86+ if world_size < 2 :
87+ self .skipTest ("Need at least 2 devices for sharding test" )
88+
89+ # Create a tensor divisible by world_size
90+ tensor = torch .randn (100_000 , 88 )
91+ tensor_cpu = tensor .cpu () # Keep a CPU copy for comparison
92+
93+ # Create a DeviceMesh
94+ device_mesh = DeviceMesh ("xla" , list (range (world_size )))
95+
96+ # Use DTensor.from_local with sharding placement
97+ dt = DTensor .from_local (
98+ tensor ,
99+ device_mesh = device_mesh ,
100+ placements = [Shard (0 )]
101+ )
102+
103+ # Verify the tensor was converted correctly
104+ self .assertEqual (dt .shape , tensor .shape )
105+
106+ # Check the value of the tensor
107+ torch .testing .assert_close (dt .global_tensor , tensor_cpu , check_device = False )
108+
109+ # Verify operations work
110+ result = dt + 1.0
111+ self .assertEqual (result .shape , tensor .shape )
112+
113+ print ("Conversion with sharding successful" )
114+
115+ def test_conversion_with_different_dtypes (self ):
116+ """Test conversion with different dtypes."""
117+ world_size = xr .global_runtime_device_count ()
118+ device_mesh = DeviceMesh ("xla" , list (range (world_size )))
119+
120+ # Test with different dtypes
121+ for dtype in [torch .float16 , torch .float32 , torch .int32 , torch .int64 ]:
122+ # Create a tensor with specific dtype
123+ tensor = torch .ones (100_000 , 88 , dtype = dtype )
124+ tensor_cpu = tensor .cpu () # Keep a CPU copy for comparison
125+
126+ # Use DTensor.from_local with the tensor
127+ dt = DTensor .from_local (tensor , device_mesh = device_mesh )
128+
129+ # Verify dtype is preserved
130+ self .assertEqual (dt .dtype , dtype )
131+
132+ # Check the value of the tensor
133+ torch .testing .assert_close (dt .global_tensor , tensor_cpu , check_device = False )
134+
135+ # Verify operations work
136+ if dtype .is_floating_point :
137+ result = dt + 1.0
138+ else :
139+ result = dt + 1
140+
141+ self .assertEqual (result .shape , tensor .shape )
142+ self .assertEqual (result .dtype , dtype )
143+
144+ print (f"Conversion with { dtype } successful" )
145+
146+
147+ if __name__ == "__main__" :
148+ result = unittest .main (exit = False )
149+ sys .exit (0 if result .result .wasSuccessful () else 1 )
0 commit comments