1414
1515#include " infini_train/include/autograd/function_hook.h"
1616#include " infini_train/include/common/cuda/common_cuda.h"
17- #include " infini_train/include/device.h"
1817#include " infini_train/include/nn/parallel/utils.h"
1918#include " infini_train/include/nn/parallel/work.h"
19+ #include " infini_train/include/tensor.h"
2020
2121namespace infini_train ::nn::parallel {
2222namespace {
@@ -106,7 +106,6 @@ std::vector<std::vector<size_t>> ComputeBucketAssignmentBySize(const std::vector
106106 return (std::hash<int >()(k.dev ) << 1 ) ^ std::hash<int >()(static_cast <int >(k.dtype ));
107107 }
108108 };
109- auto key_of = [&](size_t i) -> Key { return Key{tensors[i]->GetDevice ()->Index (), tensors[i]->Dtype ()}; };
110109
111110 // Maintain the current state of each bucket
112111 struct State {
@@ -117,8 +116,6 @@ std::vector<std::vector<size_t>> ComputeBucketAssignmentBySize(const std::vector
117116
118117 std::unordered_map<Key, State, KeyHash> states;
119118 std::vector<Key> key_order;
120- // NOTE(zbl): Assume combinations of (device, dtype) <= 8
121- states.reserve (8 );
122119
123120 std::vector<std::vector<size_t >> buckets_all;
124121 buckets_all.reserve (tensors.size ());
@@ -130,9 +127,7 @@ std::vector<std::vector<size_t>> ComputeBucketAssignmentBySize(const std::vector
130127 }
131128 };
132129
133- auto current_cap = [&](const State &s) -> size_t { return bucket_size_limits[s.limit_idx ]; };
134-
135- auto flush_current_bucket = [&](State &s) {
130+ auto flushCurrentBucket = [&](State &s) {
136131 if (!s.current_tensors .empty ()) {
137132 buckets_all.push_back (std::move (s.current_tensors ));
138133 s.current_tensors .clear ();
@@ -146,7 +141,7 @@ std::vector<std::vector<size_t>> ComputeBucketAssignmentBySize(const std::vector
146141 const auto &tensor = tensors[idx_in_order];
147142 CHECK (tensor);
148143
149- const Key k = key_of ( idx_in_order) ;
144+ const Key k = Key{tensors[idx_in_order]-> GetDevice ()-> Index (), tensors[ idx_in_order]-> Dtype ()} ;
150145 auto it = states.find (k);
151146 if (it == states.end ()) {
152147 it = states.emplace (k, State{}).first ;
@@ -156,20 +151,20 @@ std::vector<std::vector<size_t>> ComputeBucketAssignmentBySize(const std::vector
156151
157152 const size_t element_size_in_bytes = kDataTypeToSize .at (tensor->Dtype ());
158153 const size_t bytes = tensor->NumElements () * element_size_in_bytes;
159- const size_t cap = current_cap ( state) ;
154+ const size_t cap = bucket_size_limits[ state. limit_idx ] ;
160155
161156 // Assign current tensor to current bucket first
162157 state.current_tensors .push_back (idx_in_order);
163158 state.current_bytes += bytes;
164159
165160 // If current bucket is out of capacity, then flush and move on to the next bucket
166161 if (state.current_bytes >= cap) {
167- flush_current_bucket (state);
162+ flushCurrentBucket (state);
168163 }
169164 }
170165
171166 // Flush the last bucket of each group manually
172- for (auto &key : key_order) { flush_current_bucket (states[key]); }
167+ for (auto &key : key_order) { flushCurrentBucket (states[key]); }
173168
174169 return buckets_all;
175170}
@@ -215,6 +210,7 @@ void Reducer::BuildBuckets(const std::vector<std::vector<size_t>> &bucket_indice
215210 CHECK (!bucket_indices[bucket_idx].empty ());
216211 const auto &first_param = params_[bucket_indices[bucket_idx][0 ]];
217212 bucket.dtype = first_param->Dtype ();
213+ // FIXME(zbl): use global_rank() in multi-node settings
218214 bucket.device_rank = first_param->GetDevice ()->rank ().thread_rank ();
219215
220216 size_t total_elems = 0 ;
@@ -274,8 +270,8 @@ void Reducer::RebuildBuckets() {
274270 tensors_in_order.push_back (params_[global_idx]);
275271 }
276272
277- const size_t first_cap_bytes = opts_.first_bucket_cap_mb * 1024ULL * 1024ULL ;
278- const size_t normal_cap_bytes = opts_.normal_bucket_cap_mb * 1024ULL * 1024ULL ;
273+ const size_t first_cap_bytes = opts_.first_bucket_cap_mb * kBytesPerMB ;
274+ const size_t normal_cap_bytes = opts_.normal_bucket_cap_mb * kBytesPerMB ;
279275 std::vector<size_t > bucket_size_limits = {first_cap_bytes, normal_cap_bytes};
280276 auto new_bucket_indices = ComputeBucketAssignmentBySize (tensors_in_order, bucket_size_limits, full_order);
281277
@@ -364,8 +360,7 @@ void Reducer::MarkVariableReadyDense(size_t variable_index) {
364360 auto &bucket = buckets_.at (loc.bucket_index );
365361
366362 // Record real order of bucket being ready
367- if (!has_rebuilt_bucket_ && variable_index < ready_seen_this_iter_.size ()
368- && !ready_seen_this_iter_[variable_index]) {
363+ if (!has_rebuilt_bucket_ && !ready_seen_this_iter_[variable_index]) {
369364 grad_ready_order_indices_.push_back (variable_index);
370365 ready_seen_this_iter_[variable_index] = 1 ;
371366 }
0 commit comments