Skip to content

Commit e7d57db

Browse files
Chamberlain0w0kilinchange
authored andcommitted
fix: fix requested changes
1 parent 8b427d5 commit e7d57db

File tree

3 files changed

+9
-13
lines changed

3 files changed

+9
-13
lines changed

infini_train/include/tensor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ class Tensor : public std::enable_shared_from_this<Tensor> {
196196
std::shared_ptr<Tensor> RequiresGrad();
197197

198198
std::shared_ptr<Tensor> grad() const;
199-
void set_grad(std::shared_ptr<Tensor> &grad);
199+
void set_grad(const std::shared_ptr<Tensor> &grad);
200200

201201
bool requires_grad() const;
202202
void set_requires_grad(bool requires_grad);

infini_train/src/nn/parallel/reducer.cc

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -120,19 +120,15 @@ std::vector<std::vector<size_t>> ComputeBucketAssignmentBySize(const std::vector
120120
std::vector<std::vector<size_t>> buckets_all;
121121
buckets_all.reserve(tensors.size());
122122

123-
auto advance_limit = [&](State &s) {
124-
// Iterate along bucket_size_limits till the last one everytime a bucket is completed
125-
if (s.limit_idx + 1 < bucket_size_limits.size()) {
126-
++s.limit_idx;
127-
}
128-
};
129-
130-
auto flushCurrentBucket = [&](State &s) {
123+
auto FlushCurrentBucket = [&](State &s) {
131124
if (!s.current_tensors.empty()) {
132125
buckets_all.push_back(std::move(s.current_tensors));
133126
s.current_tensors.clear();
134127
s.current_bytes = 0;
135-
advance_limit(s);
128+
// Iterate along bucket_size_limits till the last one everytime a bucket is completed
129+
if (s.limit_idx + 1 < bucket_size_limits.size()) {
130+
++s.limit_idx;
131+
}
136132
}
137133
};
138134

@@ -159,12 +155,12 @@ std::vector<std::vector<size_t>> ComputeBucketAssignmentBySize(const std::vector
159155

160156
// If current bucket is out of capacity, then flush and move on to the next bucket
161157
if (state.current_bytes >= cap) {
162-
flushCurrentBucket(state);
158+
FlushCurrentBucket(state);
163159
}
164160
}
165161

166162
// Flush the last bucket of each group manually
167-
for (auto &key : key_order) { flushCurrentBucket(states[key]); }
163+
for (auto &key : key_order) { FlushCurrentBucket(states[key]); }
168164

169165
return buckets_all;
170166
}

infini_train/src/tensor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ std::shared_ptr<Tensor> Tensor::RequiresGrad() {
544544
}
545545

546546
std::shared_ptr<Tensor> Tensor::grad() const { return grad_; };
547-
void Tensor::set_grad(std::shared_ptr<Tensor> &grad) {
547+
void Tensor::set_grad(const std::shared_ptr<Tensor> &grad) {
548548
if (grad) {
549549
CHECK(grad->GetDevice() == GetDevice());
550550
CHECK(grad->Dtype() == Dtype());

0 commit comments

Comments
 (0)