@@ -46,16 +46,13 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
4646
4747namespace caffe {
4848
49- #define START_ITER 1
50-
51-
5249#ifdef CAFFE_PER_LAYER_TIMINGS
5350#define LAYER_TIMING_START () do { \
54- timer.Start (); \
51+ root_solver_-> timer .Start (); \
5552}while (0 )
5653
5754#define LAYER_TIMING_STOP (name, index ) do { \
58- name##_time_per_layer[index] += timer.MicroSeconds (); \
55+ root_solver_-> name ##_time_per_layer[index] += root_solver_-> timer .MicroSeconds (); \
5956}while (0 )
6057#else
6158#define LAYER_TIMING_START ()
@@ -101,50 +98,29 @@ inline void MultiSolver<Dtype>::WaitAndUpdateGradient(int layer_id) {
10198
10299template <typename Dtype>
103100Dtype MultiSolver<Dtype>::ForwardBackwardImpl (bool first, bool last) {
104-
105101 Dtype loss = 0 ;
106102 Net<Dtype>& net = *root_solver_->net ();
107103 const std::vector<shared_ptr<Layer<Dtype>>>& layers{ net.layers () };
108104 const std::vector<bool >& layer_need_backward{ net.layer_need_backward () };
109- #ifdef FW_OVERLAP_OPT
110- int iter = root_solver_->iter ();
111- #endif
112-
113- #ifdef CAFFE_PER_LAYER_TIMINGS
114- Timer& timer = root_solver_->timer ;
115- std::vector<double >& forward_time_per_layer = root_solver_->forward_time_per_layer ;
116- std::vector<double >& backward_time_per_layer = root_solver_->backward_time_per_layer ;
117- std::vector<double >& update_time_per_layer = root_solver_->update_time_per_layer ;
118- std::vector<double >& startcomm_time_per_layer = root_solver_->startcomm_time_per_layer ;
119- std::vector<double >& waitcomm_time_per_layer = root_solver_->waitcomm_time_per_layer ;
120- #endif /* CAFFE_PER_LAYER_TIMINGS */
121-
122105
123106 for (int i = 0 ; i < layers.size (); ++i) {
124107#ifdef FW_OVERLAP_OPT
125- if (first && iter >= START_ITER + 1 ) {
108+ if (first && IsSkipWaitGradient (i) == false ) {
126109 while (layer_finished_flags_[i] == false ) {
127- if (IsSkipWaitGradient (i)) {
128- break ;
129- }
130-
131110 WaitAndUpdateGradient (i);
132- if (layer_finished_flags_[i]) {
111+ if (layer_finished_flags_[i])
133112 break ;
134- }
135113
136114 for (int k=i+1 ; k<layers.size (); k++) {
137115 if (layer_finished_flags_[k] || IsSkipWaitGradient (k)) {
138116 layer_finished_flags_[k] = true ;
139117 continue ;
140118 }
141-
142119 WaitAndUpdateGradient (k);
143120 if (layer_finished_flags_[k])
144121 break ;
145122 }
146123 }
147- layer_finished_flags_[i] = false ;
148124 }
149125#endif
150126
@@ -159,12 +135,11 @@ Dtype MultiSolver<Dtype>::ForwardBackwardImpl(bool first, bool last) {
159135 }
160136
161137 LAYER_TIMING_START ();
162-
163138 net.BackwardFromTo (i, i);
164-
165139 LAYER_TIMING_STOP (backward, i);
166140
167- if (last && (layers[i]->layerOp != nullptr ) && layers[i]->layerOp ->HasParameterSets ()) {
141+ if (last && (layers[i]->layerOp != nullptr )
142+ && layers[i]->layerOp ->HasParameterSets ()) {
168143 LAYER_TIMING_START ();
169144 for (int j = 0 ; j < callbacks_.size (); ++j) {
170145 callbacks_[j]->on_iter_finished (i);
@@ -174,6 +149,7 @@ Dtype MultiSolver<Dtype>::ForwardBackwardImpl(bool first, bool last) {
174149 }
175150
176151#ifdef FW_OVERLAP_OPT
152+ int iter = root_solver_->iter ();
177153 int max_iter = root_solver_->param ().max_iter ();
178154 bool test = (root_solver_->param ().test_interval ()
179155 && ((iter + 1 ) % root_solver_->param ().test_interval () == 0 ));
@@ -183,22 +159,20 @@ Dtype MultiSolver<Dtype>::ForwardBackwardImpl(bool first, bool last) {
183159#else
184160 if (last) {
185161#endif
186-
187- for (int i = 0 ; i < layers.size (); ++i) {
188- #ifdef FW_OVERLAP_OPT
189- if (layer_finished_flags_[i])
190- continue ;
191- #endif
162+ for (int i = 0 ; i < layers.size (); ++i) {
192163 if (IsSkipWaitGradient (i)) {
193164#ifdef FW_OVERLAP_OPT
194165 finished_count++;
195166 layer_finished_flags_[i] = true ;
196167#endif
197168 continue ;
198169 }
170+ #ifdef FW_OVERLAP_OPT
171+ if (layer_finished_flags_[i])
172+ continue ;
173+ #endif
199174
200175 WaitAndUpdateGradient (i);
201-
202176#ifdef FW_OVERLAP_OPT
203177 if (layer_finished_flags_[i])
204178 finished_count++;
0 commit comments