-
Notifications
You must be signed in to change notification settings - Fork 13
feat: DDP gradient bucketing #92
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
b3c0512
feat: add DDP gradient bucketing, support compute/comm dual stream
Chamberlain0w0 0053fa8
feat: add Work definition, fix gradient_as_bucket_view option
Chamberlain0w0 ed1a608
fix: fix requested changes and add sync in profiler
Chamberlain0w0 0a2f97d
fix: fix stream wait logics in compute-comm overlap
Chamberlain0w0 d7d43b4
fix: fix requested changes
Chamberlain0w0 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,176 @@ | ||
| #pragma once | ||
|
|
||
| #include <atomic> | ||
| #include <memory> | ||
| #include <mutex> | ||
| #include <vector> | ||
|
|
||
| #include "infini_train/include/datatype.h" | ||
| #include "infini_train/include/nn/parallel/parallel_functional.h" | ||
|
|
||
| namespace infini_train { | ||
| class Tensor; | ||
| class Device; | ||
| namespace autograd { | ||
| class PostAccumulateGradHook; | ||
| } // namespace autograd | ||
| namespace nn::parallel { | ||
| class Work; | ||
| } // namespace nn::parallel | ||
| } // namespace infini_train | ||
|
|
||
| namespace infini_train::nn::parallel { | ||
| namespace { | ||
| // Default bucket size in alignment with PyTorch | ||
| constexpr int kFirstBucketCapMB = 1; | ||
| constexpr int kNormalBucketCapMB = 25; | ||
| constexpr size_t kBytesPerMB = 1024ULL * 1024ULL; | ||
| } // namespace | ||
|
|
||
| // GradBucket passes bucket contents tensor to DDP communication hook. | ||
| // ref: https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/comm.hpp | ||
| class GradBucket { | ||
| public: | ||
| explicit GradBucket(const std::vector<std::shared_ptr<Tensor>> &tensors) : tensors_(tensors) {} | ||
| const std::vector<std::shared_ptr<Tensor>> &tensors() const { return tensors_; } | ||
|
|
||
| private: | ||
| std::vector<std::shared_ptr<Tensor>> tensors_; | ||
| }; | ||
|
|
||
| // Compute bucket assignment according to the size of each tensors and bucket capacity. | ||
| // Returns the indices of tensors in the corrsponding bucket, i.e. output[bucket_i] = {tensor_j, tensor_k, ...} | ||
| // The index of tensors[idx] assigned to bucket(j and k above) is tensor_indices[idx]. | ||
| // When tensor_indices is empty, the index of tensors[idx] assigned to bucket(j and k above) is idx itself. | ||
| std::vector<std::vector<size_t>> ComputeBucketAssignmentBySize(const std::vector<std::shared_ptr<Tensor>> &tensors, | ||
| const std::vector<size_t> &bucket_size_limits, | ||
| const std::vector<size_t> &tensor_indices = {}); | ||
|
|
||
| struct ReducerOptions { | ||
| // Pack all Reducer-related args together | ||
| // Ref: https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html | ||
|
|
||
| // Max capacity for each bucket(in MB) | ||
| size_t first_bucket_cap_mb = kFirstBucketCapMB; | ||
| size_t normal_bucket_cap_mb = kNormalBucketCapMB; | ||
|
|
||
| // When set true, map param.grad directly to the slice of bucket.flat(same address in memory) instead of memcpy | ||
| bool gradient_as_bucket_view = true; | ||
|
|
||
| // Whether to enable gradient bucketing | ||
| // FIXME(zbl): should enable gradient bucketing by default | ||
| bool gradient_bucketing_enabled = true; | ||
| }; | ||
|
|
||
| // DDP Reducer that handles gradient bucketing in backward | ||
| // ref: https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/reducer.hpp | ||
| class Reducer : public std::enable_shared_from_this<Reducer> { | ||
| public: | ||
| /** @brief Constructor of Reducer | ||
| * | ||
| * @param parameters A list of parameters for this process's single model replica | ||
| * @param bucket_indices The bucket assignment for this reducer | ||
| * @param opts Other options, see definition of ReducerOptions | ||
| */ | ||
| explicit Reducer(std::vector<std::shared_ptr<Tensor>> parameters, std::vector<std::vector<size_t>> bucket_indices, | ||
| const ReducerOptions &opts); | ||
|
|
||
| // Attach PostAllReduceHooks to params | ||
| void AttachHooksToParameters(); | ||
|
|
||
| // Prepare bucket info for next step | ||
| void PrepareForBackward(); | ||
|
|
||
| // For custom DDP hook to overwrite the default AllReduce. | ||
| // This can be used for algorithms like Gradient Compression/GossipGrad. | ||
| // Hook is registered using `Reducer::RegisterCommHook()`. | ||
| // TODO(zbl): Leave the placeholder for the moment | ||
| void RegisterCommHook(std::shared_ptr<autograd::PostAccumulateGradHook> hook); | ||
|
|
||
| // Return every tensor in bucket's flat buffer | ||
| std::vector<std::vector<std::shared_ptr<Tensor>>> GetBucketTensors() const; | ||
|
|
||
| private: | ||
| // A variable locator locates a particular variable in the reducer's buckets | ||
| struct VariableLocator { | ||
| // Index of the bucket containing the variable in the `buckets_` vector | ||
| size_t bucket_index = 0; | ||
| // Index of the variable in the bucket | ||
| size_t intra_bucket_index = 0; | ||
| }; | ||
|
|
||
| // Bucket used in DDP backward | ||
| struct Bucket { | ||
| // Gradients of the bucket flattened into a 1-dimensional tensor | ||
| std::shared_ptr<Tensor> contents; | ||
| DataType dtype; | ||
| int device_rank = 0; | ||
|
|
||
| // Variables whose gradients are held in this bucket | ||
| std::vector<std::shared_ptr<Tensor>> variables; | ||
|
|
||
| // Per-variable offset/length into the flattened `gradients` tensor and | ||
| // the corresponding `GradBucket` instance for communication hooks | ||
| // In terms of element count, not bytes | ||
| std::vector<size_t> offsets; | ||
| std::vector<size_t> lengths; | ||
|
|
||
| // Views into the `gradients` tensor for each individual gradient | ||
| std::vector<std::shared_ptr<Tensor>> bucket_views_in; | ||
| // TODO(zbl): reserved for occasions where grads have different stride/layout | ||
| std::vector<std::shared_ptr<Tensor>> bucket_views_out; | ||
|
|
||
| // Number of gradients left to be computed before the bucket is ready to be reduced | ||
| size_t pending; | ||
|
|
||
| // Global indices of participating variables in the bucket | ||
| std::vector<size_t> variable_indices; | ||
|
|
||
| // If this bucket should expect a single sparse gradient | ||
| // If `true`, then this implies that `bucket.variables.size() == 1`. | ||
| // TODO(zbl): support logics for sparse gradient later | ||
| bool expect_sparse_gradient = false; | ||
kilinchange marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| // The result of async communication op | ||
| std::shared_ptr<Work> work = nullptr; | ||
| }; | ||
|
|
||
| private: | ||
| void InitializeBuckets(const std::vector<std::vector<size_t>> &bucket_indices); | ||
|
|
||
| // NOTE(zbl): all grads are assumed dense and stored continously in bucket for now | ||
| void MarkVariableReadyDense(size_t variable_index); | ||
| void MarkBucketReady(size_t bucket_index); | ||
| void FinalizeBucketDense(size_t bucket_index); | ||
| void FinalizeBackward(); | ||
|
|
||
| void BuildBuckets(const std::vector<std::vector<size_t>> &bucket_indices); | ||
| void InitializeBucketViews(Bucket &bucket); | ||
| void RebuildBuckets(); | ||
|
|
||
| private: | ||
| mutable std::mutex mutex_; | ||
| std::vector<std::shared_ptr<Tensor>> params_; | ||
| std::vector<Bucket> buckets_; | ||
| std::vector<VariableLocator> locators_; | ||
|
|
||
| std::atomic<size_t> buckets_finished_{0}; | ||
| std::shared_ptr<autograd::PostAccumulateGradHook> comm_hook_ = nullptr; | ||
| ReducerOptions opts_; | ||
|
|
||
| // Next bucket to be reduced | ||
| // This is to make sure that all-reduce of buckets be launched in the order we expect | ||
| size_t next_bucket_ = 0; | ||
| // To record the order of params getting ready on first step | ||
| std::vector<size_t> grad_ready_order_indices_; | ||
| // To record whether each param is ready on first step | ||
| std::vector<uint8_t> ready_seen_this_iter_; | ||
| // Whether to rebuild buckets on next train step | ||
| bool need_rebuild_ = false; | ||
| // Whether to buckets have already been rebuilt on the second step | ||
| bool has_rebuilt_bucket_ = false; | ||
| // Whether all buckets are ready and backward can be finalized | ||
| bool all_buckets_ready_this_iter_ = false; | ||
| }; | ||
|
|
||
| } // namespace infini_train::nn::parallel | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| #pragma once | ||
|
|
||
| #include <atomic> | ||
| #include <chrono> | ||
| #include <exception> | ||
| #include <memory> | ||
| #include <mutex> | ||
|
|
||
| #ifdef USE_CUDA | ||
| #include <cuda_runtime.h> | ||
| #endif | ||
| #ifdef USE_NCCL | ||
| #include <nccl.h> | ||
| #endif | ||
|
|
||
| namespace infini_train { | ||
| class Device; | ||
| } // namespace infini_train | ||
|
|
||
| namespace infini_train::nn::parallel { | ||
|
|
||
| class Work { | ||
| public: | ||
| virtual ~Work() = default; | ||
|
|
||
| virtual bool WaitBlocking(std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()) = 0; | ||
| virtual bool WaitNonBlocking() = 0; | ||
|
|
||
| virtual bool IsCompleted() const = 0; | ||
| virtual bool IsSuccess() const = 0; | ||
|
|
||
| virtual void Synchronize() const = 0; | ||
|
|
||
| virtual std::exception_ptr exception() const = 0; | ||
|
|
||
| virtual void *ready_event() const = 0; | ||
| virtual void *done_event() const = 0; | ||
| }; | ||
|
|
||
| #ifdef USE_NCCL | ||
| class WorkNccl final : public Work { | ||
| public: | ||
| WorkNccl(const Device *device, ncclComm_t comm); | ||
| ~WorkNccl() override; | ||
|
|
||
| bool WaitBlocking(std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()) override; | ||
| bool WaitNonBlocking(); | ||
|
|
||
| bool IsCompleted() const override; | ||
| bool IsSuccess() const override; | ||
|
|
||
| void Synchronize() const override; | ||
|
|
||
| std::exception_ptr exception() const override { return exception_; }; | ||
|
|
||
| void *ready_event() const override { return reinterpret_cast<void *>(ready_event_); }; | ||
| void *done_event() const override { return reinterpret_cast<void *>(done_event_); }; | ||
|
|
||
| private: | ||
| bool CheckNcclStatus(); | ||
| void SetException(std::exception_ptr e); | ||
|
|
||
| private: | ||
| const Device *device_ = nullptr; | ||
| cudaEvent_t ready_event_; | ||
| cudaEvent_t done_event_; | ||
| ncclComm_t comm_; | ||
|
|
||
| mutable std::mutex mutex_; | ||
| std::exception_ptr exception_; | ||
| std::atomic<bool> completed_{false}; | ||
| std::atomic<bool> success_{false}; | ||
| }; | ||
| #endif | ||
|
|
||
kilinchange marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } // namespace infini_train::nn::parallel | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.