@@ -14,6 +14,10 @@ Key Features:
1414- Multiple transport backends (RDMA, regular TCP) for optimal performance
1515- Flexible storage volume management and sharding strategies
1616
17+ Note: Although this may change in the future, TorchStore only supports multi-processing/multi-node jobs launched with Monarch.
18+ For more information on what Monarch is, see https://github.com/meta-pytorch/monarch?tab=readme-ov-file#monarch-
19+
20+
1721> ⚠️ ** Early Development Warning** TorchStore is currently in an experimental
1822> stage. You should expect bugs, incomplete features, and APIs that may change
1923> in future versions. The project welcomes bugfixes, but to make sure things are
@@ -51,8 +55,13 @@ pip install -e .
5155
5256# Install development dependencies
5357pip install -e ' .[dev]'
58+
59+ # NOTE: It's common to run into libpytorch issues. A good workaround is to export:
60+ # export LD_LIBRARY_PATH="$CONDA_PREFIX/lib:${LD_LIBRARY_PATH:-}"
5461```
5562
63+
64+
5665### Regular Installation
5766
5867To install the package directly from the repository:
@@ -67,78 +76,110 @@ Once installed, you can import it in your Python code:
6776import torchstore
6877```
6978
70- Note: Setup currently assumes you have a working conda environment with both torch & monarch (this is currently a todo).
71-
7279## Usage
7380
7481``` python
75- import torch
7682import asyncio
83+
84+ import torch
85+
86+ from monarch.actor import Actor, current_rank, endpoint
87+
7788import torchstore as ts
89+ from torchstore.utils import spawn_actors
90+
91+
92+ WORLD_SIZE = 4
93+
94+
95+ # In monarch, Actors are the way we represent multi-process/node applications. For additional details, see:
96+ # https://github.com/meta-pytorch/monarch?tab=readme-ov-file#monarch-
97+ class ExampleActor (Actor ):
98+ def __init__ (self , world_size = WORLD_SIZE ):
99+ self .rank = current_rank().rank
100+ self .world_size = WORLD_SIZE
101+
102+ @endpoint
103+ async def store_tensor (self ):
104+ t = torch.tensor([self .rank])
105+ await ts.put(f " { self .rank} _tensor " , t)
106+
107+ @endpoint
108+ async def print_tensor (self ):
109+ other_rank = (self .rank + 1 ) % self .world_size
110+ t = await ts.get(f " { other_rank} _tensor " )
111+ print (f " Rank=[ { self .rank} ] Fetched { t} from { other_rank= } " )
112+
78113
79114async def main ():
80115
81116 # Create a store instance
82117 await ts.initialize()
83118
84- # Store a tensor
85- await ts.put(" my_tensor" , torch.randn(3 , 4 ))
119+ actors = await spawn_actors(WORLD_SIZE , ExampleActor, " example_actors" )
86120
87- # Retrieve a tensor
88- tensor = await ts.get(" my_tensor" )
121+ # Calls "store_tensor" on each actor instance
122+ await actors.store_tensor.call()
123+ await actors.print_tensor.call()
89124
90-
91- if __name__ == " __main__" :
125+ if __name__ == " __main__" :
92126 asyncio.run(main())
93127
128+ # Expected output
129+ # [0] [2] Rank=[2] Fetched tensor([3]) from other_rank=3
130+ # [0] [0] Rank=[0] Fetched tensor([1]) from other_rank=1
131+ # [0] [3] Rank=[3] Fetched tensor([0]) from other_rank=0
132+ # [0] [1] Rank=[1] Fetched tensor([2]) from other_rank=2
133+
94134```
95135
96136### Resharding Support with DTensor
97137
98- ``` python
99- import torchstore as ts
100- from torch.distributed._tensor import distribute_tensor, Replicate, Shard
101- from torch.distributed.device_mesh import init_device_mesh
102-
103- async def place_dtensor_in_store ():
104- device_mesh = init_device_mesh(" cpu" , (4 ,))
105- tensor = torch.arange(4 )
106- dtensor = distribute_tensor(tensor, device_mesh, placements = [Shard(1 )])
107-
108- # Store a tensor
109- await ts.put(" my_tensor" , dtensor)
138+ TorchStore makes it easy to fetch arbitraty slices of any Distributed Tensor.
139+ For a full DTensor example, see [ examples/dtensor.py] ( https://github.com/meta-pytorch/torchstore/blob/main/example/dtensor.py )
110140
111141
112- async def fetch_dtensor_from_store ()
113- # You can now fetch arbitrary shards of this tensor from any rank e.g.
114- device_mesh = init_device_mesh(" cpu" , (2 ,2 ))
115- tensor = torch.rand(4 )
116- dtensor = distribute_tensor(
117- tensor,
118- device_mesh,
119- placements = [Replicate(), Shard(0 )]
120- )
121-
122- # This line copies the previously stored dtensor into local memory.
123- await ts.get(" my_tensor" , dtensor)
124-
125- def run_in_parallel (func ):
126- # just for demonstrative purposes
127- return func
142+ ``` python
128143
129- if __name__ == " __main__" :
130- ts.initialize()
131- run_in_parallel(place_dtensor_in_store)
132- run_in_parallel(fetch_dtensor_from_store)
133- ts.shutdown()
144+ class DTensorActor (Actor ):
145+ """
146+ Example pseudo-code for an Actor utilizing DTensor support
147+
148+ Full actor definition in [examples/dtensor.py](https://github.com/meta-pytorch/torchstore/blob/main/example/dtensor.py)
149+ """
150+
151+ @endpoint
152+ async def do_put (self ):
153+ # Typical dtensor boiler-plate
154+ self .initialize_distributed()
155+ device_mesh = init_device_mesh(" cpu" , self .mesh_shape)
156+ tensor = self .original_tensor.to(" cpu" )
157+ dtensor = distribute_tensor(tensor, device_mesh, placements = self .placements)
158+
159+ print (f " Calling put with { dtensor= } " )
160+ # This will place only the local shard into TorchStore
161+ await ts.put(self .shared_key, dtensor)
162+
163+ @endpoint
164+ async def do_get (self ):
165+ # Typical dtensor boiler-plate
166+ self .initialize_distributed()
167+ device_mesh = init_device_mesh(" cpu" , self .mesh_shape)
168+ tensor = self .original_tensor.to(" cpu" )
169+ dtensor = distribute_tensor(tensor, device_mesh, placements = self .placements)
170+
171+ # Torchstore will use the metadata in the local dtensor to only fetch tensor data
172+ # which belongs to the local shard.
173+ fetched_tensor = await ts.get(self .shared_key, dtensor)
174+ print (fetched_tensor)
134175
135176# checkout out tests/test_resharding.py for more e2e examples with resharding DTensor.
136177```
137178
138179# Testing
139180
140181Pytest is used for testing. For an examples of how to run tests (and get logs), see:
141- ` TORCHSTORE_LOG_LEVEL=DEBUG pytest -vs --log-cli-level=DEBUG tests/test_models.py::test_main `
182+ `TORCHSTORE_LOG_LEVEL=DEBUG pytest -vs --log-cli-level=DEBUG tests/test_models.py::test_basic
142183
143184## License
144185
0 commit comments