Skip to content

Commit c02dc83

Browse files
d4l3kmeta-codesync[bot]
authored andcommitted
torchcomms/dummy: enable using dummy backend outside of tests (#75)
Summary: Pull Request resolved: #75 This enables using the dummy backend outside of tests. This allows for mocking out the distributed parts of a model for local performance testing. This also does some small cleanups/registration to make the Dummy backend more similar to other backends. Meta: This required significant changes to how we package the backend .so files to avoid duplicate library linking when building with BUCK. https://docs.google.com/document/d/1A1_djQlNcTznXm72G7R_MELW-j0S05pz4WTFSpeDdts/edit?tab=t.0#heading=h.7ws522vzwljf Reviewed By: fduwjj Differential Revision: D88218584 Privacy Context Container: L1397144 fbshipit-source-id: 464d207c38e589f0c9cf39a17f101a15af09f162
1 parent ff7effb commit c02dc83

File tree

10 files changed

+313
-304
lines changed

10 files changed

+313
-304
lines changed

comms/torchcomms/TorchCommBackend.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
#include <comms/torchcomms/TorchCommUtils.hpp>
1212
#include <comms/torchcomms/TorchCommWindow.hpp>
1313
#include <comms/torchcomms/TorchWork.hpp>
14-
#include <torch/csrc/distributed/c10d/Store.hpp> // @manual=//caffe2:torch-cpp-cpu
15-
#include <torch/csrc/distributed/c10d/Work.hpp> // @manual=//caffe2:torch-cpp-cpu
1614
#include <memory>
1715
#include <vector>
1816

comms/torchcomms/TorchCommBatch.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
#include <ATen/ATen.h>
66
#include <c10/core/Device.h>
77
#include <c10/util/intrusive_ptr.h>
8-
#include <torch/csrc/distributed/c10d/Store.hpp> // @manual=//caffe2:torch-cpp-cpu
9-
#include <torch/csrc/distributed/c10d/Work.hpp> // @manual=//caffe2:torch-cpp-cpu
108
#include "comms/torchcomms/TorchCommOptions.hpp"
119

1210
namespace torch {
Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
3+
#include <comms/torchcomms/TorchCommDummy.hpp>
4+
#include <comms/torchcomms/TorchCommFactory.hpp>
5+
#include <comms/torchcomms/TorchWork.hpp>
6+
#include <torch/csrc/distributed/c10d/Store.hpp> // @manual=//caffe2:torch-cpp-cpu
7+
8+
namespace torch {
9+
namespace comms {
10+
11+
// Dummy TorchWork implementation for testing
12+
class DummyTorchWork : public TorchWork {
13+
public:
14+
bool isCompleted() override {
15+
return true;
16+
}
17+
18+
void wait() override {}
19+
};
20+
21+
class DummyTorchCommWindow : public TorchCommWindow {
22+
public:
23+
void allocate(
24+
const size_t window_size,
25+
bool cpu_buf,
26+
const size_t signal_size = 256) override {
27+
(void)cpu_buf;
28+
(void)signal_size;
29+
win_size_ = window_size;
30+
}
31+
c10::intrusive_ptr<TorchWork> put(
32+
const at::Tensor& data,
33+
int dstRank,
34+
size_t targetDisp,
35+
bool asyncOp) override {
36+
(void)data;
37+
(void)dstRank;
38+
(void)targetDisp;
39+
(void)asyncOp;
40+
return c10::make_intrusive<DummyTorchWork>();
41+
}
42+
at::Tensor getTensor(
43+
int rank,
44+
at::IntArrayRef sizes,
45+
at::ScalarType dtype,
46+
int64_t storageOffset) override {
47+
(void)rank;
48+
(void)sizes;
49+
(void)dtype;
50+
(void)storageOffset;
51+
return at::Tensor();
52+
}
53+
c10::intrusive_ptr<TorchWork> signal(
54+
size_t signalDisp,
55+
uint64_t signalVal,
56+
int dstRank,
57+
bool asyncOp) override {
58+
(void)signalDisp;
59+
(void)signalVal;
60+
(void)dstRank;
61+
(void)asyncOp;
62+
return c10::make_intrusive<DummyTorchWork>();
63+
}
64+
virtual c10::intrusive_ptr<TorchWork> waitSignal(
65+
size_t signalDisp,
66+
uint64_t cmpVal,
67+
SignalCmpOp cmpOp,
68+
bool asyncOp) override {
69+
(void)signalDisp;
70+
(void)cmpVal;
71+
(void)cmpOp;
72+
(void)asyncOp;
73+
return c10::make_intrusive<DummyTorchWork>();
74+
}
75+
};
76+
77+
TorchCommDummy::TorchCommDummy()
78+
: initialized_(false), device_(at::kCPU), rank_(0), size_(1) {}
79+
80+
void TorchCommDummy::init(
81+
at::Device device,
82+
const std::string& name,
83+
const CommOptions& options) {
84+
device_ = device;
85+
options_ = options;
86+
initialized_ = true;
87+
name_ = name;
88+
}
89+
90+
void TorchCommDummy::finalize() {
91+
initialized_ = false;
92+
}
93+
94+
int TorchCommDummy::getRank() const {
95+
return rank_;
96+
}
97+
98+
int TorchCommDummy::getSize() const {
99+
return size_;
100+
}
101+
102+
std::string_view TorchCommDummy::getCommName() const {
103+
return name_;
104+
}
105+
106+
std::string_view TorchCommDummy::getBackendName() const {
107+
return kBackendName;
108+
}
109+
110+
c10::intrusive_ptr<TorchWork> TorchCommDummy::send(
111+
const at::Tensor& tensor,
112+
int dst,
113+
bool async_op,
114+
const SendOptions& options) {
115+
return c10::make_intrusive<DummyTorchWork>();
116+
}
117+
118+
c10::intrusive_ptr<TorchWork> TorchCommDummy::recv(
119+
at::Tensor& tensor,
120+
int src,
121+
bool async_op,
122+
const RecvOptions& options) {
123+
return c10::make_intrusive<DummyTorchWork>();
124+
}
125+
126+
c10::intrusive_ptr<TorchWork> TorchCommDummy::batch_op_issue(
127+
const std::vector<BatchSendRecv::P2POp>& ops,
128+
bool async_op,
129+
const BatchP2POptions& options) {
130+
return c10::make_intrusive<DummyTorchWork>();
131+
}
132+
133+
c10::intrusive_ptr<TorchWork> TorchCommDummy::broadcast(
134+
at::Tensor& tensor,
135+
int root,
136+
bool async_op,
137+
const BroadcastOptions& options) {
138+
return c10::make_intrusive<DummyTorchWork>();
139+
}
140+
141+
c10::intrusive_ptr<TorchWork> TorchCommDummy::all_reduce(
142+
at::Tensor& tensor,
143+
const ReduceOp& op,
144+
bool async_op,
145+
const AllReduceOptions& options) {
146+
return c10::make_intrusive<DummyTorchWork>();
147+
}
148+
149+
c10::intrusive_ptr<TorchWork> TorchCommDummy::reduce(
150+
const at::Tensor& tensor,
151+
int root,
152+
const ReduceOp& op,
153+
bool async_op,
154+
const ReduceOptions& options) {
155+
return c10::make_intrusive<DummyTorchWork>();
156+
}
157+
158+
c10::intrusive_ptr<TorchWork> TorchCommDummy::all_gather(
159+
const std::vector<at::Tensor>& tensor_list,
160+
const at::Tensor& tensor,
161+
bool async_op,
162+
const AllGatherOptions& options) {
163+
return c10::make_intrusive<DummyTorchWork>();
164+
}
165+
166+
c10::intrusive_ptr<TorchWork> TorchCommDummy::all_gather_v(
167+
const std::vector<at::Tensor>& tensor_list,
168+
const at::Tensor& tensor,
169+
bool async_op,
170+
const AllGatherOptions& options) {
171+
return c10::make_intrusive<DummyTorchWork>();
172+
}
173+
174+
c10::intrusive_ptr<TorchWork> TorchCommDummy::all_gather_single(
175+
at::Tensor& output,
176+
const at::Tensor& input,
177+
bool async_op,
178+
const AllGatherSingleOptions& options) {
179+
return c10::make_intrusive<DummyTorchWork>();
180+
}
181+
182+
c10::intrusive_ptr<TorchWork> TorchCommDummy::reduce_scatter(
183+
at::Tensor& output,
184+
const std::vector<at::Tensor>& input_list,
185+
const ReduceOp& op,
186+
bool async_op,
187+
const ReduceScatterOptions& options) {
188+
return c10::make_intrusive<DummyTorchWork>();
189+
}
190+
191+
c10::intrusive_ptr<TorchWork> TorchCommDummy::reduce_scatter_v(
192+
at::Tensor& output,
193+
const std::vector<at::Tensor>& input_list,
194+
const ReduceOp& op,
195+
bool async_op,
196+
const ReduceScatterOptions& options) {
197+
return c10::make_intrusive<DummyTorchWork>();
198+
}
199+
200+
c10::intrusive_ptr<TorchWork> TorchCommDummy::reduce_scatter_single(
201+
at::Tensor& output,
202+
const at::Tensor& input,
203+
const ReduceOp& op,
204+
bool async_op,
205+
const ReduceScatterSingleOptions& options) {
206+
return c10::make_intrusive<DummyTorchWork>();
207+
}
208+
209+
c10::intrusive_ptr<TorchWork> TorchCommDummy::all_to_all_single(
210+
at::Tensor& output,
211+
const at::Tensor& input,
212+
bool async_op,
213+
const AllToAllSingleOptions& options) {
214+
return c10::make_intrusive<DummyTorchWork>();
215+
}
216+
217+
c10::intrusive_ptr<TorchWork> TorchCommDummy::all_to_all_v_single(
218+
at::Tensor& output,
219+
const at::Tensor& input,
220+
const std::vector<uint64_t>& output_split_sizes,
221+
const std::vector<uint64_t>& input_split_sizes,
222+
bool async_op,
223+
const AllToAllvSingleOptions& options) {
224+
return c10::make_intrusive<DummyTorchWork>();
225+
}
226+
227+
c10::intrusive_ptr<TorchWork> TorchCommDummy::all_to_all(
228+
const std::vector<at::Tensor>& output_tensor_list,
229+
const std::vector<at::Tensor>& input_tensor_list,
230+
bool async_op,
231+
const AllToAllOptions& options) {
232+
return c10::make_intrusive<DummyTorchWork>();
233+
}
234+
235+
c10::intrusive_ptr<TorchWork> TorchCommDummy::barrier(
236+
bool async_op,
237+
const BarrierOptions& options) {
238+
return c10::make_intrusive<DummyTorchWork>();
239+
}
240+
241+
c10::intrusive_ptr<TorchWork> TorchCommDummy::scatter(
242+
at::Tensor& output_tensor,
243+
const std::vector<at::Tensor>& input_tensor_list,
244+
int root,
245+
bool async_op,
246+
const ScatterOptions& options) {
247+
return c10::make_intrusive<DummyTorchWork>();
248+
}
249+
250+
c10::intrusive_ptr<TorchWork> TorchCommDummy::gather(
251+
const std::vector<at::Tensor>& output_tensor_list,
252+
const at::Tensor& input_tensor,
253+
int root,
254+
bool async_op,
255+
const GatherOptions& options) {
256+
return c10::make_intrusive<DummyTorchWork>();
257+
}
258+
259+
std::shared_ptr<TorchCommWindow> TorchCommDummy::window_allocate(
260+
const size_t window_size,
261+
bool cpu_buf,
262+
const size_t signal_size) {
263+
auto win = std::make_shared<DummyTorchCommWindow>();
264+
win->allocate(window_size, cpu_buf, signal_size);
265+
return win;
266+
}
267+
268+
std::shared_ptr<TorchCommBackend> TorchCommDummy::split(
269+
const std::vector<int>& ranks,
270+
const std::string& name,
271+
const CommOptions& options) {
272+
(void)ranks;
273+
(void)name;
274+
(void)options;
275+
return std::make_shared<TorchCommDummy>();
276+
}
277+
278+
const CommOptions& TorchCommDummy::getOptions() const {
279+
return options_;
280+
}
281+
282+
const at::Device& TorchCommDummy::getDevice() const {
283+
return device_;
284+
}
285+
286+
namespace {
287+
class DummyRegistration {
288+
public:
289+
DummyRegistration() {
290+
TorchCommFactory::get().register_backend(
291+
"dummy", []() { return std::make_shared<TorchCommDummy>(); });
292+
}
293+
};
294+
295+
static DummyRegistration registration{};
296+
} // namespace
297+
298+
} // namespace comms
299+
} // namespace torch

comms/torchcomms/tests/unit/cpp/DummyTorchCommBackend.hpp renamed to comms/torchcomms/TorchCommDummy.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
namespace torch {
1010
namespace comms {
1111

12-
class DummyTorchCommBackend : public TorchCommBackend {
12+
class TorchCommDummy : public TorchCommBackend {
1313
public:
1414
static constexpr std::string_view kBackendName = "dummy";
1515

16-
DummyTorchCommBackend();
17-
~DummyTorchCommBackend() override = default;
16+
TorchCommDummy();
17+
~TorchCommDummy() override = default;
1818

1919
// Initialize the communication backend
2020
void init(

comms/torchcomms/TorchCommOptions.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#include <c10/util/intrusive_ptr.h>
88
#include <comms/torchcomms/TorchCommTypes.hpp>
99
#include <torch/csrc/distributed/c10d/Store.hpp> // @manual=//caffe2:torch-cpp-cpu
10-
#include <torch/csrc/distributed/c10d/Work.hpp> // @manual=//caffe2:torch-cpp-cpu
1110
#include <chrono>
1211
#include <string>
1312
#include <unordered_map>

comms/torchcomms/TorchCommTypes.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
#include <ATen/ATen.h>
66
#include <c10/core/Device.h>
77
#include <c10/util/intrusive_ptr.h>
8-
#include <torch/csrc/distributed/c10d/Store.hpp> // @manual=//caffe2:torch-cpp-cpu
9-
#include <torch/csrc/distributed/c10d/Work.hpp> // @manual=//caffe2:torch-cpp-cpu
108
#include <chrono>
119
#include <variant>
1210

comms/torchcomms/TorchCommUtils.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44

55
#include <string>
66

7-
#include <torch/csrc/distributed/c10d/Store.hpp> // @manual=//caffe2:torch-cpp-cpu
8-
97
namespace torch {
108
namespace comms {
119

0 commit comments

Comments
 (0)