@@ -50,194 +50,37 @@ namespace caffe {
5050template <typename Dtype>
5151void MKLDNNConcatLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
5252 const vector<Blob<Dtype>*>& top) {
53- VLOG (1 ) << " MKLDNNConcatLayer<Dtype>::LayerSetUp: " << this ->layer_param_ .name ();
54-
55- const ConcatParameter& concat_param = this ->layer_param_ .concat_param ();
56- CHECK (!(concat_param.has_axis () && concat_param.has_concat_dim ()))
57- << " Either axis or concat_dim should be specified; not both." ;
53+ // VLOG(1) << "MKLDNNConcatLayer<Dtype>::LayerSetUp: " << this->layer_param_.name();
5854
5955 int dim_src = bottom[0 ]->shape ().size ();
6056 // int dim_dst = dim_src;
6157
6258 num_concats_ = bottom.size ();
63-
64- const int num_axes = bottom[0 ]->num_axes ();
65- if (concat_param.has_concat_dim ()) {
66- concat_dimension = static_cast <int >(concat_param.concat_dim ());
67- // Don't allow negative indexing for concat_dim, a uint32 -- almost certainly unintended.
68- CHECK_GE (concat_dimension, 0 ) << " casting concat_dim from uint32 to int32 "
69- << " produced negative result; concat_dim must satisfy "
70- << " 0 <= concat_dimension < " << kMaxBlobAxes ;
71- CHECK_LT (concat_dimension, num_axes) << " concat_dimension out of range." ;
72- } else {
73- concat_dimension = bottom[0 ]->CanonicalAxisIndex (concat_param.axis ());
74- }
59+ channels_ = 0 ;
7560
7661 for (auto i = 1 ; i < num_concats_; ++i) {
77- if (concat_dimension == 0 )
78- {
79- CHECK_EQ (bottom[0 ]->channels (), bottom[i]->channels ());
80- CHECK_EQ (bottom[0 ]->height (), bottom[i]->height ());
81- CHECK_EQ (bottom[0 ]->width (), bottom[i]->width ());
82- break ;
83- }
84- else if (concat_dimension == 1 )
85- {
86- CHECK_EQ (bottom[0 ]->num (), bottom[i]->num ());
87- CHECK_EQ (bottom[0 ]->height (), bottom[i]->height ());
88- CHECK_EQ (bottom[0 ]->width (), bottom[i]->width ());
89- break ;
90- }
91- else if (concat_dimension == 2 )
92- {
93- CHECK_EQ (bottom[0 ]->num (), bottom[i]->num ());
94- CHECK_EQ (bottom[0 ]->channels (), bottom[i]->channels ());
95- CHECK_EQ (bottom[0 ]->width (), bottom[i]->width ());
96- break ;
97- }
98- else if (concat_dimension == 3 )
99- {
100- CHECK_EQ (bottom[0 ]->num (), bottom[i]->num ());
101- CHECK_EQ (bottom[0 ]->channels (), bottom[i]->channels ());
102- CHECK_EQ (bottom[0 ]->height (), bottom[i]->height ());
103- break ;
104- }
62+ CHECK_EQ (bottom[0 ]->num (), bottom[i]->num ());
63+ CHECK_EQ (bottom[0 ]->height (), bottom[i]->height ());
64+ CHECK_EQ (bottom[0 ]->width (), bottom[i]->width ());
10565 }
10666
107- split_dims.reserve (num_concats_);
108- if (concat_dimension == 0 )
109- {
110- num_ = 0 ;
111- channels_ = bottom[0 ]->channels ();
112- height_ = bottom[0 ]->height ();
113- width_ = bottom[0 ]->width ();
114- for (auto i = 0 ; i < num_concats_; ++i) {
115- CHECK_EQ (dim_src, bottom[i]->shape ().size ());
116- split_dims[i] = bottom[i]->num ();
117- num_ += split_dims[i];
118- }
119- }
120- else if (concat_dimension == 1 )
121- {
122- num_ = bottom[0 ]->num ();
123- channels_ = 0 ;
124- height_ = bottom[0 ]->height ();
125- width_ = bottom[0 ]->width ();
126- for (auto i = 0 ; i < num_concats_; ++i) {
127- CHECK_EQ (dim_src, bottom[i]->shape ().size ());
128- split_dims[i] = bottom[i]->channels ();
129- channels_ += split_dims[i];
130- }
131- }
132- else if (concat_dimension == 2 )
133- {
134- num_ = bottom[0 ]->num ();
135- channels_ = bottom[0 ]->channels ();
136- height_ = 0 ;
137- width_ = bottom[0 ]->width ();
138- for (auto i = 0 ; i < num_concats_; ++i) {
139- CHECK_EQ (dim_src, bottom[i]->shape ().size ());
140- split_dims[i] = bottom[i]->height ();
141- height_ += split_dims[i];
142- }
143- }
144- else if (concat_dimension == 3 )
145- {
146- num_ = bottom[0 ]->num ();
147- channels_ = bottom[0 ]->channels ();
148- height_ = bottom[0 ]->height ();
149- width_ = 0 ;
150- for (auto i = 0 ; i < num_concats_; ++i) {
151- CHECK_EQ (dim_src, bottom[i]->shape ().size ());
152- split_dims[i] = bottom[i]->width ();
153- width_ += split_dims[i];
154- }
67+ split_channels.reserve (num_concats_);
68+ for (auto i = 0 ; i < num_concats_; ++i) {
69+ CHECK_EQ (dim_src, bottom[i]->shape ().size ());
70+
71+ split_channels[i] = bottom[i]->channels ();
72+ channels_ += split_channels[i];
15573 }
15674}
15775
15876template <typename Dtype>
15977void MKLDNNConcatLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
16078 const vector<Blob<Dtype>*>& top) {
161- VLOG (1 ) << " MKLDNNConcatLayer<Dtype>::Reshape: " << this ->layer_param_ .name ();
79+ // VLOG(1) << "MKLDNNConcatLayer<Dtype>::Reshape: " << this->layer_param_.name();
16280
163- if (concat_dimension == 0 )
164- {
165- // Need to re-calculate the shape duo to the change of batch size
166- num_ = 0 ;
167- channels_ = bottom[0 ]->channels ();
168- height_ = bottom[0 ]->height ();
169- width_ = bottom[0 ]->width ();
170- // Also need to reshape the concat dim, in case the concat dim is just be reshaped by batch size
171- for (auto i = 0 ; i < num_concats_; ++i) {
172- split_dims[i] = bottom[i]->num ();
173- num_ += split_dims[i];
174- }
175-
176- if (this ->channels_ == bottom[0 ]->channels () &&
177- this ->height_ == bottom[0 ]->height () &&
178- this ->width_ == bottom[0 ]->width ()) {
179- reshape = false ;
180- } else {
181- reshape = true ;
182- }
183- }
184- else if (concat_dimension == 1 )
185- {
186- num_ = bottom[0 ]->num ();
187- channels_ = 0 ;
188- height_ = bottom[0 ]->height ();
189- width_ = bottom[0 ]->width ();
190- for (auto i = 0 ; i < num_concats_; ++i) {
191- split_dims[i] = bottom[i]->channels ();
192- channels_ += split_dims[i];
193- }
194-
195- if (this ->num_ == bottom[0 ]->num () &&
196- this ->height_ == bottom[0 ]->height () &&
197- this ->width_ == bottom[0 ]->width ()) {
198- reshape = false ;
199- } else {
200- reshape = true ;
201- }
202- }
203- else if (concat_dimension == 2 )
204- {
205- num_ = bottom[0 ]->num ();
206- channels_ = bottom[0 ]->channels ();
207- height_ = 0 ;
208- width_ = bottom[0 ]->width ();
209- for (auto i = 0 ; i < num_concats_; ++i) {
210- split_dims[i] = bottom[i]->height ();
211- height_ += split_dims[i];
212- }
213-
214- if (this ->num_ == bottom[0 ]->num () &&
215- this ->channels_ == bottom[0 ]->channels () &&
216- this ->width_ == bottom[0 ]->width ()) {
217- reshape = false ;
218- } else {
219- reshape = true ;
220- }
221- }
222- else if (concat_dimension == 3 )
223- {
224- num_ = bottom[0 ]->num ();
225- channels_ = bottom[0 ]->channels ();
226- height_ = bottom[0 ]->height ();
227- width_ = 0 ;
228- for (auto i = 0 ; i < num_concats_; ++i) {
229- split_dims[i] = bottom[i]->width ();
230- width_ += split_dims[i];
231- }
232-
233- if (this ->num_ == bottom[0 ]->num () &&
234- this ->channels_ == bottom[0 ]->channels () &&
235- this ->height_ == bottom[0 ]->height ()) {
236- reshape = false ;
237- } else {
238- reshape = true ;
239- }
240- }
81+ num_ = bottom[0 ]->num ();
82+ height_ = bottom[0 ]->height ();
83+ width_ = bottom[0 ]->width ();
24184
24285 top[0 ]->Reshape (num_, channels_, height_, width_);
24386}
@@ -283,25 +126,7 @@ void MKLDNNConcatLayer<Dtype>::InitConcatFwd(const vector<Blob<Dtype>*>& bottom,
283126 std::vector<primitive::at> srcs;
284127 for (auto i = 0 ; i < num_concats_; i++) {
285128 fwd_bottom_data.push_back (boost::shared_ptr<MKLDNNData<Dtype> >());
286-
287- memory::dims input_tz = {0 , 0 , 0 , 0 };
288- if (concat_dimension == 0 )
289- {
290- input_tz = {split_dims[i], channels_, height_, width_};
291- }
292- else if (concat_dimension == 1 )
293- {
294- input_tz = {num_, split_dims[i], height_, width_};
295- }
296- else if (concat_dimension == 2 )
297- {
298- input_tz = {num_, channels_, split_dims[i], width_};
299- }
300- else if (concat_dimension == 3 )
301- {
302- input_tz = {num_, channels_, height_, split_dims[i]};
303- }
304-
129+ memory::dims input_tz = {num_, split_channels[i], height_, width_};
305130 memory::format src_mfmt = mfmt_nchw;
306131 shared_ptr<memory::primitive_desc> prv_src_mpd;
307132 shared_ptr<memory::primitive_desc> usr_src_mpd (
@@ -329,6 +154,8 @@ void MKLDNNConcatLayer<Dtype>::InitConcatFwd(const vector<Blob<Dtype>*>& bottom,
329154 shared_ptr<memory::primitive_desc> usr_dst_mpd (new memory::primitive_desc (
330155 {output_tz, data_type, mfmt_nchw}, cpu_engine));
331156
157+ // FIXME: concat dimension
158+ concat_dimension = 1 ;
332159 concatFwd_pd.reset (new concat::primitive_desc (concat_dimension, srcs_mpd));
333160
334161 shared_ptr<memory::primitive_desc> prv_dst_mpd (new memory::primitive_desc (
@@ -364,6 +191,9 @@ void MKLDNNConcatLayer<Dtype>::InitConcatBwd(const vector<Blob<Dtype>*>& top,
364191 memory::dims input_tz = {num_, channels_, height_, width_};
365192 memory::dims offsets = {0 , 0 , 0 , 0 };
366193
194+ // FIXME: concat dimension
195+ concat_dimension = 1 ;
196+
367197 shared_ptr<memory::primitive_desc> prv_diff_dst_mpd;
368198 shared_ptr<memory::primitive_desc> usr_diff_dst_mpd (
369199 new memory::primitive_desc ({input_tz, data_type, mfmt_nchw},
@@ -388,25 +218,7 @@ void MKLDNNConcatLayer<Dtype>::InitConcatBwd(const vector<Blob<Dtype>*>& top,
388218 for (auto i = 0 ; i < num_concats_; i++) {
389219 bwd_bottom_diff.push_back (boost::shared_ptr<MKLDNNDiff<Dtype> >());
390220 reorders.push_back (MKLDNNPrimitive<Dtype>());
391-
392- memory::dims dims = {0 , 0 , 0 , 0 };
393- if (concat_dimension == 0 )
394- {
395- dims = {split_dims[i], channels_, height_, width_};
396- }
397- else if (concat_dimension == 1 )
398- {
399- dims = {num_, split_dims[i], height_, width_};
400- }
401- else if (concat_dimension == 2 )
402- {
403- dims = {num_, channels_, split_dims[i], width_};
404- }
405- else if (concat_dimension == 3 )
406- {
407- dims = {num_, channels_, height_, split_dims[i]};
408- }
409-
221+ memory::dims dims = {num_, split_channels[i], height_, width_};
410222 shared_ptr<memory::primitive_desc> usr_diff_src_mpd (
411223 new memory::primitive_desc ({dims, data_type, mfmt_nchw},
412224 cpu_engine));
@@ -447,7 +259,7 @@ void MKLDNNConcatLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
447259 LOG (INFO) << " MKLDNNConcatLayer<Dtype>::Forward_cpu: " << this ->layer_param_ .name ();
448260#endif
449261
450- if (( NULL == concatFwd_pd) || ( true == reshape) )
262+ if (NULL == concatFwd_pd)
451263 InitConcatFwd (bottom, top);
452264 for (auto i = 0 ; i < num_concats_; i++) {
453265 // making reorders if needed.
@@ -472,7 +284,7 @@ void MKLDNNConcatLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top
472284 LOG (INFO) << " MKLDNNConcatLayer<Dtype>::Backward_cpu: " << this ->layer_param_ .name ();
473285#endif
474286
475- if (( reorders.size () == 0 ) || ( true == reshape) )
287+ if (reorders.size () == 0 )
476288 InitConcatBwd (top, propagate_down, bottom);
477289 bwd_top_diff->sync_before_read ();
478290 for (auto i = 0 ; i < num_concats_; ++i) {
0 commit comments