File tree Expand file tree Collapse file tree 3 files changed +9
-13
lines changed
Expand file tree Collapse file tree 3 files changed +9
-13
lines changed Original file line number Diff line number Diff line change @@ -196,7 +196,7 @@ class Tensor : public std::enable_shared_from_this<Tensor> {
196196 std::shared_ptr<Tensor> RequiresGrad ();
197197
198198 std::shared_ptr<Tensor> grad () const ;
199- void set_grad (std::shared_ptr<Tensor> &grad);
199+ void set_grad (const std::shared_ptr<Tensor> &grad);
200200
201201 bool requires_grad () const ;
202202 void set_requires_grad (bool requires_grad);
Original file line number Diff line number Diff line change @@ -120,19 +120,15 @@ std::vector<std::vector<size_t>> ComputeBucketAssignmentBySize(const std::vector
120120 std::vector<std::vector<size_t >> buckets_all;
121121 buckets_all.reserve (tensors.size ());
122122
123- auto advance_limit = [&](State &s) {
124- // Iterate along bucket_size_limits till the last one everytime a bucket is completed
125- if (s.limit_idx + 1 < bucket_size_limits.size ()) {
126- ++s.limit_idx ;
127- }
128- };
129-
130- auto flushCurrentBucket = [&](State &s) {
123+ auto FlushCurrentBucket = [&](State &s) {
131124 if (!s.current_tensors .empty ()) {
132125 buckets_all.push_back (std::move (s.current_tensors ));
133126 s.current_tensors .clear ();
134127 s.current_bytes = 0 ;
135- advance_limit (s);
128+ // Iterate along bucket_size_limits till the last one everytime a bucket is completed
129+ if (s.limit_idx + 1 < bucket_size_limits.size ()) {
130+ ++s.limit_idx ;
131+ }
136132 }
137133 };
138134
@@ -159,12 +155,12 @@ std::vector<std::vector<size_t>> ComputeBucketAssignmentBySize(const std::vector
159155
160156 // If current bucket is out of capacity, then flush and move on to the next bucket
161157 if (state.current_bytes >= cap) {
162- flushCurrentBucket (state);
158+ FlushCurrentBucket (state);
163159 }
164160 }
165161
166162 // Flush the last bucket of each group manually
167- for (auto &key : key_order) { flushCurrentBucket (states[key]); }
163+ for (auto &key : key_order) { FlushCurrentBucket (states[key]); }
168164
169165 return buckets_all;
170166}
Original file line number Diff line number Diff line change @@ -544,7 +544,7 @@ std::shared_ptr<Tensor> Tensor::RequiresGrad() {
544544}
545545
546546std::shared_ptr<Tensor> Tensor::grad () const { return grad_; };
547- void Tensor::set_grad (std::shared_ptr<Tensor> &grad) {
547+ void Tensor::set_grad (const std::shared_ptr<Tensor> &grad) {
548548 if (grad) {
549549 CHECK (grad->GetDevice () == GetDevice ());
550550 CHECK (grad->Dtype () == Dtype ());
You can’t perform that action at this time.
0 commit comments