From b3c05124cf441f58763e071c8ee0f64d0f22cbbf Mon Sep 17 00:00:00 2001 From: bolunz Date: Wed, 12 Nov 2025 14:20:21 +0800 Subject: [PATCH 1/5] feat: add DDP gradient bucketing, support compute/comm dual stream --- .../nn/parallel/distributed_data_parallel.h | 7 +- .../include/nn/parallel/process_group.h | 10 + infini_train/include/nn/parallel/reducer.h | 152 ++++++ infini_train/include/tensor.h | 4 + infini_train/src/autograd/accumulate.cc | 9 +- .../nn/parallel/distributed_data_parallel.cc | 21 +- infini_train/src/nn/parallel/process_group.cc | 57 +++ infini_train/src/nn/parallel/reducer.cc | 480 ++++++++++++++++++ infini_train/src/tensor.cc | 8 + 9 files changed, 738 insertions(+), 10 deletions(-) create mode 100644 infini_train/include/nn/parallel/reducer.h create mode 100644 infini_train/src/nn/parallel/reducer.cc diff --git a/infini_train/include/nn/parallel/distributed_data_parallel.h b/infini_train/include/nn/parallel/distributed_data_parallel.h index 5809dc00..df214156 100644 --- a/infini_train/include/nn/parallel/distributed_data_parallel.h +++ b/infini_train/include/nn/parallel/distributed_data_parallel.h @@ -3,6 +3,7 @@ #include #include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/parallel/reducer.h" namespace infini_train { class Tensor; @@ -13,9 +14,13 @@ namespace infini_train::nn::parallel { class DistributedDataParallel : public nn::Module { public: - DistributedDataParallel(std::shared_ptr module, int device_id); + DistributedDataParallel(std::shared_ptr module, int device_id, + const ReducerOptions &opts = ReducerOptions{}); std::vector> Forward(const std::vector> &input_tensors) override; + +private: + std::shared_ptr reducer_; }; } // namespace infini_train::nn::parallel diff --git a/infini_train/include/nn/parallel/process_group.h b/infini_train/include/nn/parallel/process_group.h index 5919e054..3bcf9c2a 100644 --- a/infini_train/include/nn/parallel/process_group.h +++ b/infini_train/include/nn/parallel/process_group.h @@ -28,8 +28,11 @@ class ProcessGroup { public: explicit ProcessGroup(const std::vector &device_indices); + ~ProcessGroup(); + int GetGroupRank(int thread_rank) const; + // Communication operations void AllReduce(const std::shared_ptr &tensor, function::ReduceOpType reduce_op) const; void AllGather(const std::shared_ptr &output, const std::shared_ptr &input) const; @@ -52,11 +55,18 @@ class ProcessGroup { std::vector> NcclRecv(std::vector> tensors, int src_rank) const; + // Overlap helper functions + void EnqueueAllReduce(cudaEvent_t ready_event, cudaEvent_t done_event, const std::shared_ptr &tensor, + function::ReduceOpType reduce_op) const; + void WaitAllReduceDone(cudaEvent_t done_event, const std::shared_ptr &tensor) const; + private: std::vector comms_; + std::vector comm_streams_; std::vector devices_; std::unordered_map device_comm_map_; + std::unordered_map device_stream_map_; std::unordered_map thread_group_rank_map_; // thread_rank : group_rank int comm_size_ = 0; diff --git a/infini_train/include/nn/parallel/reducer.h b/infini_train/include/nn/parallel/reducer.h new file mode 100644 index 00000000..4685094a --- /dev/null +++ b/infini_train/include/nn/parallel/reducer.h @@ -0,0 +1,152 @@ +#pragma once + +#include +#include +#include + +#include "infini_train/include/autograd/function_hook.h" +#include "infini_train/include/nn/parallel/parallel_functional.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::parallel { + +// GradBucket passes bucket contents tensor to DDP communication hook. +// ref: https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/comm.hpp +class GradBucket { +public: + explicit GradBucket(const std::vector> &tensors) : tensors_(tensors) {} + const std::vector> &getTensors() const { return tensors_; } + +private: + std::vector> tensors_; +}; + +// Compute bucket assignment according to the size of each tensors and bucket capacity. +// Returns the indices of tensors in the corrsponding bucket, i.e. output[bucket_i] = {tensor_j, tensor_k, ...} +// The index of tensors[idx] assigned to bucket(j and k above) is tensor_indices[idx]. +// When tensor_indices is empty, the index of tensors[idx] assigned to bucket(j and k above) is idx itself. +std::vector> ComputeBucketAssignmentBySize(const std::vector> &tensors, + const std::vector &bucket_size_limits, + const std::vector &tensor_indices = {}); + +struct ReducerOptions { + // Max capacity for each bucket(in MB) + size_t first_bucket_cap_mb = 128; + size_t normal_bucket_cap_mb = 512; + + // When set true, map param.grad directly to the slice of bucket.flat(same address in memory) instead of memcpy + bool gradient_as_bucket_view = false; +}; + +// DDP Reducer that handles gradient bucketing in backward +// ref: https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/reducer.hpp +class Reducer : public std::enable_shared_from_this { +public: + /** @brief Constructor of Reducer + * + * @param parameters A list of parameters for this process's single model replica + * @param bucket_indices The bucket assignment for this reducer + * @param opts Other options, see definition of ReducerOptions + */ + explicit Reducer(std::vector> parameters, std::vector> bucket_indices, + const ReducerOptions &opts); + ~Reducer(); + + // Prepare bucket info for next step + void PrepareForBackward(); + + // For custom DDP hook to overwrite the default AllReduce. T + // This can be used for algorithms like Gradient Compression/GossipGrad. + // Hook is registered using `Reducer::RegisterCommHook()`. + // TODO(zbl): Leave the placeholder for the moment + void RegisterCommHook(std::shared_ptr hook); + + // Return every tensor in bucket's flat buffer + std::vector>> GetBucketTensors() const; + +private: + // A variable locator locates a particular variable in the reducer's buckets + struct VariableLocator { + // Index of the bucket containing the variable in the `buckets_` vector + size_t bucket_index = 0; + // Index of the variable in the bucket + size_t intra_bucket_index = 0; + }; + + // Bucket used in DDP backward + struct Bucket { + // Gradients of the bucket flattened into a 1-dimensional tensor + std::shared_ptr contents; + DataType dtype; + int device_rank = 0; + + // Variables whose gradients are held in this bucket + std::vector> variables; + + // Per-variable offset/length into the flattened `gradients` tensor and + // the corresponding `GradBucket` instance for communication hooks + // In terms of element count, not bytes + std::vector offsets; + std::vector lengths; + + // Views into the `gradients` tensor for each individual gradient + std::vector> bucket_views_in; + // NOTE(zbl): reserved for occasions where grads have different stride/layout + std::vector> bucket_views_out; + + // Number of gradients left to be computed before the bucket is ready to be reduced + size_t pending; + + // Global indices of participating variables in the bucket + std::vector variable_indices; + + // If this bucket should expect a single sparse gradient + // If `true`, then this implies that `bucket.variables.size() == 1`. + // TODO(zbl): support logics for sparse gradient later + bool expect_sparse_gradient = false; + +#ifdef USE_CUDA + // Event to mark that AllReduce is completed + cudaEvent_t allreduce_done = nullptr; + // Event to mark that all tensors' grad in bucket are ready + cudaEvent_t bucket_ready = nullptr; +#endif + }; + +private: + void InitializeBuckets(const std::vector> &bucket_indices); + void AttachHooksToParameters(); + + // NOTE(zbl): all grads are assumed dense and stored continously in bucket for now + void MarkVariableReadyDense(size_t variable_index); + void MarkBucketReady(size_t bucket_index); + void FinalizeBucketDense(size_t bucket_index); + + void BuildBuckets(const std::vector> &bucket_indices); + void InitializeBucketViews(Bucket &bucket); + void RebuildBuckets(); + +private: + mutable std::mutex mutex_; + std::vector> params_; + std::vector buckets_; + std::vector locators_; + + std::atomic buckets_finished_{0}; + std::shared_ptr comm_hook_ = nullptr; + ReducerOptions opts_; + + // Next bucket to be reduced + // This is to make sure that all-reduce of buckets be launched in the order we expect + size_t next_bucket_ = 0; + // To record the order of params getting ready on first step + std::vector grad_ready_order_indices_; + // To record whether each param is ready on first step + std::vector ready_seen_this_iter_; + // Whether to rebuild buckets on next train step + bool need_rebuild_ = false; + // Whether to buckets have already been rebuilt on the second step + bool has_rebuilt_bucket_ = false; +}; + +} // namespace infini_train::nn::parallel diff --git a/infini_train/include/tensor.h b/infini_train/include/tensor.h index 9e3fbdb4..4932bce4 100644 --- a/infini_train/include/tensor.h +++ b/infini_train/include/tensor.h @@ -208,6 +208,8 @@ class Tensor : public std::enable_shared_from_this { void set_output_idx(int output_idx); void ZeroGrad(bool set_to_none = true); + void MarkGradOverwriteOnNextAccum(); + bool ConsumeGradOverwriteFlag(); void Backward(std::shared_ptr gradient = nullptr, bool retain_graph = false, bool create_graph = false) const; @@ -229,6 +231,8 @@ class Tensor : public std::enable_shared_from_this { // a strong reference to the accumulator to manage its lifetime. std::shared_ptr grad_accumulator_ = nullptr; std::shared_ptr post_accumulate_grad_hook_ = nullptr; + + bool grad_overwrite_once_ = false; }; std::shared_ptr operator==(const std::shared_ptr &t, float scalar); diff --git a/infini_train/src/autograd/accumulate.cc b/infini_train/src/autograd/accumulate.cc index a635cb38..4e3204f6 100644 --- a/infini_train/src/autograd/accumulate.cc +++ b/infini_train/src/autograd/accumulate.cc @@ -26,8 +26,13 @@ AccumulateGrad::Backward(const std::vector> &grad_output if (grad_output) { if (grad) { - auto kernel = Dispatcher::Instance().GetKernel({device->Type(), "AccumulateGrad"}); - kernel.Call(grad_output, learning_rate_, grad); + if (tensor_->ConsumeGradOverwriteFlag()) { + auto new_grad = std::make_shared(*grad_output.get(), 0, grad_output->Dims()); + tensor_->set_grad(std::move(new_grad)); + } else { + auto kernel = Dispatcher::Instance().GetKernel({device->Type(), "AccumulateGrad"}); + kernel.Call(grad_output, learning_rate_, grad); + } } else { auto new_grad = std::make_shared(*grad_output.get(), 0, grad_output->Dims()); tensor_->set_grad(std::move(new_grad)); diff --git a/infini_train/src/nn/parallel/distributed_data_parallel.cc b/infini_train/src/nn/parallel/distributed_data_parallel.cc index 26ca62c5..c6212770 100644 --- a/infini_train/src/nn/parallel/distributed_data_parallel.cc +++ b/infini_train/src/nn/parallel/distributed_data_parallel.cc @@ -17,25 +17,32 @@ namespace { constexpr char kModuleName[] = "module"; } // namespace -DistributedDataParallel::DistributedDataParallel(std::shared_ptr module, int device_id) { +DistributedDataParallel::DistributedDataParallel(std::shared_ptr module, int device_id, + const ReducerOptions &opts) { for (auto ¶m : module->Parameters()) { auto device = param->GetDevice(); CHECK_EQ(device->Index(), device_id) << "All parameters must be on the same device as the module"; - - auto ddp_pg - = ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(device->rank().thread_rank())); - auto hook = std::make_unique(function::ReduceOpType::kAvg, - ddp_pg); - param->RegisterPostAccumulateGradHook(std::move(hook)); } for (auto &buffer : module->Buffers()) { CHECK_EQ(buffer->GetDevice()->Index(), device_id) << "All buffers must be on the same device as the module"; } modules_[kModuleName] = std::move(module); + + // Bucket Assignment + auto params = modules_[kModuleName]->Parameters(); + const size_t first_cap_bytes = opts.first_bucket_cap_mb * 1024ULL * 1024ULL; + const size_t normal_cap_bytes = opts.normal_bucket_cap_mb * 1024ULL * 1024ULL; + std::vector bucket_size_limits = {first_cap_bytes, normal_cap_bytes}; + auto bucket_indices = ComputeBucketAssignmentBySize(params, bucket_size_limits); + + reducer_ = std::make_shared(std::move(params), bucket_indices, opts); } std::vector> DistributedDataParallel::Forward(const std::vector> &input_tensors) { + if (reducer_) { + reducer_->PrepareForBackward(); + } return modules_[kModuleName]->Forward(input_tensors); } diff --git a/infini_train/src/nn/parallel/process_group.cc b/infini_train/src/nn/parallel/process_group.cc index 3271331c..64de2df1 100644 --- a/infini_train/src/nn/parallel/process_group.cc +++ b/infini_train/src/nn/parallel/process_group.cc @@ -1,5 +1,6 @@ #include "infini_train/include/nn/parallel/process_group.h" +#include #include #include @@ -44,11 +45,36 @@ ProcessGroup::ProcessGroup(const std::vector &device_indices) : comm_size_( comms_.resize(comm_size_); NCCL_CHECK(ncclCommInitAll(comms_.data(), comm_size_, device_indices.data())); + comm_streams_.resize(comm_size_); + int current_device = -1; + CUDA_CHECK(cudaGetDevice(¤t_device)); + for (int i = 0; i < comm_size_; ++i) { auto device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, device_indices[i]); devices_.push_back(device); device_comm_map_[device] = comms_[i]; thread_group_rank_map_[device->rank().thread_rank()] = i; + + device->SetDevice(); + int low, high; + cudaDeviceGetStreamPriorityRange(&low, &high); + cudaStreamCreateWithPriority(&comm_streams_[i], cudaStreamNonBlocking, high); + device_stream_map_[device] = comm_streams_[i]; + } + + CUDA_CHECK(cudaSetDevice(current_device)); +} + +ProcessGroup::~ProcessGroup() { + for (auto &s : comm_streams_) { + if (s) { + cudaStreamDestroy(s); + } + } + for (auto &c : comms_) { + if (c) { + ncclCommDestroy(c); + } } } @@ -308,6 +334,37 @@ std::vector> ProcessGroup::NcclRecv(std::vector &tensor, function::ReduceOpType reduce_op) const { + CHECK(ready_event && done_event) << "Events must be created."; + const auto *device = dynamic_cast(tensor->GetDevice()); + CHECK(std::find(devices_.begin(), devices_.end(), device) != devices_.end()) + << "Device of target Tensor is not in current ProcessGroup"; + + cudaStream_t compute_stream = device->Stream(); + cudaStream_t comm_stream = device_stream_map_.at(device); + + cudaEventRecord(ready_event, compute_stream); + cudaStreamWaitEvent(comm_stream, ready_event, 0); + + // Perform NcclAllReduce on comm stream + device->SetDevice(); + NCCL_CHECK(ncclAllReduce(tensor->DataPtr(), tensor->DataPtr(), tensor->NumElements(), + kNcclDtypeMap.at(tensor->Dtype()), kNcclReduceOpMap.at(reduce_op), + device_comm_map_.at(device), comm_stream)); + + cudaEventRecord(done_event, comm_stream); +} + +void ProcessGroup::WaitAllReduceDone(cudaEvent_t done_event, const std::shared_ptr &tensor) const { + CHECK(done_event) << "Events must be created."; + const auto *device = dynamic_cast(tensor->GetDevice()); + CHECK(std::find(devices_.begin(), devices_.end(), device) != devices_.end()) + << "Device of target Tensor is not in current ProcessGroup"; + cudaStreamWaitEvent(device->Stream(), done_event, 0); +} + #endif ProcessGroupFactory *ProcessGroupFactory::Instance() { diff --git a/infini_train/src/nn/parallel/reducer.cc b/infini_train/src/nn/parallel/reducer.cc new file mode 100644 index 00000000..4800c90a --- /dev/null +++ b/infini_train/src/nn/parallel/reducer.cc @@ -0,0 +1,480 @@ +#include "infini_train/include/nn/parallel/reducer.h" + +#include +#include +#include +#include +#include + +#ifdef USE_CUDA +#include +#endif + +#include "glog/logging.h" + +#include "infini_train/include/autograd/function_hook.h" +#include "infini_train/include/common/cuda/common_cuda.h" +#include "infini_train/include/device.h" +#include "infini_train/include/nn/parallel/utils.h" + +namespace infini_train::nn::parallel { +namespace { +void CopyGradToBucket(const std::shared_ptr &grad, const std::shared_ptr &flat, size_t dst_elem_offset, + void *stream = nullptr) { + CHECK(grad && flat); + const size_t element_size_in_bytes = kDataTypeToSize.at(grad->Dtype()); + const size_t bytes = grad->NumElements() * element_size_in_bytes; + char *dst = static_cast(flat->DataPtr()) + dst_elem_offset * element_size_in_bytes; + const void *src = grad->DataPtr(); + + const auto dev_type = grad->GetDevice()->Type(); + if (dev_type == DeviceType::kCPU) { + std::memcpy(dst, src, bytes); + return; + } +#ifdef USE_CUDA + if (dev_type == DeviceType::kCUDA) { + auto *cuda_dev = dynamic_cast(flat->GetDevice()); + CHECK(cuda_dev); + cuda_dev->SetDevice(); + auto comm_stream = stream ? reinterpret_cast(stream) : cuda_dev->Stream(); + cudaMemcpyAsync(dst, src, bytes, cudaMemcpyDeviceToDevice, comm_stream); + return; + } +#endif + LOG(FATAL) << "Unsupported device type in CopyGradToBucket"; +} + +void CopyBucketToGrad(const std::shared_ptr &flat, const std::shared_ptr &grad, size_t src_elem_offset, + void *stream = nullptr) { + CHECK(grad && flat); + const size_t element_size_in_bytes = kDataTypeToSize.at(grad->Dtype()); + const size_t bytes = grad->NumElements() * element_size_in_bytes; + const char *src = static_cast(flat->DataPtr()) + src_elem_offset * element_size_in_bytes; + void *dst = grad->DataPtr(); + + const auto dev_type = grad->GetDevice()->Type(); + if (dev_type == DeviceType::kCPU) { + std::memcpy(dst, src, bytes); + return; + } +#ifdef USE_CUDA + if (dev_type == DeviceType::kCUDA) { + auto *cuda_dev = dynamic_cast(flat->GetDevice()); + CHECK(cuda_dev); + cuda_dev->SetDevice(); + auto comm_stream = stream ? reinterpret_cast(stream) : cuda_dev->Stream(); + cudaMemcpyAsync(dst, src, bytes, cudaMemcpyDeviceToDevice, comm_stream); + return; + } +#endif + LOG(FATAL) << "Unsupported device type in CopyBucketToGrad"; +} + +std::shared_ptr MakeGradView(const std::shared_ptr &contents, size_t offset_elems, + const std::vector &dims) { + // Return a view of contents (same chunk of memory) + auto view = std::make_shared(*contents, offset_elems * kDataTypeToSize.at(contents->Dtype()), dims); + return view; +} +} // namespace + +std::vector> ComputeBucketAssignmentBySize(const std::vector> &tensors, + const std::vector &bucket_size_limits, + const std::vector &tensor_indices) { + + CHECK(!tensors.empty()); + CHECK(!bucket_size_limits.empty()); + // By default, tensors are bucketed in reverse order, closer to the order that grad goes ready in backward + auto ReverseOrder = [](size_t n) { + std::vector idx(n); + std::iota(idx.begin(), idx.end(), 0); + std::reverse(idx.begin(), idx.end()); + return idx; + }; + std::vector order = tensor_indices.empty() ? ReverseOrder(tensors.size()) : tensor_indices; + + // Group tensors by device/dtype, make sure that device and dtype is the same in a single bucket + struct Key { + int dev; + DataType dtype; + bool operator==(const Key &o) const { return dev == o.dev && dtype == o.dtype; } + }; + struct KeyHash { + size_t operator()(const Key &k) const { + return (std::hash()(k.dev) << 1) ^ std::hash()(static_cast(k.dtype)); + } + }; + auto key_of = [&](size_t i) -> Key { return Key{tensors[i]->GetDevice()->Index(), tensors[i]->Dtype()}; }; + + // Maintain the current state of each bucket + struct State { + std::vector current_tensors; // Indices of tensors in the bucket + size_t current_bytes = 0; // Total bytes used by the bucket + size_t limit_idx = 0; // The index of bucket_size_limits used by the bucket + }; + + std::unordered_map states; + std::vector key_order; + // NOTE(zbl): Assume combinations of (device, dtype) <= 8 + states.reserve(8); + + std::vector> buckets_all; + buckets_all.reserve(tensors.size()); + + auto advance_limit = [&](State &s) { + // Iterate along bucket_size_limits till the last one everytime a bucket is completed + if (s.limit_idx + 1 < bucket_size_limits.size()) { + ++s.limit_idx; + } + }; + + auto current_cap = [&](const State &s) -> size_t { return bucket_size_limits[s.limit_idx]; }; + + auto flush_current_bucket = [&](State &s) { + if (!s.current_tensors.empty()) { + buckets_all.push_back(std::move(s.current_tensors)); + s.current_tensors.clear(); + s.current_bytes = 0; + advance_limit(s); + } + }; + + for (size_t idx_in_order : order) { + CHECK_LT(idx_in_order, tensors.size()); + const auto &tensor = tensors[idx_in_order]; + CHECK(tensor); + + const Key k = key_of(idx_in_order); + auto it = states.find(k); + if (it == states.end()) { + it = states.emplace(k, State{}).first; + key_order.push_back(k); + } + auto &state = it->second; + + const size_t element_size_in_bytes = kDataTypeToSize.at(tensor->Dtype()); + const size_t bytes = tensor->NumElements() * element_size_in_bytes; + const size_t cap = current_cap(state); + + // Assign current tensor to current bucket first + state.current_tensors.push_back(idx_in_order); + state.current_bytes += bytes; + + // If current bucket is out of capacity, then flush and move on to the next bucket + if (state.current_bytes >= cap) { + flush_current_bucket(state); + } + } + + // Flush the last bucket of each group manually + for (auto &key : key_order) { flush_current_bucket(states[key]); } + + return buckets_all; +} + +Reducer::Reducer(std::vector> parameters, std::vector> bucket_indices, + const ReducerOptions &opts) + : params_(std::move(parameters)), opts_(opts) { + BuildBuckets(bucket_indices); + ready_seen_this_iter_.assign(params_.size(), 0); + AttachHooksToParameters(); +} + +Reducer::~Reducer() { +#ifdef USE_CUDA + for (auto &b : buckets_) { + if (!b.contents) { + continue; + } + if (b.contents->GetDevice()->Type() == DeviceType::kCUDA) { + if (b.allreduce_done) { + CUDA_CHECK(cudaEventDestroy(b.allreduce_done)); + } + if (b.bucket_ready) { + CUDA_CHECK(cudaEventDestroy(b.bucket_ready)); + } + } + } +#endif +} + +void Reducer::InitializeBuckets(const std::vector> &bucket_indices) { +#ifdef USE_CUDA + for (auto &b : buckets_) { + if (!b.contents) { + continue; + } + if (b.contents->GetDevice()->Type() == DeviceType::kCUDA) { + if (b.allreduce_done) { + CUDA_CHECK(cudaEventDestroy(b.allreduce_done)); + } + if (b.bucket_ready) { + CUDA_CHECK(cudaEventDestroy(b.bucket_ready)); + } + } + } +#endif + buckets_.clear(); + locators_.clear(); + next_bucket_ = 0; + BuildBuckets(bucket_indices); +} + +void Reducer::InitializeBucketViews(Bucket &bucket) { + bucket.bucket_views_in.clear(); + bucket.bucket_views_out.clear(); + bucket.bucket_views_in.reserve(bucket.variables.size()); + bucket.bucket_views_out.reserve(bucket.variables.size()); + + for (size_t i = 0; i < bucket.variables.size(); ++i) { + const auto &v = bucket.variables[i]; + const size_t offset_elems = bucket.offsets[i]; + auto view_in = MakeGradView(bucket.contents, offset_elems, v->Dims()); + bucket.bucket_views_in.push_back(view_in); + } + // Set (out == in) by default when all grads are dense + bucket.bucket_views_out = bucket.bucket_views_in; + + if (opts_.gradient_as_bucket_view) { + for (size_t i = 0; i < bucket.variables.size(); ++i) { + auto &v = bucket.variables[i]; + auto g = v->grad(); + if (g && g.get() != bucket.bucket_views_in[i].get()) { + v->set_grad(bucket.bucket_views_in[i]); + } + } + } +} + +void Reducer::BuildBuckets(const std::vector> &bucket_indices) { + locators_.resize(params_.size()); + buckets_.clear(); + buckets_.reserve(bucket_indices.size()); + + for (size_t bucket_idx = 0; bucket_idx < bucket_indices.size(); ++bucket_idx) { + Bucket bucket; + + CHECK(!bucket_indices[bucket_idx].empty()); + const auto &first_param = params_[bucket_indices[bucket_idx][0]]; + bucket.dtype = first_param->Dtype(); + bucket.device_rank = first_param->GetDevice()->rank().thread_rank(); + + size_t total_elems = 0; + + for (auto param_idx : bucket_indices[bucket_idx]) { + const auto ¶m = params_.at(param_idx); + CHECK(param); + CHECK(param->GetDevice() == first_param->GetDevice()) << "Bucket cannot span devices"; + CHECK(param->Dtype() == first_param->Dtype()) << "Bucket cannot span dtypes"; + + bucket.variables.push_back(param); + bucket.offsets.push_back(total_elems); + bucket.lengths.push_back(param->NumElements()); + total_elems += param->NumElements(); + + locators_[param_idx] = {bucket_idx, bucket.variables.size() - 1}; + } + + // Assgin 1D (flat) contents + auto dev = bucket.variables.front()->GetDevice(); + bucket.contents + = std::make_shared(std::vector{static_cast(total_elems)}, bucket.dtype, dev); + // bucket.contents->Fill(0); + bucket.pending = bucket.variables.size(); + +#ifdef USE_CUDA + if (bucket.contents->GetDevice()->Type() == DeviceType::kCUDA) { + CUDA_CHECK(cudaEventCreateWithFlags(&bucket.allreduce_done, cudaEventDisableTiming)); + CUDA_CHECK(cudaEventCreateWithFlags(&bucket.bucket_ready, cudaEventDisableTiming)); + } +#endif + + bucket.variable_indices = bucket_indices[bucket_idx]; + InitializeBucketViews(bucket); + buckets_.push_back(std::move(bucket)); + } +} + +void Reducer::RebuildBuckets() { + // NOTE(zbl): Assume mutex is on when entering this function + // If no order is recorded then skip the rebuild + if (grad_ready_order_indices_.empty()) { + return; + } + + // full_order = real ready order + missing index + std::vector seen(params_.size(), 0); + for (auto idx : grad_ready_order_indices_) { + if (idx < params_.size()) { + seen[idx] = 1; + } + } + std::vector full_order = grad_ready_order_indices_; + full_order.reserve(params_.size()); + for (size_t i = 0; i < params_.size(); ++i) { + if (!seen[i]) { + full_order.push_back(i); + } + } + + std::vector> tensors_in_order; + tensors_in_order.reserve(full_order.size()); + for (auto global_idx : full_order) { + CHECK_LT(global_idx, params_.size()); + tensors_in_order.push_back(params_[global_idx]); + } + + const size_t first_cap_bytes = opts_.first_bucket_cap_mb * 1024ULL * 1024ULL; + const size_t normal_cap_bytes = opts_.normal_bucket_cap_mb * 1024ULL * 1024ULL; + std::vector bucket_size_limits = {first_cap_bytes, normal_cap_bytes}; + auto new_bucket_indices = ComputeBucketAssignmentBySize(tensors_in_order, bucket_size_limits, full_order); + + InitializeBuckets(new_bucket_indices); +} + +std::vector>> Reducer::GetBucketTensors() const { + std::lock_guard g(mutex_); + std::vector>> out; + out.reserve(buckets_.size()); + for (auto const &b : buckets_) { out.push_back({b.contents}); } + return out; +} + +void Reducer::RegisterCommHook(std::shared_ptr hook) { + std::lock_guard g(mutex_); + comm_hook_ = std::move(hook); +} + +void Reducer::PrepareForBackward() { + std::lock_guard g(mutex_); + buckets_finished_.store(0, std::memory_order_relaxed); + + if (need_rebuild_ && !has_rebuilt_bucket_) { + RebuildBuckets(); + has_rebuilt_bucket_ = true; + need_rebuild_ = false; + } + next_bucket_ = 0; + grad_ready_order_indices_.clear(); + ready_seen_this_iter_.assign(params_.size(), 0); + + for (auto &bucket : buckets_) { + bucket.pending = bucket.variables.size(); + if (opts_.gradient_as_bucket_view) { + for (size_t i = 0; i < bucket.variables.size(); ++i) { + // Tie each param.grad to slice of contents + const auto ¶m = bucket.variables[i]; + auto view = bucket.bucket_views_in[i]; + auto grad = param->grad(); + + if (grad == nullptr) { + param->MarkGradOverwriteOnNextAccum(); + param->set_grad(view); + } else { + CHECK_EQ(grad.get(), view.get()) << "Param's gradient should be a slice of bucket's flat buffer."; + } + } + } + } +} + +void Reducer::AttachHooksToParameters() { + for (size_t param_idx = 0; param_idx < params_.size(); ++param_idx) { + class BucketHook final : public autograd::PostAccumulateGradHook { + public: + BucketHook(std::weak_ptr reducer, size_t var_index) + : reducer_(std::move(reducer)), var_index_(var_index) {} + + void operator()(const std::shared_ptr &) override { + if (auto r = reducer_.lock()) { + r->MarkVariableReadyDense(var_index_); + } + } + + private: + std::weak_ptr reducer_; + size_t var_index_; + }; + + auto hook = std::make_unique(weak_from_this(), param_idx); + params_[param_idx]->RegisterPostAccumulateGradHook(std::move(hook)); + } +} + +void Reducer::MarkVariableReadyDense(size_t variable_index) { + std::lock_guard g(mutex_); + const auto loc = locators_.at(variable_index); + auto &bucket = buckets_.at(loc.bucket_index); + + // Record real order of bucket being ready + if (!has_rebuilt_bucket_ && variable_index < ready_seen_this_iter_.size() + && !ready_seen_this_iter_[variable_index]) { + grad_ready_order_indices_.push_back(variable_index); + ready_seen_this_iter_[variable_index] = 1; + } + + if (!opts_.gradient_as_bucket_view) { + auto grad = bucket.variables[loc.intra_bucket_index]->grad(); + CHECK(grad && grad->Dtype() == bucket.dtype && grad->GetDevice() == bucket.contents->GetDevice()); + CopyGradToBucket(grad, bucket.contents, bucket.offsets[loc.intra_bucket_index]); + } + + CHECK(bucket.pending > 0); + bucket.pending -= 1; + + bool should_launch_next = (bucket.pending == 0); + + if (should_launch_next) { + MarkBucketReady(loc.bucket_index); + } +} + +void Reducer::MarkBucketReady(size_t bucket_index) { + // NOTE(zbl): Assume mutex is on when entering this function + if (bucket_index > next_bucket_) { + // Only when bucket_index == next_bucket_ will we launch all-reduce + // bucket_index > next_bucket_ means that there are still buckets before that are not ready + return; + } + // From next_bucket_, launch ready buckets(pending==0) in turn + while (next_bucket_ < buckets_.size() && buckets_[next_bucket_].pending == 0) { + auto &bucket = buckets_[next_bucket_]; + FinalizeBucketDense(next_bucket_); + ++next_bucket_; + } + + // If all buckets are ready, then try to rebuild them in real order + if (next_bucket_ == buckets_.size() && !has_rebuilt_bucket_) { + if (!grad_ready_order_indices_.empty()) { + need_rebuild_ = true; + } + } +} + +void Reducer::FinalizeBucketDense(size_t bucket_index) { + auto &bucket = buckets_.at(bucket_index); + auto ddp_pg = ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(bucket.device_rank)); + + if (comm_hook_) { + std::vector> bucket_view{bucket.contents}; + // NOTE(zbl): Custom hook should do in-place operations + // e.g. comm_hook_(GradBucket{bucket_view})[0]; + // FIXME(zbl): support custom hook later + LOG(FATAL) << "Custom hook is not supported now"; + } else { + ddp_pg->EnqueueAllReduce(bucket.bucket_ready, bucket.allreduce_done, bucket.contents, + function::ReduceOpType::kAvg); + } + + if (!opts_.gradient_as_bucket_view) { + for (size_t i = 0; i < bucket.variables.size(); ++i) { + // Directly assgin bucket slice to grad instead of copying + // Same behavior as `CopyBucketToGrad(bucket.contents, bucket.variables[i]->grad(), bucket.offsets[i]);` + bucket.variables[i]->set_grad(bucket.bucket_views_in[i]); + } + } + + ddp_pg->WaitAllReduceDone(bucket.allreduce_done, bucket.contents); +} +} // namespace infini_train::nn::parallel diff --git a/infini_train/src/tensor.cc b/infini_train/src/tensor.cc index abfd560d..db888009 100644 --- a/infini_train/src/tensor.cc +++ b/infini_train/src/tensor.cc @@ -500,6 +500,14 @@ void Tensor::ZeroGrad(bool set_to_none) { } } +void Tensor::MarkGradOverwriteOnNextAccum() { grad_overwrite_once_ = true; } + +bool Tensor::ConsumeGradOverwriteFlag() { + bool flag = grad_overwrite_once_; + grad_overwrite_once_ = false; + return flag; +} + void Tensor::Backward(std::shared_ptr gradient, bool retain_graph, bool create_graph) const { CHECK(!retain_graph && !create_graph) << "Not implemented yet!"; if (grad_fn_) { From 0053fa8c0c33b9496583c4eeb6a710f952817756 Mon Sep 17 00:00:00 2001 From: bolunz Date: Fri, 14 Nov 2025 17:56:34 +0800 Subject: [PATCH 2/5] feat: add Work definition, fix gradient_as_bucket_view option --- .../include/nn/parallel/process_group.h | 7 +- infini_train/include/nn/parallel/reducer.h | 19 ++- infini_train/include/nn/parallel/work.h | 72 +++++++++++ infini_train/include/tensor.h | 3 + infini_train/src/autograd/accumulate.cc | 7 +- .../nn/parallel/distributed_data_parallel.cc | 6 +- infini_train/src/nn/parallel/process_group.cc | 38 +++--- infini_train/src/nn/parallel/reducer.cc | 81 +++--------- infini_train/src/nn/parallel/work.cc | 116 ++++++++++++++++++ infini_train/src/tensor.cc | 79 +++++++++++- 10 files changed, 324 insertions(+), 104 deletions(-) create mode 100644 infini_train/include/nn/parallel/work.h create mode 100644 infini_train/src/nn/parallel/work.cc diff --git a/infini_train/include/nn/parallel/process_group.h b/infini_train/include/nn/parallel/process_group.h index 3bcf9c2a..8f581966 100644 --- a/infini_train/include/nn/parallel/process_group.h +++ b/infini_train/include/nn/parallel/process_group.h @@ -11,6 +11,7 @@ #endif #include "infini_train/include/nn/parallel/reduce_op_type.h" +#include "infini_train/include/nn/parallel/work.h" namespace infini_train { class Tensor; @@ -55,10 +56,8 @@ class ProcessGroup { std::vector> NcclRecv(std::vector> tensors, int src_rank) const; - // Overlap helper functions - void EnqueueAllReduce(cudaEvent_t ready_event, cudaEvent_t done_event, const std::shared_ptr &tensor, - function::ReduceOpType reduce_op) const; - void WaitAllReduceDone(cudaEvent_t done_event, const std::shared_ptr &tensor) const; + // Async communication functions + std::shared_ptr AllReduceAsync(const std::shared_ptr &tensor, function::ReduceOpType reduce_op) const; private: std::vector comms_; diff --git a/infini_train/include/nn/parallel/reducer.h b/infini_train/include/nn/parallel/reducer.h index 4685094a..3a559681 100644 --- a/infini_train/include/nn/parallel/reducer.h +++ b/infini_train/include/nn/parallel/reducer.h @@ -30,12 +30,15 @@ std::vector> ComputeBucketAssignmentBySize(const std::vector const std::vector &tensor_indices = {}); struct ReducerOptions { + // Pack all Reducer-related args together + // Ref: https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html + // Max capacity for each bucket(in MB) size_t first_bucket_cap_mb = 128; size_t normal_bucket_cap_mb = 512; // When set true, map param.grad directly to the slice of bucket.flat(same address in memory) instead of memcpy - bool gradient_as_bucket_view = false; + bool gradient_as_bucket_view = true; }; // DDP Reducer that handles gradient bucketing in backward @@ -50,7 +53,9 @@ class Reducer : public std::enable_shared_from_this { */ explicit Reducer(std::vector> parameters, std::vector> bucket_indices, const ReducerOptions &opts); - ~Reducer(); + + // Attach PostAllReduceHooks to params + void AttachHooksToParameters(); // Prepare bucket info for next step void PrepareForBackward(); @@ -91,7 +96,7 @@ class Reducer : public std::enable_shared_from_this { // Views into the `gradients` tensor for each individual gradient std::vector> bucket_views_in; - // NOTE(zbl): reserved for occasions where grads have different stride/layout + // TODO(zbl): reserved for occasions where grads have different stride/layout std::vector> bucket_views_out; // Number of gradients left to be computed before the bucket is ready to be reduced @@ -104,18 +109,10 @@ class Reducer : public std::enable_shared_from_this { // If `true`, then this implies that `bucket.variables.size() == 1`. // TODO(zbl): support logics for sparse gradient later bool expect_sparse_gradient = false; - -#ifdef USE_CUDA - // Event to mark that AllReduce is completed - cudaEvent_t allreduce_done = nullptr; - // Event to mark that all tensors' grad in bucket are ready - cudaEvent_t bucket_ready = nullptr; -#endif }; private: void InitializeBuckets(const std::vector> &bucket_indices); - void AttachHooksToParameters(); // NOTE(zbl): all grads are assumed dense and stored continously in bucket for now void MarkVariableReadyDense(size_t variable_index); diff --git a/infini_train/include/nn/parallel/work.h b/infini_train/include/nn/parallel/work.h new file mode 100644 index 00000000..d9f6520a --- /dev/null +++ b/infini_train/include/nn/parallel/work.h @@ -0,0 +1,72 @@ +#pragma once + +#include +#include +#include +#include +#include + +#ifdef USE_CUDA +#include +#endif +#ifdef USE_NCCL +#include +#endif + +#include "infini_train/include/device.h" + +namespace infini_train::nn::parallel { + +class Work { +public: + virtual ~Work() = default; + + virtual bool Wait(std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()) = 0; + + virtual bool IsCompleted() const = 0; + virtual bool IsSuccess() const = 0; + + virtual void Synchronize() const = 0; + + virtual std::exception_ptr exception() const = 0; + + virtual void *ready_event() const = 0; + virtual void *done_event() const = 0; +}; + +#ifdef USE_NCCL +class WorkNccl final : public Work { +public: + WorkNccl(const Device *device, ncclComm_t comm); + ~WorkNccl() override; + + bool Wait(std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()) override; + + bool IsCompleted() const override; + bool IsSuccess() const override; + + void Synchronize() const override; + + std::exception_ptr exception() const override { return exception_; }; + + void *ready_event() const override { return reinterpret_cast(ready_event_); }; + void *done_event() const override { return reinterpret_cast(done_event_); }; + +private: + bool CheckNcclStatus(); + void SetException(std::exception_ptr e); + +private: + Device *device_ = nullptr; + cudaEvent_t ready_event_; + cudaEvent_t done_event_; + ncclComm_t comm_; + + mutable std::mutex mutex_; + std::exception_ptr exception_; + std::atomic completed_{false}; + std::atomic success_{false}; +}; +#endif + +} // namespace infini_train::nn::parallel diff --git a/infini_train/include/tensor.h b/infini_train/include/tensor.h index 4932bce4..a7c4bc1e 100644 --- a/infini_train/include/tensor.h +++ b/infini_train/include/tensor.h @@ -82,6 +82,9 @@ class Tensor : public std::enable_shared_from_this { Tensor To(const Device *device); Tensor To(DataType dtype); + void CopyFrom(const Tensor &src); + void CopyFrom(const std::shared_ptr &src); + // operator overloading std::shared_ptr Equals(const std::shared_ptr &other); std::shared_ptr Equals(float scalar); diff --git a/infini_train/src/autograd/accumulate.cc b/infini_train/src/autograd/accumulate.cc index 4e3204f6..76443bd1 100644 --- a/infini_train/src/autograd/accumulate.cc +++ b/infini_train/src/autograd/accumulate.cc @@ -27,13 +27,16 @@ AccumulateGrad::Backward(const std::vector> &grad_output if (grad_output) { if (grad) { if (tensor_->ConsumeGradOverwriteFlag()) { - auto new_grad = std::make_shared(*grad_output.get(), 0, grad_output->Dims()); - tensor_->set_grad(std::move(new_grad)); + // If the tensor is marked to overrite its current grad on next grad update + // See notes in `infini_train::nn::parallel::Reducer::PrepareForBackward()` + // NOTE(zbl): must copy, cannot change grad buffer address + grad->CopyFrom(grad_output); } else { auto kernel = Dispatcher::Instance().GetKernel({device->Type(), "AccumulateGrad"}); kernel.Call(grad_output, learning_rate_, grad); } } else { + // NOTE(zbl): check whether need to do copying instead of slicing auto new_grad = std::make_shared(*grad_output.get(), 0, grad_output->Dims()); tensor_->set_grad(std::move(new_grad)); } diff --git a/infini_train/src/nn/parallel/distributed_data_parallel.cc b/infini_train/src/nn/parallel/distributed_data_parallel.cc index c6212770..34e232fa 100644 --- a/infini_train/src/nn/parallel/distributed_data_parallel.cc +++ b/infini_train/src/nn/parallel/distributed_data_parallel.cc @@ -20,8 +20,7 @@ constexpr char kModuleName[] = "module"; DistributedDataParallel::DistributedDataParallel(std::shared_ptr module, int device_id, const ReducerOptions &opts) { for (auto ¶m : module->Parameters()) { - auto device = param->GetDevice(); - CHECK_EQ(device->Index(), device_id) << "All parameters must be on the same device as the module"; + CHECK_EQ(param->GetDevice()->Index(), device_id) << "All parameters must be on the same device as the module"; } for (auto &buffer : module->Buffers()) { CHECK_EQ(buffer->GetDevice()->Index(), device_id) << "All buffers must be on the same device as the module"; @@ -35,7 +34,8 @@ DistributedDataParallel::DistributedDataParallel(std::shared_ptr mod std::vector bucket_size_limits = {first_cap_bytes, normal_cap_bytes}; auto bucket_indices = ComputeBucketAssignmentBySize(params, bucket_size_limits); - reducer_ = std::make_shared(std::move(params), bucket_indices, opts); + reducer_ = std::make_shared(params, bucket_indices, opts); + reducer_->AttachHooksToParameters(); } std::vector> diff --git a/infini_train/src/nn/parallel/process_group.cc b/infini_train/src/nn/parallel/process_group.cc index 64de2df1..d7937035 100644 --- a/infini_train/src/nn/parallel/process_group.cc +++ b/infini_train/src/nn/parallel/process_group.cc @@ -335,36 +335,34 @@ std::vector> ProcessGroup::NcclRecv(std::vector &tensor, function::ReduceOpType reduce_op) const { - CHECK(ready_event && done_event) << "Events must be created."; +std::shared_ptr ProcessGroup::AllReduceAsync(const std::shared_ptr &tensor, + function::ReduceOpType reduce_op) const { + void *buffer = tensor->DataPtr(); const auto *device = dynamic_cast(tensor->GetDevice()); - CHECK(std::find(devices_.begin(), devices_.end(), device) != devices_.end()) - << "Device of target Tensor is not in current ProcessGroup"; + device->SetDevice(); + + auto comm = device_comm_map_.at(device); cudaStream_t compute_stream = device->Stream(); cudaStream_t comm_stream = device_stream_map_.at(device); - cudaEventRecord(ready_event, compute_stream); - cudaStreamWaitEvent(comm_stream, ready_event, 0); + auto work = std::make_shared(device, comm); + + cudaEvent_t ready_event = reinterpret_cast(work->ready_event()); + cudaEvent_t done_event = reinterpret_cast(work->done_event()); + + CUDA_CHECK(cudaEventRecord(ready_event, compute_stream)); + CUDA_CHECK(cudaStreamWaitEvent(comm_stream, ready_event, 0)); // Perform NcclAllReduce on comm stream - device->SetDevice(); - NCCL_CHECK(ncclAllReduce(tensor->DataPtr(), tensor->DataPtr(), tensor->NumElements(), - kNcclDtypeMap.at(tensor->Dtype()), kNcclReduceOpMap.at(reduce_op), - device_comm_map_.at(device), comm_stream)); + NCCL_CHECK(ncclAllReduce(buffer, buffer, tensor->NumElements(), kNcclDtypeMap.at(tensor->Dtype()), + kNcclReduceOpMap.at(reduce_op), comm, comm_stream)); - cudaEventRecord(done_event, comm_stream); -} + CUDA_CHECK(cudaEventRecord(done_event, comm_stream)); + CUDA_CHECK(cudaStreamWaitEvent(compute_stream, done_event, 0)); -void ProcessGroup::WaitAllReduceDone(cudaEvent_t done_event, const std::shared_ptr &tensor) const { - CHECK(done_event) << "Events must be created."; - const auto *device = dynamic_cast(tensor->GetDevice()); - CHECK(std::find(devices_.begin(), devices_.end(), device) != devices_.end()) - << "Device of target Tensor is not in current ProcessGroup"; - cudaStreamWaitEvent(device->Stream(), done_event, 0); + return std::move(work); } - #endif ProcessGroupFactory *ProcessGroupFactory::Instance() { diff --git a/infini_train/src/nn/parallel/reducer.cc b/infini_train/src/nn/parallel/reducer.cc index 4800c90a..fcb0ca89 100644 --- a/infini_train/src/nn/parallel/reducer.cc +++ b/infini_train/src/nn/parallel/reducer.cc @@ -16,6 +16,7 @@ #include "infini_train/include/common/cuda/common_cuda.h" #include "infini_train/include/device.h" #include "infini_train/include/nn/parallel/utils.h" +#include "infini_train/include/nn/parallel/work.h" namespace infini_train::nn::parallel { namespace { @@ -178,43 +179,9 @@ Reducer::Reducer(std::vector> parameters, std::vectorGetDevice()->Type() == DeviceType::kCUDA) { - if (b.allreduce_done) { - CUDA_CHECK(cudaEventDestroy(b.allreduce_done)); - } - if (b.bucket_ready) { - CUDA_CHECK(cudaEventDestroy(b.bucket_ready)); - } - } - } -#endif } void Reducer::InitializeBuckets(const std::vector> &bucket_indices) { -#ifdef USE_CUDA - for (auto &b : buckets_) { - if (!b.contents) { - continue; - } - if (b.contents->GetDevice()->Type() == DeviceType::kCUDA) { - if (b.allreduce_done) { - CUDA_CHECK(cudaEventDestroy(b.allreduce_done)); - } - if (b.bucket_ready) { - CUDA_CHECK(cudaEventDestroy(b.bucket_ready)); - } - } - } -#endif buckets_.clear(); locators_.clear(); next_bucket_ = 0; @@ -235,16 +202,6 @@ void Reducer::InitializeBucketViews(Bucket &bucket) { } // Set (out == in) by default when all grads are dense bucket.bucket_views_out = bucket.bucket_views_in; - - if (opts_.gradient_as_bucket_view) { - for (size_t i = 0; i < bucket.variables.size(); ++i) { - auto &v = bucket.variables[i]; - auto g = v->grad(); - if (g && g.get() != bucket.bucket_views_in[i].get()) { - v->set_grad(bucket.bucket_views_in[i]); - } - } - } } void Reducer::BuildBuckets(const std::vector> &bucket_indices) { @@ -280,16 +237,8 @@ void Reducer::BuildBuckets(const std::vector> &bucket_indice auto dev = bucket.variables.front()->GetDevice(); bucket.contents = std::make_shared(std::vector{static_cast(total_elems)}, bucket.dtype, dev); - // bucket.contents->Fill(0); bucket.pending = bucket.variables.size(); -#ifdef USE_CUDA - if (bucket.contents->GetDevice()->Type() == DeviceType::kCUDA) { - CUDA_CHECK(cudaEventCreateWithFlags(&bucket.allreduce_done, cudaEventDisableTiming)); - CUDA_CHECK(cudaEventCreateWithFlags(&bucket.bucket_ready, cudaEventDisableTiming)); - } -#endif - bucket.variable_indices = bucket_indices[bucket_idx]; InitializeBucketViews(bucket); buckets_.push_back(std::move(bucket)); @@ -368,11 +317,18 @@ void Reducer::PrepareForBackward() { auto view = bucket.bucket_views_in[i]; auto grad = param->grad(); - if (grad == nullptr) { - param->MarkGradOverwriteOnNextAccum(); + // NOTE(zbl): This will affect behaviors in `infini_train::autograd::AccumulateGrad::Backward()` + // If ZeroGrad(set_to_none=True), grad is nullptr at this point + // If ZeroGrad(set_to_none=False), grad is set to view of bucket.contents (or modified by user) + // Either way, we reset grad to view of bucket.contents + // Since bucket.contents might not be zeroed, we need to overwrite it on next grad accumulation + if (!grad || (grad.get() != view.get())) { + if (grad) { + LOG(WARNING) << "gradient_as_bucket_view is enabled, but param " << param + << " has a non-view grad tensor. Automatically overwriting it with bucket view."; + } param->set_grad(view); - } else { - CHECK_EQ(grad.get(), view.get()) << "Param's gradient should be a slice of bucket's flat buffer."; + param->MarkGradOverwriteOnNextAccum(); } } } @@ -456,6 +412,7 @@ void Reducer::FinalizeBucketDense(size_t bucket_index) { auto &bucket = buckets_.at(bucket_index); auto ddp_pg = ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(bucket.device_rank)); + std::shared_ptr work; if (comm_hook_) { std::vector> bucket_view{bucket.contents}; // NOTE(zbl): Custom hook should do in-place operations @@ -463,18 +420,16 @@ void Reducer::FinalizeBucketDense(size_t bucket_index) { // FIXME(zbl): support custom hook later LOG(FATAL) << "Custom hook is not supported now"; } else { - ddp_pg->EnqueueAllReduce(bucket.bucket_ready, bucket.allreduce_done, bucket.contents, - function::ReduceOpType::kAvg); + work = ddp_pg->AllReduceAsync(bucket.contents, function::ReduceOpType::kAvg); } if (!opts_.gradient_as_bucket_view) { for (size_t i = 0; i < bucket.variables.size(); ++i) { - // Directly assgin bucket slice to grad instead of copying - // Same behavior as `CopyBucketToGrad(bucket.contents, bucket.variables[i]->grad(), bucket.offsets[i]);` - bucket.variables[i]->set_grad(bucket.bucket_views_in[i]); + // NOTE(zbl): For better performance, try `bucket.variables[i]->set_grad(bucket.bucket_views_in[i]);` + // to directly assgin bucket slice to grad instead of copying + CopyBucketToGrad(bucket.contents, bucket.variables[i]->grad(), bucket.offsets[i]); + // bucket.variables[i]->set_grad(bucket.bucket_views_in[i]); } } - - ddp_pg->WaitAllReduceDone(bucket.allreduce_done, bucket.contents); } } // namespace infini_train::nn::parallel diff --git a/infini_train/src/nn/parallel/work.cc b/infini_train/src/nn/parallel/work.cc new file mode 100644 index 00000000..fef7517d --- /dev/null +++ b/infini_train/src/nn/parallel/work.cc @@ -0,0 +1,116 @@ +#include "infini_train/include/nn/parallel/work.h" + +#include "glog/logging.h" + +#ifdef USE_CUDA +#include "infini_train/include/common/cuda/common_cuda.h" +#endif + +namespace infini_train::nn::parallel { +#ifdef USE_NCCL +namespace { +std::exception_ptr makeCudaError(cudaError_t err) { + return std::make_exception_ptr(std::runtime_error(cudaGetErrorString(err))); +} +} // namespace + +WorkNccl::WorkNccl(const Device *device, ncclComm_t comm) : comm_(comm) { + CUDA_CHECK(cudaEventCreateWithFlags(&ready_event_, cudaEventDisableTiming)); + CUDA_CHECK(cudaEventCreateWithFlags(&done_event_, cudaEventDisableTiming)); +} + +WorkNccl::~WorkNccl() { + if (ready_event_) { + CUDA_CHECK(cudaEventDestroy(ready_event_)); + } + if (done_event_) { + CUDA_CHECK(cudaEventDestroy(done_event_)); + } +} + +bool WorkNccl::Wait(std::chrono::milliseconds timeout) { + // Block wait on host + device_->SetDevice(); + + // If timeout is not set, then wait till it finishes + if (timeout <= std::chrono::milliseconds::zero()) { + if (auto status = cudaEventSynchronize(done_event_); status != cudaSuccess) { + SetException(makeCudaError(status)); + return false; + } + // Check NCCL status + return CheckNcclStatus(); + } + + // If timeout is set, keep querying till time's up + const auto deadline = std::chrono::steady_clock::now() + timeout; + while (std::chrono::steady_clock::now() < deadline) { + cudaError_t query = cudaEventQuery(done_event_); + if (query == cudaSuccess) { + return CheckNcclStatus(); + } + if (query != cudaErrorNotReady) { + SetException(makeCudaError(query)); + return false; + } + // NOTE(zbl): sleep for a while in case of busy waiting + std::this_thread::sleep_for(std::chrono::microseconds(50)); + } + + if (exception_) { + // NOTE(zbl): do not throw any c++ exception + LOG(FATAL) << "Error occurs while wait(). "; + } + + return false; +} + +void WorkNccl::Synchronize() const { CUDA_CHECK(cudaEventSynchronize(done_event_)); } + +bool WorkNccl::IsCompleted() const { + if (completed_.load(std::memory_order_acquire)) { + return true; + } + cudaError_t query = cudaEventQuery(done_event_); + if (query == cudaSuccess) { + const_cast(this)->completed_.store(true, std::memory_order_release); + const_cast(this)->success_.store(true, std::memory_order_release); + return true; + } + if (query != cudaErrorNotReady) { + const_cast(this)->SetException(makeCudaError(query)); + return true; + } + return false; +} + +bool WorkNccl::IsSuccess() const { + if (!IsCompleted()) { + return false; + } + return success_.load(std::memory_order_acquire) && !exception_; +} + +bool WorkNccl::CheckNcclStatus() { + ncclResult_t async_error; + if (comm_ && ncclCommGetAsyncError(comm_, &async_error) == ncclSuccess && async_error != ncclSuccess) { + SetException(std::make_exception_ptr( + std::runtime_error(std::string("NCCL async error: ") + ncclGetErrorString(async_error)))); + return false; + } + success_.store(true, std::memory_order_release); + completed_.store(true, std::memory_order_release); + return true; +} + +void WorkNccl::SetException(std::exception_ptr e) { + std::lock_guard g(mutex_); + if (!exception_) { + exception_ = std::move(e); + } + completed_.store(true, std::memory_order_release); + success_.store(false, std::memory_order_release); +} +#endif + +} // namespace infini_train::nn::parallel diff --git a/infini_train/src/tensor.cc b/infini_train/src/tensor.cc index db888009..4752fe0a 100644 --- a/infini_train/src/tensor.cc +++ b/infini_train/src/tensor.cc @@ -225,6 +225,83 @@ Tensor Tensor::To(DataType dtype) { return new_tensor; } +void Tensor::CopyFrom(const Tensor &src) { + CHECK(Dtype() == src.Dtype()) << "Tensor::CopyFrom dtype mismatch: dst=" << static_cast(Dtype()) + << " src=" << static_cast(src.Dtype()); + CHECK_EQ(NumElements(), src.NumElements()) << "Tensor::CopyFrom element count mismatch"; + CHECK(Dims() == src.Dims()) << "Tensor::CopyFrom shape mismatch"; + + const size_t nbytes = SizeInBytes(); + const Device *dst_dev = GetDevice(); + const Device *src_dev = src.GetDevice(); + + switch (dst_dev->Type()) { + case DeviceType::kCPU: { + switch (src_dev->Type()) { + case DeviceType::kCPU: { + std::memcpy(DataPtr(), src.DataPtr(), nbytes); + break; + } +#ifdef USE_CUDA + case DeviceType::kCUDA: { + // CUDA -> CPU + CUDA_CHECK(cudaMemcpy(DataPtr(), src.DataPtr(), nbytes, cudaMemcpyDeviceToHost)); + break; + } +#endif + default: + LOG(FATAL) << "Unsupported src device type: " << static_cast(src_dev->Type()); + } + break; + } + +#ifdef USE_CUDA + case DeviceType::kCUDA: { + int current_device = -1; + CUDA_CHECK(cudaGetDevice(¤t_device)); + dst_dev->SetDevice(); + + const auto *dst_cuda = dynamic_cast(dst_dev); + switch (src_dev->Type()) { + case DeviceType::kCPU: { + // CPU -> CUDA + CUDA_CHECK(cudaMemcpyAsync(DataPtr(), src.DataPtr(), nbytes, cudaMemcpyHostToDevice, dst_cuda->Stream())); + break; + } + case DeviceType::kCUDA: { + const auto *src_cuda = dynamic_cast(src_dev); + if (src_cuda->Index() == dst_cuda->Index()) { + CUDA_CHECK( + cudaMemcpyAsync(DataPtr(), src.DataPtr(), nbytes, cudaMemcpyDeviceToDevice, dst_cuda->Stream())); + } else { + int canAccessPeer = 0; + CUDA_CHECK(cudaDeviceCanAccessPeer(&canAccessPeer, dst_cuda->Index(), src_cuda->Index())); + if (canAccessPeer) { + CUDA_CHECK(cudaMemcpyPeerAsync(DataPtr(), dst_cuda->Index(), src.DataPtr(), src_cuda->Index(), + nbytes, dst_cuda->Stream())); + } else { + LOG(FATAL) << "Check accessibility between Device " << src_cuda->Index() << " and Device " + << dst_cuda->Index(); + } + } + break; + } + default: + LOG(FATAL) << "Unsupported src device type: " << static_cast(src_dev->Type()); + } + + CUDA_CHECK(cudaSetDevice(current_device)); + break; + } +#endif + + default: + LOG(FATAL) << "Unsupported dst device type: " << static_cast(dst_dev->Type()); + } +} + +void Tensor::CopyFrom(const std::shared_ptr &src) { CopyFrom(*src); } + // operator overloading std::shared_ptr Tensor::Equals(const std::shared_ptr &other) { return std::make_shared()->Apply({shared_from_this(), other})[0]; @@ -472,7 +549,7 @@ void Tensor::set_grad(std::shared_ptr grad) { CHECK(grad->GetDevice() == GetDevice()); CHECK(grad->Dtype() == Dtype()); CHECK(grad->Dims() == Dims()); - grad_ = grad; + grad_ = std::move(grad); } else { grad_.reset(); } From ed1a608244f5e0e0c2f743e4c3e9fedb96f322bb Mon Sep 17 00:00:00 2001 From: bolunz Date: Tue, 25 Nov 2025 17:58:57 +0800 Subject: [PATCH 3/5] fix: fix requested changes and add sync in profiler --- .../nn/parallel/distributed_data_parallel.h | 2 +- .../include/nn/parallel/process_group.h | 6 ++-- infini_train/include/nn/parallel/reducer.h | 29 +++++++++++++++---- infini_train/include/tensor.h | 2 +- infini_train/src/autograd/accumulate.cc | 4 +-- .../nn/parallel/distributed_data_parallel.cc | 28 ++++++++++++------ infini_train/src/nn/parallel/process_group.cc | 6 ++-- infini_train/src/nn/parallel/reducer.cc | 25 +++++++--------- infini_train/src/profiler.cc | 5 ++++ infini_train/src/tensor.cc | 4 +-- 10 files changed, 70 insertions(+), 41 deletions(-) diff --git a/infini_train/include/nn/parallel/distributed_data_parallel.h b/infini_train/include/nn/parallel/distributed_data_parallel.h index df214156..6001a17a 100644 --- a/infini_train/include/nn/parallel/distributed_data_parallel.h +++ b/infini_train/include/nn/parallel/distributed_data_parallel.h @@ -20,7 +20,7 @@ class DistributedDataParallel : public nn::Module { std::vector> Forward(const std::vector> &input_tensors) override; private: - std::shared_ptr reducer_; + std::shared_ptr reducer_ = nullptr; }; } // namespace infini_train::nn::parallel diff --git a/infini_train/include/nn/parallel/process_group.h b/infini_train/include/nn/parallel/process_group.h index 8f581966..14be5c06 100644 --- a/infini_train/include/nn/parallel/process_group.h +++ b/infini_train/include/nn/parallel/process_group.h @@ -11,14 +11,16 @@ #endif #include "infini_train/include/nn/parallel/reduce_op_type.h" -#include "infini_train/include/nn/parallel/work.h" namespace infini_train { class Tensor; class Device; namespace nn { class Module; -} +namespace parallel { +class Work; +} // namespace parallel +} // namespace nn } // namespace infini_train diff --git a/infini_train/include/nn/parallel/reducer.h b/infini_train/include/nn/parallel/reducer.h index 3a559681..adb636a6 100644 --- a/infini_train/include/nn/parallel/reducer.h +++ b/infini_train/include/nn/parallel/reducer.h @@ -1,21 +1,34 @@ #pragma once +#include #include #include #include -#include "infini_train/include/autograd/function_hook.h" +#include "infini_train/include/datatype.h" #include "infini_train/include/nn/parallel/parallel_functional.h" -#include "infini_train/include/tensor.h" + +namespace infini_train { +class Tensor; +class Device; +namespace autograd { +class PostAccumulateGradHook; +} // namespace autograd +} // namespace infini_train namespace infini_train::nn::parallel { +namespace { +constexpr int kFirstBucketCapMB = 25; +constexpr int kNormalBucketCapMB = 25; +constexpr size_t kBytesPerMB = 1024ULL * 1024ULL; +} // namespace // GradBucket passes bucket contents tensor to DDP communication hook. // ref: https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/comm.hpp class GradBucket { public: explicit GradBucket(const std::vector> &tensors) : tensors_(tensors) {} - const std::vector> &getTensors() const { return tensors_; } + const std::vector> &tensors() const { return tensors_; } private: std::vector> tensors_; @@ -34,11 +47,15 @@ struct ReducerOptions { // Ref: https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html // Max capacity for each bucket(in MB) - size_t first_bucket_cap_mb = 128; - size_t normal_bucket_cap_mb = 512; + size_t first_bucket_cap_mb = kFirstBucketCapMB; + size_t normal_bucket_cap_mb = kNormalBucketCapMB; // When set true, map param.grad directly to the slice of bucket.flat(same address in memory) instead of memcpy bool gradient_as_bucket_view = true; + + // Whether to enable gradient bucketing + // FIXME(zbl): should enable gradient bucketing by default + bool gradient_bucketing_enabled = true; }; // DDP Reducer that handles gradient bucketing in backward @@ -60,7 +77,7 @@ class Reducer : public std::enable_shared_from_this { // Prepare bucket info for next step void PrepareForBackward(); - // For custom DDP hook to overwrite the default AllReduce. T + // For custom DDP hook to overwrite the default AllReduce. // This can be used for algorithms like Gradient Compression/GossipGrad. // Hook is registered using `Reducer::RegisterCommHook()`. // TODO(zbl): Leave the placeholder for the moment diff --git a/infini_train/include/tensor.h b/infini_train/include/tensor.h index a7c4bc1e..58e02025 100644 --- a/infini_train/include/tensor.h +++ b/infini_train/include/tensor.h @@ -196,7 +196,7 @@ class Tensor : public std::enable_shared_from_this { std::shared_ptr RequiresGrad(); std::shared_ptr grad() const; - void set_grad(std::shared_ptr grad); + void set_grad(std::shared_ptr &grad); bool requires_grad() const; void set_requires_grad(bool requires_grad); diff --git a/infini_train/src/autograd/accumulate.cc b/infini_train/src/autograd/accumulate.cc index 76443bd1..def9cad8 100644 --- a/infini_train/src/autograd/accumulate.cc +++ b/infini_train/src/autograd/accumulate.cc @@ -36,9 +36,9 @@ AccumulateGrad::Backward(const std::vector> &grad_output kernel.Call(grad_output, learning_rate_, grad); } } else { - // NOTE(zbl): check whether need to do copying instead of slicing + // FIXME(zbl): check whether need to do copying instead of slicing auto new_grad = std::make_shared(*grad_output.get(), 0, grad_output->Dims()); - tensor_->set_grad(std::move(new_grad)); + tensor_->set_grad(new_grad); } auto hook = tensor_->post_accumulate_grad_hook(); if (hook != nullptr) { diff --git a/infini_train/src/nn/parallel/distributed_data_parallel.cc b/infini_train/src/nn/parallel/distributed_data_parallel.cc index 34e232fa..7b75a050 100644 --- a/infini_train/src/nn/parallel/distributed_data_parallel.cc +++ b/infini_train/src/nn/parallel/distributed_data_parallel.cc @@ -20,22 +20,32 @@ constexpr char kModuleName[] = "module"; DistributedDataParallel::DistributedDataParallel(std::shared_ptr module, int device_id, const ReducerOptions &opts) { for (auto ¶m : module->Parameters()) { - CHECK_EQ(param->GetDevice()->Index(), device_id) << "All parameters must be on the same device as the module"; + auto device = param->GetDevice(); + CHECK_EQ(device->Index(), device_id) << "All parameters must be on the same device as the module"; + if (!opts.gradient_bucketing_enabled) { + auto ddp_pg + = ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(device->rank().thread_rank())); + auto hook = std::make_unique( + function::ReduceOpType::kAvg, ddp_pg); + param->RegisterPostAccumulateGradHook(std::move(hook)); + } } for (auto &buffer : module->Buffers()) { CHECK_EQ(buffer->GetDevice()->Index(), device_id) << "All buffers must be on the same device as the module"; } modules_[kModuleName] = std::move(module); - // Bucket Assignment - auto params = modules_[kModuleName]->Parameters(); - const size_t first_cap_bytes = opts.first_bucket_cap_mb * 1024ULL * 1024ULL; - const size_t normal_cap_bytes = opts.normal_bucket_cap_mb * 1024ULL * 1024ULL; - std::vector bucket_size_limits = {first_cap_bytes, normal_cap_bytes}; - auto bucket_indices = ComputeBucketAssignmentBySize(params, bucket_size_limits); + if (opts.gradient_bucketing_enabled) { + // Bucket Assignment + auto params = modules_[kModuleName]->Parameters(); + const size_t first_cap_bytes = opts.first_bucket_cap_mb * kBytesPerMB; + const size_t normal_cap_bytes = opts.normal_bucket_cap_mb * kBytesPerMB; + std::vector bucket_size_limits = {first_cap_bytes, normal_cap_bytes}; + auto bucket_indices = ComputeBucketAssignmentBySize(params, bucket_size_limits); - reducer_ = std::make_shared(params, bucket_indices, opts); - reducer_->AttachHooksToParameters(); + reducer_ = std::make_shared(params, bucket_indices, opts); + reducer_->AttachHooksToParameters(); + } } std::vector> diff --git a/infini_train/src/nn/parallel/process_group.cc b/infini_train/src/nn/parallel/process_group.cc index d7937035..bd85deec 100644 --- a/infini_train/src/nn/parallel/process_group.cc +++ b/infini_train/src/nn/parallel/process_group.cc @@ -1,6 +1,5 @@ #include "infini_train/include/nn/parallel/process_group.h" -#include #include #include @@ -14,6 +13,7 @@ #include "infini_train/include/datatype.h" #include "infini_train/include/device.h" #include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/nn/parallel/work.h" #include "infini_train/include/tensor.h" namespace infini_train { @@ -57,8 +57,8 @@ ProcessGroup::ProcessGroup(const std::vector &device_indices) : comm_size_( device->SetDevice(); int low, high; - cudaDeviceGetStreamPriorityRange(&low, &high); - cudaStreamCreateWithPriority(&comm_streams_[i], cudaStreamNonBlocking, high); + CUDA_CHECK(cudaDeviceGetStreamPriorityRange(&low, &high)); + CUDA_CHECK(cudaStreamCreateWithPriority(&comm_streams_[i], cudaStreamNonBlocking, high)); device_stream_map_[device] = comm_streams_[i]; } diff --git a/infini_train/src/nn/parallel/reducer.cc b/infini_train/src/nn/parallel/reducer.cc index fcb0ca89..318d7489 100644 --- a/infini_train/src/nn/parallel/reducer.cc +++ b/infini_train/src/nn/parallel/reducer.cc @@ -14,9 +14,9 @@ #include "infini_train/include/autograd/function_hook.h" #include "infini_train/include/common/cuda/common_cuda.h" -#include "infini_train/include/device.h" #include "infini_train/include/nn/parallel/utils.h" #include "infini_train/include/nn/parallel/work.h" +#include "infini_train/include/tensor.h" namespace infini_train::nn::parallel { namespace { @@ -106,7 +106,6 @@ std::vector> ComputeBucketAssignmentBySize(const std::vector return (std::hash()(k.dev) << 1) ^ std::hash()(static_cast(k.dtype)); } }; - auto key_of = [&](size_t i) -> Key { return Key{tensors[i]->GetDevice()->Index(), tensors[i]->Dtype()}; }; // Maintain the current state of each bucket struct State { @@ -117,8 +116,6 @@ std::vector> ComputeBucketAssignmentBySize(const std::vector std::unordered_map states; std::vector key_order; - // NOTE(zbl): Assume combinations of (device, dtype) <= 8 - states.reserve(8); std::vector> buckets_all; buckets_all.reserve(tensors.size()); @@ -130,9 +127,7 @@ std::vector> ComputeBucketAssignmentBySize(const std::vector } }; - auto current_cap = [&](const State &s) -> size_t { return bucket_size_limits[s.limit_idx]; }; - - auto flush_current_bucket = [&](State &s) { + auto flushCurrentBucket = [&](State &s) { if (!s.current_tensors.empty()) { buckets_all.push_back(std::move(s.current_tensors)); s.current_tensors.clear(); @@ -146,7 +141,7 @@ std::vector> ComputeBucketAssignmentBySize(const std::vector const auto &tensor = tensors[idx_in_order]; CHECK(tensor); - const Key k = key_of(idx_in_order); + const Key k = Key{tensors[idx_in_order]->GetDevice()->Index(), tensors[idx_in_order]->Dtype()}; auto it = states.find(k); if (it == states.end()) { it = states.emplace(k, State{}).first; @@ -156,7 +151,7 @@ std::vector> ComputeBucketAssignmentBySize(const std::vector const size_t element_size_in_bytes = kDataTypeToSize.at(tensor->Dtype()); const size_t bytes = tensor->NumElements() * element_size_in_bytes; - const size_t cap = current_cap(state); + const size_t cap = bucket_size_limits[state.limit_idx]; // Assign current tensor to current bucket first state.current_tensors.push_back(idx_in_order); @@ -164,12 +159,12 @@ std::vector> ComputeBucketAssignmentBySize(const std::vector // If current bucket is out of capacity, then flush and move on to the next bucket if (state.current_bytes >= cap) { - flush_current_bucket(state); + flushCurrentBucket(state); } } // Flush the last bucket of each group manually - for (auto &key : key_order) { flush_current_bucket(states[key]); } + for (auto &key : key_order) { flushCurrentBucket(states[key]); } return buckets_all; } @@ -215,6 +210,7 @@ void Reducer::BuildBuckets(const std::vector> &bucket_indice CHECK(!bucket_indices[bucket_idx].empty()); const auto &first_param = params_[bucket_indices[bucket_idx][0]]; bucket.dtype = first_param->Dtype(); + // FIXME(zbl): use global_rank() in multi-node settings bucket.device_rank = first_param->GetDevice()->rank().thread_rank(); size_t total_elems = 0; @@ -274,8 +270,8 @@ void Reducer::RebuildBuckets() { tensors_in_order.push_back(params_[global_idx]); } - const size_t first_cap_bytes = opts_.first_bucket_cap_mb * 1024ULL * 1024ULL; - const size_t normal_cap_bytes = opts_.normal_bucket_cap_mb * 1024ULL * 1024ULL; + const size_t first_cap_bytes = opts_.first_bucket_cap_mb * kBytesPerMB; + const size_t normal_cap_bytes = opts_.normal_bucket_cap_mb * kBytesPerMB; std::vector bucket_size_limits = {first_cap_bytes, normal_cap_bytes}; auto new_bucket_indices = ComputeBucketAssignmentBySize(tensors_in_order, bucket_size_limits, full_order); @@ -364,8 +360,7 @@ void Reducer::MarkVariableReadyDense(size_t variable_index) { auto &bucket = buckets_.at(loc.bucket_index); // Record real order of bucket being ready - if (!has_rebuilt_bucket_ && variable_index < ready_seen_this_iter_.size() - && !ready_seen_this_iter_[variable_index]) { + if (!has_rebuilt_bucket_ && !ready_seen_this_iter_[variable_index]) { grad_ready_order_indices_.push_back(variable_index); ready_seen_this_iter_[variable_index] = 1; } diff --git a/infini_train/src/profiler.cc b/infini_train/src/profiler.cc index a04c9e14..f2be2f4b 100644 --- a/infini_train/src/profiler.cc +++ b/infini_train/src/profiler.cc @@ -84,6 +84,11 @@ void Profiler::StartRecord(const std::string &name, DeviceType device) { cudaStream_t stream = GetCudaStream(); CUDA_CHECK(cudaEventCreate(&start)); CUDA_CHECK(cudaEventCreate(&stop)); + + // Make sure the compute stream has done waiting, and ready for the execution of next op + CUDA_CHECK(cudaStreamSynchronize(stream)); + // Start record after waiting + cpu_timing_map_[name] = std::chrono::high_resolution_clock::now(); CUDA_CHECK(cudaEventRecord(start, stream)); cuda_timing_map_[name] = {reinterpret_cast(start), reinterpret_cast(stop)}; break; diff --git a/infini_train/src/tensor.cc b/infini_train/src/tensor.cc index 4752fe0a..c7f7669c 100644 --- a/infini_train/src/tensor.cc +++ b/infini_train/src/tensor.cc @@ -544,12 +544,12 @@ std::shared_ptr Tensor::RequiresGrad() { } std::shared_ptr Tensor::grad() const { return grad_; }; -void Tensor::set_grad(std::shared_ptr grad) { +void Tensor::set_grad(std::shared_ptr &grad) { if (grad) { CHECK(grad->GetDevice() == GetDevice()); CHECK(grad->Dtype() == Dtype()); CHECK(grad->Dims() == Dims()); - grad_ = std::move(grad); + grad_ = grad; } else { grad_.reset(); } From 0a2f97d642b20d22649faba8f7c52c5e6e1d85f0 Mon Sep 17 00:00:00 2001 From: bolunz Date: Wed, 26 Nov 2025 15:34:39 +0800 Subject: [PATCH 4/5] fix: fix stream wait logics in compute-comm overlap --- infini_train/include/nn/parallel/reducer.h | 12 +++- infini_train/include/nn/parallel/work.h | 12 ++-- .../nn/parallel/distributed_data_parallel.cc | 3 +- infini_train/src/nn/parallel/process_group.cc | 2 +- infini_train/src/nn/parallel/reducer.cc | 63 +++++++++++++++---- infini_train/src/nn/parallel/work.cc | 12 +++- 6 files changed, 83 insertions(+), 21 deletions(-) diff --git a/infini_train/include/nn/parallel/reducer.h b/infini_train/include/nn/parallel/reducer.h index adb636a6..f729f723 100644 --- a/infini_train/include/nn/parallel/reducer.h +++ b/infini_train/include/nn/parallel/reducer.h @@ -14,11 +14,15 @@ class Device; namespace autograd { class PostAccumulateGradHook; } // namespace autograd +namespace nn::parallel { +class Work; +} // namespace nn::parallel } // namespace infini_train namespace infini_train::nn::parallel { namespace { -constexpr int kFirstBucketCapMB = 25; +// Default bucket size in alignment with PyTorch +constexpr int kFirstBucketCapMB = 1; constexpr int kNormalBucketCapMB = 25; constexpr size_t kBytesPerMB = 1024ULL * 1024ULL; } // namespace @@ -126,6 +130,9 @@ class Reducer : public std::enable_shared_from_this { // If `true`, then this implies that `bucket.variables.size() == 1`. // TODO(zbl): support logics for sparse gradient later bool expect_sparse_gradient = false; + + // The result of async communication op + std::shared_ptr work = nullptr; }; private: @@ -135,6 +142,7 @@ class Reducer : public std::enable_shared_from_this { void MarkVariableReadyDense(size_t variable_index); void MarkBucketReady(size_t bucket_index); void FinalizeBucketDense(size_t bucket_index); + void FinalizeBackward(); void BuildBuckets(const std::vector> &bucket_indices); void InitializeBucketViews(Bucket &bucket); @@ -161,6 +169,8 @@ class Reducer : public std::enable_shared_from_this { bool need_rebuild_ = false; // Whether to buckets have already been rebuilt on the second step bool has_rebuilt_bucket_ = false; + // Whether all buckets are ready and backward can be finalized + bool all_buckets_ready_this_iter_ = false; }; } // namespace infini_train::nn::parallel diff --git a/infini_train/include/nn/parallel/work.h b/infini_train/include/nn/parallel/work.h index d9f6520a..3f304b7a 100644 --- a/infini_train/include/nn/parallel/work.h +++ b/infini_train/include/nn/parallel/work.h @@ -13,7 +13,9 @@ #include #endif -#include "infini_train/include/device.h" +namespace infini_train { +class Device; +} // namespace infini_train namespace infini_train::nn::parallel { @@ -21,7 +23,8 @@ class Work { public: virtual ~Work() = default; - virtual bool Wait(std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()) = 0; + virtual bool WaitBlocking(std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()) = 0; + virtual bool WaitNonBlocking() = 0; virtual bool IsCompleted() const = 0; virtual bool IsSuccess() const = 0; @@ -40,7 +43,8 @@ class WorkNccl final : public Work { WorkNccl(const Device *device, ncclComm_t comm); ~WorkNccl() override; - bool Wait(std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()) override; + bool WaitBlocking(std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()) override; + bool WaitNonBlocking(); bool IsCompleted() const override; bool IsSuccess() const override; @@ -57,7 +61,7 @@ class WorkNccl final : public Work { void SetException(std::exception_ptr e); private: - Device *device_ = nullptr; + const Device *device_ = nullptr; cudaEvent_t ready_event_; cudaEvent_t done_event_; ncclComm_t comm_; diff --git a/infini_train/src/nn/parallel/distributed_data_parallel.cc b/infini_train/src/nn/parallel/distributed_data_parallel.cc index 7b75a050..9b8704a0 100644 --- a/infini_train/src/nn/parallel/distributed_data_parallel.cc +++ b/infini_train/src/nn/parallel/distributed_data_parallel.cc @@ -50,10 +50,11 @@ DistributedDataParallel::DistributedDataParallel(std::shared_ptr mod std::vector> DistributedDataParallel::Forward(const std::vector> &input_tensors) { + auto outputs = modules_[kModuleName]->Forward(input_tensors); if (reducer_) { reducer_->PrepareForBackward(); } - return modules_[kModuleName]->Forward(input_tensors); + return outputs; } } // namespace infini_train::nn::parallel diff --git a/infini_train/src/nn/parallel/process_group.cc b/infini_train/src/nn/parallel/process_group.cc index bd85deec..e4d16682 100644 --- a/infini_train/src/nn/parallel/process_group.cc +++ b/infini_train/src/nn/parallel/process_group.cc @@ -359,8 +359,8 @@ std::shared_ptr ProcessGroup::AllReduceAsync(const std::shared_ptr kNcclReduceOpMap.at(reduce_op), comm, comm_stream)); CUDA_CHECK(cudaEventRecord(done_event, comm_stream)); - CUDA_CHECK(cudaStreamWaitEvent(compute_stream, done_event, 0)); + // Do not let compute stream wait for done event here return std::move(work); } #endif diff --git a/infini_train/src/nn/parallel/reducer.cc b/infini_train/src/nn/parallel/reducer.cc index 318d7489..3b029deb 100644 --- a/infini_train/src/nn/parallel/reducer.cc +++ b/infini_train/src/nn/parallel/reducer.cc @@ -355,7 +355,7 @@ void Reducer::AttachHooksToParameters() { } void Reducer::MarkVariableReadyDense(size_t variable_index) { - std::lock_guard g(mutex_); + std::unique_lock lock(mutex_); const auto loc = locators_.at(variable_index); auto &bucket = buckets_.at(loc.bucket_index); @@ -379,6 +379,13 @@ void Reducer::MarkVariableReadyDense(size_t variable_index) { if (should_launch_next) { MarkBucketReady(loc.bucket_index); } + + // Release mutex + lock.unlock(); + + if (all_buckets_ready_this_iter_) { + FinalizeBackward(); + } } void Reducer::MarkBucketReady(size_t bucket_index) { @@ -395,19 +402,23 @@ void Reducer::MarkBucketReady(size_t bucket_index) { ++next_bucket_; } - // If all buckets are ready, then try to rebuild them in real order - if (next_bucket_ == buckets_.size() && !has_rebuilt_bucket_) { - if (!grad_ready_order_indices_.empty()) { + // If all buckets are ready + if (next_bucket_ == buckets_.size()) { + // Mark that it's time to finalize backward + all_buckets_ready_this_iter_ = true; + + // Try to rebuild them in real order in the first round + if (!has_rebuilt_bucket_ && !grad_ready_order_indices_.empty()) { need_rebuild_ = true; } } } void Reducer::FinalizeBucketDense(size_t bucket_index) { + // NOTE(zbl): Assume mutex is on when entering this function auto &bucket = buckets_.at(bucket_index); auto ddp_pg = ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(bucket.device_rank)); - std::shared_ptr work; if (comm_hook_) { std::vector> bucket_view{bucket.contents}; // NOTE(zbl): Custom hook should do in-place operations @@ -415,16 +426,44 @@ void Reducer::FinalizeBucketDense(size_t bucket_index) { // FIXME(zbl): support custom hook later LOG(FATAL) << "Custom hook is not supported now"; } else { - work = ddp_pg->AllReduceAsync(bucket.contents, function::ReduceOpType::kAvg); + bucket.work = ddp_pg->AllReduceAsync(bucket.contents, function::ReduceOpType::kAvg); } +} - if (!opts_.gradient_as_bucket_view) { - for (size_t i = 0; i < bucket.variables.size(); ++i) { - // NOTE(zbl): For better performance, try `bucket.variables[i]->set_grad(bucket.bucket_views_in[i]);` - // to directly assgin bucket slice to grad instead of copying - CopyBucketToGrad(bucket.contents, bucket.variables[i]->grad(), bucket.offsets[i]); - // bucket.variables[i]->set_grad(bucket.bucket_views_in[i]); +void Reducer::FinalizeBackward() { + // NOTE(zbl): Assume mutex is off when entering this function + // Collect all works with mutex on + std::vector> works; + { + std::lock_guard lock(mutex_); + for (auto &bucket : buckets_) { + if (bucket.work) { + works.push_back(bucket.work); + } + } + } + + // Wait for works to be done with mutex off + // Note(zbl): Use non-blocking stream wait instead of sync on host + for (auto &work : works) { work->WaitNonBlocking(); } + + // Write grad back and reset with mutex on + { + std::lock_guard lock(mutex_); + for (auto &bucket : buckets_) { + if (!bucket.work) { + continue; + } + if (!opts_.gradient_as_bucket_view) { + for (size_t i = 0; i < bucket.variables.size(); ++i) { + // NOTE(zbl): For better performance, try to directly assgin bucket slice to grad instead of copying + // i.e. bucket.variables[i]->set_grad(bucket.bucket_views_in[i]); + CopyBucketToGrad(bucket.contents, bucket.variables[i]->grad(), bucket.offsets[i]); + } + } + bucket.work.reset(); } + all_buckets_ready_this_iter_ = false; } } } // namespace infini_train::nn::parallel diff --git a/infini_train/src/nn/parallel/work.cc b/infini_train/src/nn/parallel/work.cc index fef7517d..53fd465a 100644 --- a/infini_train/src/nn/parallel/work.cc +++ b/infini_train/src/nn/parallel/work.cc @@ -5,6 +5,7 @@ #ifdef USE_CUDA #include "infini_train/include/common/cuda/common_cuda.h" #endif +#include "infini_train/include/device.h" namespace infini_train::nn::parallel { #ifdef USE_NCCL @@ -14,7 +15,7 @@ std::exception_ptr makeCudaError(cudaError_t err) { } } // namespace -WorkNccl::WorkNccl(const Device *device, ncclComm_t comm) : comm_(comm) { +WorkNccl::WorkNccl(const Device *device, ncclComm_t comm) : device_(device), comm_(comm) { CUDA_CHECK(cudaEventCreateWithFlags(&ready_event_, cudaEventDisableTiming)); CUDA_CHECK(cudaEventCreateWithFlags(&done_event_, cudaEventDisableTiming)); } @@ -28,7 +29,7 @@ WorkNccl::~WorkNccl() { } } -bool WorkNccl::Wait(std::chrono::milliseconds timeout) { +bool WorkNccl::WaitBlocking(std::chrono::milliseconds timeout) { // Block wait on host device_->SetDevice(); @@ -65,6 +66,13 @@ bool WorkNccl::Wait(std::chrono::milliseconds timeout) { return false; } +bool WorkNccl::WaitNonBlocking() { + // Non-blocking wait on compute stream + device_->SetDevice(); + CUDA_CHECK(cudaStreamWaitEvent(dynamic_cast(device_)->Stream(), done_event_, 0)); + return true; +} + void WorkNccl::Synchronize() const { CUDA_CHECK(cudaEventSynchronize(done_event_)); } bool WorkNccl::IsCompleted() const { From d7d43b4526d72bf3560ef47299e5afe788b0c928 Mon Sep 17 00:00:00 2001 From: bolunz Date: Thu, 27 Nov 2025 13:08:20 +0800 Subject: [PATCH 5/5] fix: fix requested changes --- infini_train/include/tensor.h | 2 +- infini_train/src/nn/parallel/reducer.cc | 18 +++++++----------- infini_train/src/tensor.cc | 2 +- 3 files changed, 9 insertions(+), 13 deletions(-) diff --git a/infini_train/include/tensor.h b/infini_train/include/tensor.h index 58e02025..5dea3853 100644 --- a/infini_train/include/tensor.h +++ b/infini_train/include/tensor.h @@ -196,7 +196,7 @@ class Tensor : public std::enable_shared_from_this { std::shared_ptr RequiresGrad(); std::shared_ptr grad() const; - void set_grad(std::shared_ptr &grad); + void set_grad(const std::shared_ptr &grad); bool requires_grad() const; void set_requires_grad(bool requires_grad); diff --git a/infini_train/src/nn/parallel/reducer.cc b/infini_train/src/nn/parallel/reducer.cc index 3b029deb..429cd44f 100644 --- a/infini_train/src/nn/parallel/reducer.cc +++ b/infini_train/src/nn/parallel/reducer.cc @@ -120,19 +120,15 @@ std::vector> ComputeBucketAssignmentBySize(const std::vector std::vector> buckets_all; buckets_all.reserve(tensors.size()); - auto advance_limit = [&](State &s) { - // Iterate along bucket_size_limits till the last one everytime a bucket is completed - if (s.limit_idx + 1 < bucket_size_limits.size()) { - ++s.limit_idx; - } - }; - - auto flushCurrentBucket = [&](State &s) { + auto FlushCurrentBucket = [&](State &s) { if (!s.current_tensors.empty()) { buckets_all.push_back(std::move(s.current_tensors)); s.current_tensors.clear(); s.current_bytes = 0; - advance_limit(s); + // Iterate along bucket_size_limits till the last one everytime a bucket is completed + if (s.limit_idx + 1 < bucket_size_limits.size()) { + ++s.limit_idx; + } } }; @@ -159,12 +155,12 @@ std::vector> ComputeBucketAssignmentBySize(const std::vector // If current bucket is out of capacity, then flush and move on to the next bucket if (state.current_bytes >= cap) { - flushCurrentBucket(state); + FlushCurrentBucket(state); } } // Flush the last bucket of each group manually - for (auto &key : key_order) { flushCurrentBucket(states[key]); } + for (auto &key : key_order) { FlushCurrentBucket(states[key]); } return buckets_all; } diff --git a/infini_train/src/tensor.cc b/infini_train/src/tensor.cc index c7f7669c..491b90b7 100644 --- a/infini_train/src/tensor.cc +++ b/infini_train/src/tensor.cc @@ -544,7 +544,7 @@ std::shared_ptr Tensor::RequiresGrad() { } std::shared_ptr Tensor::grad() const { return grad_; }; -void Tensor::set_grad(std::shared_ptr &grad) { +void Tensor::set_grad(const std::shared_ptr &grad) { if (grad) { CHECK(grad->GetDevice() == GetDevice()); CHECK(grad->Dtype() == Dtype());