1616#include " infini_train/include/common/cuda/common_cuda.h"
1717#include " infini_train/include/device.h"
1818#include " infini_train/include/nn/parallel/utils.h"
19+ #include " infini_train/include/nn/parallel/work.h"
1920
2021namespace infini_train ::nn::parallel {
2122namespace {
@@ -178,43 +179,9 @@ Reducer::Reducer(std::vector<std::shared_ptr<Tensor>> parameters, std::vector<st
178179 : params_(std::move(parameters)), opts_(opts) {
179180 BuildBuckets (bucket_indices);
180181 ready_seen_this_iter_.assign (params_.size (), 0 );
181- AttachHooksToParameters ();
182- }
183-
184- Reducer::~Reducer () {
185- #ifdef USE_CUDA
186- for (auto &b : buckets_) {
187- if (!b.contents ) {
188- continue ;
189- }
190- if (b.contents ->GetDevice ()->Type () == DeviceType::kCUDA ) {
191- if (b.allreduce_done ) {
192- CUDA_CHECK (cudaEventDestroy (b.allreduce_done ));
193- }
194- if (b.bucket_ready ) {
195- CUDA_CHECK (cudaEventDestroy (b.bucket_ready ));
196- }
197- }
198- }
199- #endif
200182}
201183
202184void Reducer::InitializeBuckets (const std::vector<std::vector<size_t >> &bucket_indices) {
203- #ifdef USE_CUDA
204- for (auto &b : buckets_) {
205- if (!b.contents ) {
206- continue ;
207- }
208- if (b.contents ->GetDevice ()->Type () == DeviceType::kCUDA ) {
209- if (b.allreduce_done ) {
210- CUDA_CHECK (cudaEventDestroy (b.allreduce_done ));
211- }
212- if (b.bucket_ready ) {
213- CUDA_CHECK (cudaEventDestroy (b.bucket_ready ));
214- }
215- }
216- }
217- #endif
218185 buckets_.clear ();
219186 locators_.clear ();
220187 next_bucket_ = 0 ;
@@ -235,16 +202,6 @@ void Reducer::InitializeBucketViews(Bucket &bucket) {
235202 }
236203 // Set (out == in) by default when all grads are dense
237204 bucket.bucket_views_out = bucket.bucket_views_in ;
238-
239- if (opts_.gradient_as_bucket_view ) {
240- for (size_t i = 0 ; i < bucket.variables .size (); ++i) {
241- auto &v = bucket.variables [i];
242- auto g = v->grad ();
243- if (g && g.get () != bucket.bucket_views_in [i].get ()) {
244- v->set_grad (bucket.bucket_views_in [i]);
245- }
246- }
247- }
248205}
249206
250207void Reducer::BuildBuckets (const std::vector<std::vector<size_t >> &bucket_indices) {
@@ -280,16 +237,8 @@ void Reducer::BuildBuckets(const std::vector<std::vector<size_t>> &bucket_indice
280237 auto dev = bucket.variables .front ()->GetDevice ();
281238 bucket.contents
282239 = std::make_shared<Tensor>(std::vector<int64_t >{static_cast <int64_t >(total_elems)}, bucket.dtype , dev);
283- // bucket.contents->Fill(0);
284240 bucket.pending = bucket.variables .size ();
285241
286- #ifdef USE_CUDA
287- if (bucket.contents ->GetDevice ()->Type () == DeviceType::kCUDA ) {
288- CUDA_CHECK (cudaEventCreateWithFlags (&bucket.allreduce_done , cudaEventDisableTiming));
289- CUDA_CHECK (cudaEventCreateWithFlags (&bucket.bucket_ready , cudaEventDisableTiming));
290- }
291- #endif
292-
293242 bucket.variable_indices = bucket_indices[bucket_idx];
294243 InitializeBucketViews (bucket);
295244 buckets_.push_back (std::move (bucket));
@@ -368,11 +317,18 @@ void Reducer::PrepareForBackward() {
368317 auto view = bucket.bucket_views_in [i];
369318 auto grad = param->grad ();
370319
371- if (grad == nullptr ) {
372- param->MarkGradOverwriteOnNextAccum ();
320+ // NOTE(zbl): This will affect behaviors in `infini_train::autograd::AccumulateGrad::Backward()`
321+ // If ZeroGrad(set_to_none=True), grad is nullptr at this point
322+ // If ZeroGrad(set_to_none=False), grad is set to view of bucket.contents (or modified by user)
323+ // Either way, we reset grad to view of bucket.contents
324+ // Since bucket.contents might not be zeroed, we need to overwrite it on next grad accumulation
325+ if (!grad || (grad.get () != view.get ())) {
326+ if (grad) {
327+ LOG (WARNING) << " gradient_as_bucket_view is enabled, but param " << param
328+ << " has a non-view grad tensor. Automatically overwriting it with bucket view." ;
329+ }
373330 param->set_grad (view);
374- } else {
375- CHECK_EQ (grad.get (), view.get ()) << " Param's gradient should be a slice of bucket's flat buffer." ;
331+ param->MarkGradOverwriteOnNextAccum ();
376332 }
377333 }
378334 }
@@ -456,25 +412,24 @@ void Reducer::FinalizeBucketDense(size_t bucket_index) {
456412 auto &bucket = buckets_.at (bucket_index);
457413 auto ddp_pg = ProcessGroupFactory::Instance ()->Get (GetDataParallelProcessGroupName (bucket.device_rank ));
458414
415+ std::shared_ptr<Work> work;
459416 if (comm_hook_) {
460417 std::vector<std::shared_ptr<Tensor>> bucket_view{bucket.contents };
461418 // NOTE(zbl): Custom hook should do in-place operations
462419 // e.g. comm_hook_(GradBucket{bucket_view})[0];
463420 // FIXME(zbl): support custom hook later
464421 LOG (FATAL) << " Custom hook is not supported now" ;
465422 } else {
466- ddp_pg->EnqueueAllReduce (bucket.bucket_ready , bucket.allreduce_done , bucket.contents ,
467- function::ReduceOpType::kAvg );
423+ work = ddp_pg->AllReduceAsync (bucket.contents , function::ReduceOpType::kAvg );
468424 }
469425
470426 if (!opts_.gradient_as_bucket_view ) {
471427 for (size_t i = 0 ; i < bucket.variables .size (); ++i) {
472- // Directly assgin bucket slice to grad instead of copying
473- // Same behavior as `CopyBucketToGrad(bucket.contents, bucket.variables[i]->grad(), bucket.offsets[i]);`
474- bucket.variables [i]->set_grad (bucket.bucket_views_in [i]);
428+ // NOTE(zbl): For better performance, try `bucket.variables[i]->set_grad(bucket.bucket_views_in[i]);`
429+ // to directly assgin bucket slice to grad instead of copying
430+ CopyBucketToGrad (bucket.contents , bucket.variables [i]->grad (), bucket.offsets [i]);
431+ // bucket.variables[i]->set_grad(bucket.bucket_views_in[i]);
475432 }
476433 }
477-
478- ddp_pg->WaitAllReduceDone (bucket.allreduce_done , bucket.contents );
479434}
480435} // namespace infini_train::nn::parallel
0 commit comments