Skip to content

Commit 19c9727

Browse files
feat: add Work definition, fix gradient_as_bucket_view option
1 parent bc69a62 commit 19c9727

File tree

10 files changed

+324
-104
lines changed

10 files changed

+324
-104
lines changed

infini_train/include/nn/parallel/process_group.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#endif
1212

1313
#include "infini_train/include/nn/parallel/reduce_op_type.h"
14+
#include "infini_train/include/nn/parallel/work.h"
1415

1516
namespace infini_train {
1617
class Tensor;
@@ -55,10 +56,8 @@ class ProcessGroup {
5556

5657
std::vector<std::shared_ptr<Tensor>> NcclRecv(std::vector<std::shared_ptr<Tensor>> tensors, int src_rank) const;
5758

58-
// Overlap helper functions
59-
void EnqueueAllReduce(cudaEvent_t ready_event, cudaEvent_t done_event, const std::shared_ptr<Tensor> &tensor,
60-
function::ReduceOpType reduce_op) const;
61-
void WaitAllReduceDone(cudaEvent_t done_event, const std::shared_ptr<Tensor> &tensor) const;
59+
// Async communication functions
60+
std::shared_ptr<Work> AllReduceAsync(const std::shared_ptr<Tensor> &tensor, function::ReduceOpType reduce_op) const;
6261

6362
private:
6463
std::vector<ncclComm_t> comms_;

infini_train/include/nn/parallel/reducer.h

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,15 @@ std::vector<std::vector<size_t>> ComputeBucketAssignmentBySize(const std::vector
3030
const std::vector<size_t> &tensor_indices = {});
3131

3232
struct ReducerOptions {
33+
// Pack all Reducer-related args together
34+
// Ref: https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
35+
3336
// Max capacity for each bucket(in MB)
3437
size_t first_bucket_cap_mb = 128;
3538
size_t normal_bucket_cap_mb = 512;
3639

3740
// When set true, map param.grad directly to the slice of bucket.flat(same address in memory) instead of memcpy
38-
bool gradient_as_bucket_view = false;
41+
bool gradient_as_bucket_view = true;
3942
};
4043

4144
// DDP Reducer that handles gradient bucketing in backward
@@ -50,7 +53,9 @@ class Reducer : public std::enable_shared_from_this<Reducer> {
5053
*/
5154
explicit Reducer(std::vector<std::shared_ptr<Tensor>> parameters, std::vector<std::vector<size_t>> bucket_indices,
5255
const ReducerOptions &opts);
53-
~Reducer();
56+
57+
// Attach PostAllReduceHooks to params
58+
void AttachHooksToParameters();
5459

5560
// Prepare bucket info for next step
5661
void PrepareForBackward();
@@ -91,7 +96,7 @@ class Reducer : public std::enable_shared_from_this<Reducer> {
9196

9297
// Views into the `gradients` tensor for each individual gradient
9398
std::vector<std::shared_ptr<Tensor>> bucket_views_in;
94-
// NOTE(zbl): reserved for occasions where grads have different stride/layout
99+
// TODO(zbl): reserved for occasions where grads have different stride/layout
95100
std::vector<std::shared_ptr<Tensor>> bucket_views_out;
96101

97102
// Number of gradients left to be computed before the bucket is ready to be reduced
@@ -104,18 +109,10 @@ class Reducer : public std::enable_shared_from_this<Reducer> {
104109
// If `true`, then this implies that `bucket.variables.size() == 1`.
105110
// TODO(zbl): support logics for sparse gradient later
106111
bool expect_sparse_gradient = false;
107-
108-
#ifdef USE_CUDA
109-
// Event to mark that AllReduce is completed
110-
cudaEvent_t allreduce_done = nullptr;
111-
// Event to mark that all tensors' grad in bucket are ready
112-
cudaEvent_t bucket_ready = nullptr;
113-
#endif
114112
};
115113

116114
private:
117115
void InitializeBuckets(const std::vector<std::vector<size_t>> &bucket_indices);
118-
void AttachHooksToParameters();
119116

120117
// NOTE(zbl): all grads are assumed dense and stored continously in bucket for now
121118
void MarkVariableReadyDense(size_t variable_index);
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#pragma once
2+
3+
#include <atomic>
4+
#include <chrono>
5+
#include <exception>
6+
#include <memory>
7+
#include <mutex>
8+
9+
#ifdef USE_CUDA
10+
#include <cuda_runtime.h>
11+
#endif
12+
#ifdef USE_NCCL
13+
#include <nccl.h>
14+
#endif
15+
16+
#include "infini_train/include/device.h"
17+
18+
namespace infini_train::nn::parallel {
19+
20+
class Work {
21+
public:
22+
virtual ~Work() = default;
23+
24+
virtual bool Wait(std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()) = 0;
25+
26+
virtual bool IsCompleted() const = 0;
27+
virtual bool IsSuccess() const = 0;
28+
29+
virtual void Synchronize() const = 0;
30+
31+
virtual std::exception_ptr exception() const = 0;
32+
33+
virtual void *ready_event() const = 0;
34+
virtual void *done_event() const = 0;
35+
};
36+
37+
#ifdef USE_NCCL
38+
class WorkNccl final : public Work {
39+
public:
40+
WorkNccl(const Device *device, ncclComm_t comm);
41+
~WorkNccl() override;
42+
43+
bool Wait(std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()) override;
44+
45+
bool IsCompleted() const override;
46+
bool IsSuccess() const override;
47+
48+
void Synchronize() const override;
49+
50+
std::exception_ptr exception() const override { return exception_; };
51+
52+
void *ready_event() const override { return reinterpret_cast<void *>(ready_event_); };
53+
void *done_event() const override { return reinterpret_cast<void *>(done_event_); };
54+
55+
private:
56+
bool CheckNcclStatus();
57+
void SetException(std::exception_ptr e);
58+
59+
private:
60+
Device *device_ = nullptr;
61+
cudaEvent_t ready_event_;
62+
cudaEvent_t done_event_;
63+
ncclComm_t comm_;
64+
65+
mutable std::mutex mutex_;
66+
std::exception_ptr exception_;
67+
std::atomic<bool> completed_{false};
68+
std::atomic<bool> success_{false};
69+
};
70+
#endif
71+
72+
} // namespace infini_train::nn::parallel

infini_train/include/tensor.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ class Tensor : public std::enable_shared_from_this<Tensor> {
8282
Tensor To(const Device *device);
8383
Tensor To(DataType dtype);
8484

85+
void CopyFrom(const Tensor &src);
86+
void CopyFrom(const std::shared_ptr<Tensor> &src);
87+
8588
// operator overloading
8689
std::shared_ptr<Tensor> Equals(const std::shared_ptr<Tensor> &other);
8790
std::shared_ptr<Tensor> Equals(float scalar);

infini_train/src/autograd/accumulate.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,16 @@ AccumulateGrad::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_output
2727
if (grad_output) {
2828
if (grad) {
2929
if (tensor_->ConsumeGradOverwriteFlag()) {
30-
auto new_grad = std::make_shared<Tensor>(*grad_output.get(), 0, grad_output->Dims());
31-
tensor_->set_grad(std::move(new_grad));
30+
// If the tensor is marked to overrite its current grad on next grad update
31+
// See notes in `infini_train::nn::parallel::Reducer::PrepareForBackward()`
32+
// NOTE(zbl): must copy, cannot change grad buffer address
33+
grad->CopyFrom(grad_output);
3234
} else {
3335
auto kernel = Dispatcher::Instance().GetKernel({device->Type(), "AccumulateGrad"});
3436
kernel.Call<void>(grad_output, learning_rate_, grad);
3537
}
3638
} else {
39+
// NOTE(zbl): check whether need to do copying instead of slicing
3740
auto new_grad = std::make_shared<Tensor>(*grad_output.get(), 0, grad_output->Dims());
3841
tensor_->set_grad(std::move(new_grad));
3942
}

infini_train/src/nn/parallel/distributed_data_parallel.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ constexpr char kModuleName[] = "module";
2020
DistributedDataParallel::DistributedDataParallel(std::shared_ptr<nn::Module> module, int device_id,
2121
const ReducerOptions &opts) {
2222
for (auto &param : module->Parameters()) {
23-
auto device = param->GetDevice();
24-
CHECK_EQ(device->Index(), device_id) << "All parameters must be on the same device as the module";
23+
CHECK_EQ(param->GetDevice()->Index(), device_id) << "All parameters must be on the same device as the module";
2524
}
2625
for (auto &buffer : module->Buffers()) {
2726
CHECK_EQ(buffer->GetDevice()->Index(), device_id) << "All buffers must be on the same device as the module";
@@ -35,7 +34,8 @@ DistributedDataParallel::DistributedDataParallel(std::shared_ptr<nn::Module> mod
3534
std::vector<size_t> bucket_size_limits = {first_cap_bytes, normal_cap_bytes};
3635
auto bucket_indices = ComputeBucketAssignmentBySize(params, bucket_size_limits);
3736

38-
reducer_ = std::make_shared<Reducer>(std::move(params), bucket_indices, opts);
37+
reducer_ = std::make_shared<Reducer>(params, bucket_indices, opts);
38+
reducer_->AttachHooksToParameters();
3939
}
4040

4141
std::vector<std::shared_ptr<Tensor>>

infini_train/src/nn/parallel/process_group.cc

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -335,36 +335,34 @@ std::vector<std::shared_ptr<Tensor>> ProcessGroup::NcclRecv(std::vector<std::sha
335335
return tensors;
336336
}
337337

338-
void ProcessGroup::EnqueueAllReduce(cudaEvent_t ready_event, cudaEvent_t done_event,
339-
const std::shared_ptr<Tensor> &tensor, function::ReduceOpType reduce_op) const {
340-
CHECK(ready_event && done_event) << "Events must be created.";
338+
std::shared_ptr<Work> ProcessGroup::AllReduceAsync(const std::shared_ptr<Tensor> &tensor,
339+
function::ReduceOpType reduce_op) const {
340+
void *buffer = tensor->DataPtr();
341341
const auto *device = dynamic_cast<const CudaDevice *>(tensor->GetDevice());
342-
CHECK(std::find(devices_.begin(), devices_.end(), device) != devices_.end())
343-
<< "Device of target Tensor is not in current ProcessGroup";
342+
device->SetDevice();
343+
344+
auto comm = device_comm_map_.at(device);
344345

345346
cudaStream_t compute_stream = device->Stream();
346347
cudaStream_t comm_stream = device_stream_map_.at(device);
347348

348-
cudaEventRecord(ready_event, compute_stream);
349-
cudaStreamWaitEvent(comm_stream, ready_event, 0);
349+
auto work = std::make_shared<WorkNccl>(device, comm);
350+
351+
cudaEvent_t ready_event = reinterpret_cast<cudaEvent_t>(work->ready_event());
352+
cudaEvent_t done_event = reinterpret_cast<cudaEvent_t>(work->done_event());
353+
354+
CUDA_CHECK(cudaEventRecord(ready_event, compute_stream));
355+
CUDA_CHECK(cudaStreamWaitEvent(comm_stream, ready_event, 0));
350356

351357
// Perform NcclAllReduce on comm stream
352-
device->SetDevice();
353-
NCCL_CHECK(ncclAllReduce(tensor->DataPtr(), tensor->DataPtr(), tensor->NumElements(),
354-
kNcclDtypeMap.at(tensor->Dtype()), kNcclReduceOpMap.at(reduce_op),
355-
device_comm_map_.at(device), comm_stream));
358+
NCCL_CHECK(ncclAllReduce(buffer, buffer, tensor->NumElements(), kNcclDtypeMap.at(tensor->Dtype()),
359+
kNcclReduceOpMap.at(reduce_op), comm, comm_stream));
356360

357-
cudaEventRecord(done_event, comm_stream);
358-
}
361+
CUDA_CHECK(cudaEventRecord(done_event, comm_stream));
362+
CUDA_CHECK(cudaStreamWaitEvent(compute_stream, done_event, 0));
359363

360-
void ProcessGroup::WaitAllReduceDone(cudaEvent_t done_event, const std::shared_ptr<Tensor> &tensor) const {
361-
CHECK(done_event) << "Events must be created.";
362-
const auto *device = dynamic_cast<const CudaDevice *>(tensor->GetDevice());
363-
CHECK(std::find(devices_.begin(), devices_.end(), device) != devices_.end())
364-
<< "Device of target Tensor is not in current ProcessGroup";
365-
cudaStreamWaitEvent(device->Stream(), done_event, 0);
364+
return std::move(work);
366365
}
367-
368366
#endif
369367

370368
ProcessGroupFactory *ProcessGroupFactory::Instance() {

infini_train/src/nn/parallel/reducer.cc

Lines changed: 18 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "infini_train/include/common/cuda/common_cuda.h"
1717
#include "infini_train/include/device.h"
1818
#include "infini_train/include/nn/parallel/utils.h"
19+
#include "infini_train/include/nn/parallel/work.h"
1920

2021
namespace infini_train::nn::parallel {
2122
namespace {
@@ -178,43 +179,9 @@ Reducer::Reducer(std::vector<std::shared_ptr<Tensor>> parameters, std::vector<st
178179
: params_(std::move(parameters)), opts_(opts) {
179180
BuildBuckets(bucket_indices);
180181
ready_seen_this_iter_.assign(params_.size(), 0);
181-
AttachHooksToParameters();
182-
}
183-
184-
Reducer::~Reducer() {
185-
#ifdef USE_CUDA
186-
for (auto &b : buckets_) {
187-
if (!b.contents) {
188-
continue;
189-
}
190-
if (b.contents->GetDevice()->Type() == DeviceType::kCUDA) {
191-
if (b.allreduce_done) {
192-
CUDA_CHECK(cudaEventDestroy(b.allreduce_done));
193-
}
194-
if (b.bucket_ready) {
195-
CUDA_CHECK(cudaEventDestroy(b.bucket_ready));
196-
}
197-
}
198-
}
199-
#endif
200182
}
201183

202184
void Reducer::InitializeBuckets(const std::vector<std::vector<size_t>> &bucket_indices) {
203-
#ifdef USE_CUDA
204-
for (auto &b : buckets_) {
205-
if (!b.contents) {
206-
continue;
207-
}
208-
if (b.contents->GetDevice()->Type() == DeviceType::kCUDA) {
209-
if (b.allreduce_done) {
210-
CUDA_CHECK(cudaEventDestroy(b.allreduce_done));
211-
}
212-
if (b.bucket_ready) {
213-
CUDA_CHECK(cudaEventDestroy(b.bucket_ready));
214-
}
215-
}
216-
}
217-
#endif
218185
buckets_.clear();
219186
locators_.clear();
220187
next_bucket_ = 0;
@@ -235,16 +202,6 @@ void Reducer::InitializeBucketViews(Bucket &bucket) {
235202
}
236203
// Set (out == in) by default when all grads are dense
237204
bucket.bucket_views_out = bucket.bucket_views_in;
238-
239-
if (opts_.gradient_as_bucket_view) {
240-
for (size_t i = 0; i < bucket.variables.size(); ++i) {
241-
auto &v = bucket.variables[i];
242-
auto g = v->grad();
243-
if (g && g.get() != bucket.bucket_views_in[i].get()) {
244-
v->set_grad(bucket.bucket_views_in[i]);
245-
}
246-
}
247-
}
248205
}
249206

250207
void Reducer::BuildBuckets(const std::vector<std::vector<size_t>> &bucket_indices) {
@@ -280,16 +237,8 @@ void Reducer::BuildBuckets(const std::vector<std::vector<size_t>> &bucket_indice
280237
auto dev = bucket.variables.front()->GetDevice();
281238
bucket.contents
282239
= std::make_shared<Tensor>(std::vector<int64_t>{static_cast<int64_t>(total_elems)}, bucket.dtype, dev);
283-
// bucket.contents->Fill(0);
284240
bucket.pending = bucket.variables.size();
285241

286-
#ifdef USE_CUDA
287-
if (bucket.contents->GetDevice()->Type() == DeviceType::kCUDA) {
288-
CUDA_CHECK(cudaEventCreateWithFlags(&bucket.allreduce_done, cudaEventDisableTiming));
289-
CUDA_CHECK(cudaEventCreateWithFlags(&bucket.bucket_ready, cudaEventDisableTiming));
290-
}
291-
#endif
292-
293242
bucket.variable_indices = bucket_indices[bucket_idx];
294243
InitializeBucketViews(bucket);
295244
buckets_.push_back(std::move(bucket));
@@ -368,11 +317,18 @@ void Reducer::PrepareForBackward() {
368317
auto view = bucket.bucket_views_in[i];
369318
auto grad = param->grad();
370319

371-
if (grad == nullptr) {
372-
param->MarkGradOverwriteOnNextAccum();
320+
// NOTE(zbl): This will affect behaviors in `infini_train::autograd::AccumulateGrad::Backward()`
321+
// If ZeroGrad(set_to_none=True), grad is nullptr at this point
322+
// If ZeroGrad(set_to_none=False), grad is set to view of bucket.contents (or modified by user)
323+
// Either way, we reset grad to view of bucket.contents
324+
// Since bucket.contents might not be zeroed, we need to overwrite it on next grad accumulation
325+
if (!grad || (grad.get() != view.get())) {
326+
if (grad) {
327+
LOG(WARNING) << "gradient_as_bucket_view is enabled, but param " << param
328+
<< " has a non-view grad tensor. Automatically overwriting it with bucket view.";
329+
}
373330
param->set_grad(view);
374-
} else {
375-
CHECK_EQ(grad.get(), view.get()) << "Param's gradient should be a slice of bucket's flat buffer.";
331+
param->MarkGradOverwriteOnNextAccum();
376332
}
377333
}
378334
}
@@ -456,25 +412,24 @@ void Reducer::FinalizeBucketDense(size_t bucket_index) {
456412
auto &bucket = buckets_.at(bucket_index);
457413
auto ddp_pg = ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(bucket.device_rank));
458414

415+
std::shared_ptr<Work> work;
459416
if (comm_hook_) {
460417
std::vector<std::shared_ptr<Tensor>> bucket_view{bucket.contents};
461418
// NOTE(zbl): Custom hook should do in-place operations
462419
// e.g. comm_hook_(GradBucket{bucket_view})[0];
463420
// FIXME(zbl): support custom hook later
464421
LOG(FATAL) << "Custom hook is not supported now";
465422
} else {
466-
ddp_pg->EnqueueAllReduce(bucket.bucket_ready, bucket.allreduce_done, bucket.contents,
467-
function::ReduceOpType::kAvg);
423+
work = ddp_pg->AllReduceAsync(bucket.contents, function::ReduceOpType::kAvg);
468424
}
469425

470426
if (!opts_.gradient_as_bucket_view) {
471427
for (size_t i = 0; i < bucket.variables.size(); ++i) {
472-
// Directly assgin bucket slice to grad instead of copying
473-
// Same behavior as `CopyBucketToGrad(bucket.contents, bucket.variables[i]->grad(), bucket.offsets[i]);`
474-
bucket.variables[i]->set_grad(bucket.bucket_views_in[i]);
428+
// NOTE(zbl): For better performance, try `bucket.variables[i]->set_grad(bucket.bucket_views_in[i]);`
429+
// to directly assgin bucket slice to grad instead of copying
430+
CopyBucketToGrad(bucket.contents, bucket.variables[i]->grad(), bucket.offsets[i]);
431+
// bucket.variables[i]->set_grad(bucket.bucket_views_in[i]);
475432
}
476433
}
477-
478-
ddp_pg->WaitAllReduceDone(bucket.allreduce_done, bucket.contents);
479434
}
480435
} // namespace infini_train::nn::parallel

0 commit comments

Comments
 (0)