@@ -46,77 +46,167 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
4646
4747namespace caffe {
4848
49+ #define START_ITER 1
50+
51+
52+ #ifdef CAFFE_PER_LAYER_TIMINGS
53+ #define LAYER_TIMING_START () do { \
54+ timer.Start (); \
55+ }while (0 )
56+
57+ #define LAYER_TIMING_STOP (name, index ) do { \
58+ name##_time_per_layer[index] += timer.MicroSeconds (); \
59+ }while (0 )
60+ #else
61+ #define LAYER_TIMING_START ()
62+
63+ #define LAYER_TIMING_STOP (name,index )
64+ #endif
65+
66+ template <typename Dtype>
67+ inline bool MultiSolver<Dtype>::IsSkipWaitGradient(int layer_id) {
68+ Net<Dtype>& net = *root_solver_->net ();
69+ const std::vector<shared_ptr<Layer<Dtype>>>& layers{ net.layers () };
70+ const std::vector<bool >& layer_need_backward{ net.layer_need_backward () };
71+
72+ if (!layer_need_backward[layer_id] || ((layers[layer_id]->layerOp != nullptr )
73+ && !layers[layer_id]->layerOp ->HasParameterSets ())) {
74+ DLOG (INFO) << " ForwardBackwardImpl: no need for apply_updates for layer # "
75+ << layer_id << " , skip on_delwt_wait, apply_updates, on_wtinc_ready" ;
76+ return true ;
77+ }
78+ return false ;
79+ }
80+
81+ template <typename Dtype>
82+ inline void MultiSolver<Dtype>::WaitAndUpdateGradient(int layer_id) {
83+ LAYER_TIMING_START ();
84+ for (int j = 0 ; j < callbacks_.size (); ++j) {
85+ callbacks_[j]->on_delwt_wait (layer_id);
86+ }
87+ LAYER_TIMING_STOP (waitcomm, layer_id);
88+
89+ #ifdef FW_OVERLAP_OPT
90+ if (layer_finished_flags_[layer_id]) {
91+ #endif
92+ LAYER_TIMING_START ();
93+ for (int j = 0 ; j < callbacks_.size (); ++j) {
94+ callbacks_[j]->apply_updates (layer_id);
95+ }
96+ LAYER_TIMING_STOP (update, layer_id);
97+ #ifdef FW_OVERLAP_OPT
98+ }
99+ #endif
100+ }
101+
49102template <typename Dtype>
50103Dtype MultiSolver<Dtype>::ForwardBackwardImpl (bool first, bool last) {
51104
52105 Dtype loss = 0 ;
53106 Net<Dtype>& net = *root_solver_->net ();
54107 const std::vector<shared_ptr<Layer<Dtype>>>& layers{ net.layers () };
55108 const std::vector<bool >& layer_need_backward{ net.layer_need_backward () };
109+ #ifdef FW_OVERLAP_OPT
110+ int iter = root_solver_->iter ();
111+ #endif
56112
57113#ifdef CAFFE_PER_LAYER_TIMINGS
58114 Timer& timer = root_solver_->timer ;
59115 std::vector<double >& forward_time_per_layer = root_solver_->forward_time_per_layer ;
60116 std::vector<double >& backward_time_per_layer = root_solver_->backward_time_per_layer ;
61117 std::vector<double >& update_time_per_layer = root_solver_->update_time_per_layer ;
118+ std::vector<double >& startcomm_time_per_layer = root_solver_->startcomm_time_per_layer ;
119+ std::vector<double >& waitcomm_time_per_layer = root_solver_->waitcomm_time_per_layer ;
62120#endif /* CAFFE_PER_LAYER_TIMINGS */
63121
122+
64123 for (int i = 0 ; i < layers.size (); ++i) {
65- #ifdef CAFFE_PER_LAYER_TIMINGS
66- timer.Start ();
124+ #ifdef FW_OVERLAP_OPT
125+ if (first && iter >= START_ITER + 1 ) {
126+ while (layer_finished_flags_[i] == false ) {
127+ if (IsSkipWaitGradient (i)) {
128+ break ;
129+ }
130+
131+ WaitAndUpdateGradient (i);
132+ if (layer_finished_flags_[i]) {
133+ break ;
134+ }
135+
136+ for (int k=i+1 ; k<layers.size (); k++) {
137+ if (layer_finished_flags_[k] || IsSkipWaitGradient (k)) {
138+ layer_finished_flags_[k] = true ;
139+ continue ;
140+ }
141+
142+ WaitAndUpdateGradient (k);
143+ if (layer_finished_flags_[k])
144+ break ;
145+ }
146+ }
147+ layer_finished_flags_[i] = false ;
148+ }
67149#endif
68- loss += net.ForwardFromTo (i, i);
69150
70- # ifdef CAFFE_PER_LAYER_TIMINGS
71- forward_time_per_layer[i] += timer. MicroSeconds ( );
72- # endif
151+ LAYER_TIMING_START ();
152+ loss += net. ForwardFromTo (i, i );
153+ LAYER_TIMING_STOP (forward, i);
73154 }
74155
75156 for (int i = layers.size () - 1 ; i >= 0 ; --i) {
76- #ifdef CAFFE_PER_LAYER_TIMINGS
77- timer.Start ();
78- #endif
79-
80157 if (!layer_need_backward[i]) {
81158 continue ;
82159 }
160+
161+ LAYER_TIMING_START ();
83162
84163 net.BackwardFromTo (i, i);
85164
165+ LAYER_TIMING_STOP (backward, i);
166+
86167 if (last && (layers[i]->layerOp != nullptr ) && layers[i]->layerOp ->HasParameterSets ()) {
168+ LAYER_TIMING_START ();
87169 for (int j = 0 ; j < callbacks_.size (); ++j) {
88- callbacks_[j]->on_iter_finished (i);
170+ callbacks_[j]->on_iter_finished (i);
89171 }
172+ LAYER_TIMING_STOP (startcomm, i);
90173 }
91-
92- #ifdef CAFFE_PER_LAYER_TIMINGS
93- backward_time_per_layer[i] += timer.MicroSeconds ();
94- #endif
95174 }
96175
176+ #ifdef FW_OVERLAP_OPT
177+ int max_iter = root_solver_->param ().max_iter ();
178+ bool test = (root_solver_->param ().test_interval ()
179+ && ((iter + 1 ) % root_solver_->param ().test_interval () == 0 ));
180+ if (last && (test || (iter == max_iter - 1 ))) {
181+ int finished_count = 0 ;
182+ while (finished_count < layers.size ()) {
183+ #else
97184 if (last) {
185+ #endif
98186
99187 for (int i = 0 ; i < layers.size (); ++i) {
100- #ifdef CAFFE_PER_LAYER_TIMINGS
101- timer.Start ();
188+ #ifdef FW_OVERLAP_OPT
189+ if (layer_finished_flags_[i])
190+ continue ;
102191#endif
103- if (!layer_need_backward[i] || ((layers[i]->layerOp != nullptr ) && !layers[i]->layerOp ->HasParameterSets ())) {
104- DLOG (INFO) << " ForwardBackwardImpl: no need for apply_updates for layer # " << i
105- << " , skip on_delwt_wait, apply_updates, on_wtinc_ready" ;
106- continue ;
107- }
192+ if (IsSkipWaitGradient (i)) {
193+ #ifdef FW_OVERLAP_OPT
194+ finished_count++;
195+ layer_finished_flags_[i] = true ;
196+ #endif
197+ continue ;
198+ }
108199
109- for (int j = 0 ; j < callbacks_.size (); ++j) {
110- callbacks_[j]->on_delwt_wait (i);
111- }
200+ WaitAndUpdateGradient (i);
112201
113- for (int j = 0 ; j < callbacks_.size (); ++j) {
114- callbacks_[j]->apply_updates (i);
115- }
116- #ifdef CAFFE_PER_LAYER_TIMINGS
117- update_time_per_layer[i] += timer.MicroSeconds ();
202+ #ifdef FW_OVERLAP_OPT
203+ if (layer_finished_flags_[i])
204+ finished_count++;
118205#endif
206+ }
207+ #ifdef FW_OVERLAP_OPT
119208 }
209+ #endif
120210 }
121211
122212 DLOG (WARNING) << " iter " << root_solver_->iter () << " , loss " << loss;
0 commit comments