@@ -50,37 +50,194 @@ namespace caffe {
5050template <typename Dtype>
5151void MKLDNNConcatLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
5252 const vector<Blob<Dtype>*>& top) {
53- // VLOG(1) << "MKLDNNConcatLayer<Dtype>::LayerSetUp: " << this->layer_param_.name();
53+ VLOG (1 ) << " MKLDNNConcatLayer<Dtype>::LayerSetUp: " << this ->layer_param_ .name ();
54+
55+ const ConcatParameter& concat_param = this ->layer_param_ .concat_param ();
56+ CHECK (!(concat_param.has_axis () && concat_param.has_concat_dim ()))
57+ << " Either axis or concat_dim should be specified; not both." ;
5458
5559 int dim_src = bottom[0 ]->shape ().size ();
5660 // int dim_dst = dim_src;
5761
5862 num_concats_ = bottom.size ();
59- channels_ = 0 ;
6063
61- for (auto i = 1 ; i < num_concats_; ++i) {
62- CHECK_EQ (bottom[0 ]->num (), bottom[i]->num ());
63- CHECK_EQ (bottom[0 ]->height (), bottom[i]->height ());
64- CHECK_EQ (bottom[0 ]->width (), bottom[i]->width ());
64+ const int num_axes = bottom[0 ]->num_axes ();
65+ if (concat_param.has_concat_dim ()) {
66+ concat_dimension = static_cast <int >(concat_param.concat_dim ());
67+ // Don't allow negative indexing for concat_dim, a uint32 -- almost certainly unintended.
68+ CHECK_GE (concat_dimension, 0 ) << " casting concat_dim from uint32 to int32 "
69+ << " produced negative result; concat_dim must satisfy "
70+ << " 0 <= concat_dimension < " << kMaxBlobAxes ;
71+ CHECK_LT (concat_dimension, num_axes) << " concat_dimension out of range." ;
72+ } else {
73+ concat_dimension = bottom[0 ]->CanonicalAxisIndex (concat_param.axis ());
6574 }
6675
67- split_channels.reserve (num_concats_);
68- for (auto i = 0 ; i < num_concats_; ++i) {
69- CHECK_EQ (dim_src, bottom[i]->shape ().size ());
76+ for (auto i = 1 ; i < num_concats_; ++i) {
77+ if (concat_dimension == 0 )
78+ {
79+ CHECK_EQ (bottom[0 ]->channels (), bottom[i]->channels ());
80+ CHECK_EQ (bottom[0 ]->height (), bottom[i]->height ());
81+ CHECK_EQ (bottom[0 ]->width (), bottom[i]->width ());
82+ break ;
83+ }
84+ else if (concat_dimension == 1 )
85+ {
86+ CHECK_EQ (bottom[0 ]->num (), bottom[i]->num ());
87+ CHECK_EQ (bottom[0 ]->height (), bottom[i]->height ());
88+ CHECK_EQ (bottom[0 ]->width (), bottom[i]->width ());
89+ break ;
90+ }
91+ else if (concat_dimension == 2 )
92+ {
93+ CHECK_EQ (bottom[0 ]->num (), bottom[i]->num ());
94+ CHECK_EQ (bottom[0 ]->channels (), bottom[i]->channels ());
95+ CHECK_EQ (bottom[0 ]->width (), bottom[i]->width ());
96+ break ;
97+ }
98+ else if (concat_dimension == 3 )
99+ {
100+ CHECK_EQ (bottom[0 ]->num (), bottom[i]->num ());
101+ CHECK_EQ (bottom[0 ]->channels (), bottom[i]->channels ());
102+ CHECK_EQ (bottom[0 ]->height (), bottom[i]->height ());
103+ break ;
104+ }
105+ }
70106
71- split_channels[i] = bottom[i]->channels ();
72- channels_ += split_channels[i];
107+ split_dims.reserve (num_concats_);
108+ if (concat_dimension == 0 )
109+ {
110+ num_ = 0 ;
111+ channels_ = bottom[0 ]->channels ();
112+ height_ = bottom[0 ]->height ();
113+ width_ = bottom[0 ]->width ();
114+ for (auto i = 0 ; i < num_concats_; ++i) {
115+ CHECK_EQ (dim_src, bottom[i]->shape ().size ());
116+ split_dims[i] = bottom[i]->num ();
117+ num_ += split_dims[i];
118+ }
119+ }
120+ else if (concat_dimension == 1 )
121+ {
122+ num_ = bottom[0 ]->num ();
123+ channels_ = 0 ;
124+ height_ = bottom[0 ]->height ();
125+ width_ = bottom[0 ]->width ();
126+ for (auto i = 0 ; i < num_concats_; ++i) {
127+ CHECK_EQ (dim_src, bottom[i]->shape ().size ());
128+ split_dims[i] = bottom[i]->channels ();
129+ channels_ += split_dims[i];
130+ }
131+ }
132+ else if (concat_dimension == 2 )
133+ {
134+ num_ = bottom[0 ]->num ();
135+ channels_ = bottom[0 ]->channels ();
136+ height_ = 0 ;
137+ width_ = bottom[0 ]->width ();
138+ for (auto i = 0 ; i < num_concats_; ++i) {
139+ CHECK_EQ (dim_src, bottom[i]->shape ().size ());
140+ split_dims[i] = bottom[i]->height ();
141+ height_ += split_dims[i];
142+ }
143+ }
144+ else if (concat_dimension == 3 )
145+ {
146+ num_ = bottom[0 ]->num ();
147+ channels_ = bottom[0 ]->channels ();
148+ height_ = bottom[0 ]->height ();
149+ width_ = 0 ;
150+ for (auto i = 0 ; i < num_concats_; ++i) {
151+ CHECK_EQ (dim_src, bottom[i]->shape ().size ());
152+ split_dims[i] = bottom[i]->width ();
153+ width_ += split_dims[i];
154+ }
73155 }
74156}
75157
76158template <typename Dtype>
77159void MKLDNNConcatLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
78160 const vector<Blob<Dtype>*>& top) {
79- // VLOG(1) << "MKLDNNConcatLayer<Dtype>::Reshape: " << this->layer_param_.name();
161+ VLOG (1 ) << " MKLDNNConcatLayer<Dtype>::Reshape: " << this ->layer_param_ .name ();
80162
81- num_ = bottom[0 ]->num ();
82- height_ = bottom[0 ]->height ();
83- width_ = bottom[0 ]->width ();
163+ if (concat_dimension == 0 )
164+ {
165+ // Need to re-calculate the shape duo to the change of batch size
166+ num_ = 0 ;
167+ channels_ = bottom[0 ]->channels ();
168+ height_ = bottom[0 ]->height ();
169+ width_ = bottom[0 ]->width ();
170+ // Also need to reshape the concat dim, in case the concat dim is just be reshaped by batch size
171+ for (auto i = 0 ; i < num_concats_; ++i) {
172+ split_dims[i] = bottom[i]->num ();
173+ num_ += split_dims[i];
174+ }
175+
176+ if (this ->channels_ == bottom[0 ]->channels () &&
177+ this ->height_ == bottom[0 ]->height () &&
178+ this ->width_ == bottom[0 ]->width ()) {
179+ reshape = false ;
180+ } else {
181+ reshape = true ;
182+ }
183+ }
184+ else if (concat_dimension == 1 )
185+ {
186+ num_ = bottom[0 ]->num ();
187+ channels_ = 0 ;
188+ height_ = bottom[0 ]->height ();
189+ width_ = bottom[0 ]->width ();
190+ for (auto i = 0 ; i < num_concats_; ++i) {
191+ split_dims[i] = bottom[i]->channels ();
192+ channels_ += split_dims[i];
193+ }
194+
195+ if (this ->num_ == bottom[0 ]->num () &&
196+ this ->height_ == bottom[0 ]->height () &&
197+ this ->width_ == bottom[0 ]->width ()) {
198+ reshape = false ;
199+ } else {
200+ reshape = true ;
201+ }
202+ }
203+ else if (concat_dimension == 2 )
204+ {
205+ num_ = bottom[0 ]->num ();
206+ channels_ = bottom[0 ]->channels ();
207+ height_ = 0 ;
208+ width_ = bottom[0 ]->width ();
209+ for (auto i = 0 ; i < num_concats_; ++i) {
210+ split_dims[i] = bottom[i]->height ();
211+ height_ += split_dims[i];
212+ }
213+
214+ if (this ->num_ == bottom[0 ]->num () &&
215+ this ->channels_ == bottom[0 ]->channels () &&
216+ this ->width_ == bottom[0 ]->width ()) {
217+ reshape = false ;
218+ } else {
219+ reshape = true ;
220+ }
221+ }
222+ else if (concat_dimension == 3 )
223+ {
224+ num_ = bottom[0 ]->num ();
225+ channels_ = bottom[0 ]->channels ();
226+ height_ = bottom[0 ]->height ();
227+ width_ = 0 ;
228+ for (auto i = 0 ; i < num_concats_; ++i) {
229+ split_dims[i] = bottom[i]->width ();
230+ width_ += split_dims[i];
231+ }
232+
233+ if (this ->num_ == bottom[0 ]->num () &&
234+ this ->channels_ == bottom[0 ]->channels () &&
235+ this ->height_ == bottom[0 ]->height ()) {
236+ reshape = false ;
237+ } else {
238+ reshape = true ;
239+ }
240+ }
84241
85242 top[0 ]->Reshape (num_, channels_, height_, width_);
86243}
@@ -126,7 +283,25 @@ void MKLDNNConcatLayer<Dtype>::InitConcatFwd(const vector<Blob<Dtype>*>& bottom,
126283 std::vector<primitive::at> srcs;
127284 for (auto i = 0 ; i < num_concats_; i++) {
128285 fwd_bottom_data.push_back (boost::shared_ptr<MKLDNNData<Dtype> >());
129- memory::dims input_tz = {num_, split_channels[i], height_, width_};
286+
287+ memory::dims input_tz = {0 , 0 , 0 , 0 };
288+ if (concat_dimension == 0 )
289+ {
290+ input_tz = {split_dims[i], channels_, height_, width_};
291+ }
292+ else if (concat_dimension == 1 )
293+ {
294+ input_tz = {num_, split_dims[i], height_, width_};
295+ }
296+ else if (concat_dimension == 2 )
297+ {
298+ input_tz = {num_, channels_, split_dims[i], width_};
299+ }
300+ else if (concat_dimension == 3 )
301+ {
302+ input_tz = {num_, channels_, height_, split_dims[i]};
303+ }
304+
130305 memory::format src_mfmt = mfmt_nchw;
131306 shared_ptr<memory::primitive_desc> prv_src_mpd;
132307 shared_ptr<memory::primitive_desc> usr_src_mpd (
@@ -154,8 +329,6 @@ void MKLDNNConcatLayer<Dtype>::InitConcatFwd(const vector<Blob<Dtype>*>& bottom,
154329 shared_ptr<memory::primitive_desc> usr_dst_mpd (new memory::primitive_desc (
155330 {output_tz, data_type, mfmt_nchw}, cpu_engine));
156331
157- // FIXME: concat dimension
158- concat_dimension = 1 ;
159332 concatFwd_pd.reset (new concat::primitive_desc (concat_dimension, srcs_mpd));
160333
161334 shared_ptr<memory::primitive_desc> prv_dst_mpd (new memory::primitive_desc (
@@ -191,9 +364,6 @@ void MKLDNNConcatLayer<Dtype>::InitConcatBwd(const vector<Blob<Dtype>*>& top,
191364 memory::dims input_tz = {num_, channels_, height_, width_};
192365 memory::dims offsets = {0 , 0 , 0 , 0 };
193366
194- // FIXME: concat dimension
195- concat_dimension = 1 ;
196-
197367 shared_ptr<memory::primitive_desc> prv_diff_dst_mpd;
198368 shared_ptr<memory::primitive_desc> usr_diff_dst_mpd (
199369 new memory::primitive_desc ({input_tz, data_type, mfmt_nchw},
@@ -218,7 +388,25 @@ void MKLDNNConcatLayer<Dtype>::InitConcatBwd(const vector<Blob<Dtype>*>& top,
218388 for (auto i = 0 ; i < num_concats_; i++) {
219389 bwd_bottom_diff.push_back (boost::shared_ptr<MKLDNNDiff<Dtype> >());
220390 reorders.push_back (MKLDNNPrimitive<Dtype>());
221- memory::dims dims = {num_, split_channels[i], height_, width_};
391+
392+ memory::dims dims = {0 , 0 , 0 , 0 };
393+ if (concat_dimension == 0 )
394+ {
395+ dims = {split_dims[i], channels_, height_, width_};
396+ }
397+ else if (concat_dimension == 1 )
398+ {
399+ dims = {num_, split_dims[i], height_, width_};
400+ }
401+ else if (concat_dimension == 2 )
402+ {
403+ dims = {num_, channels_, split_dims[i], width_};
404+ }
405+ else if (concat_dimension == 3 )
406+ {
407+ dims = {num_, channels_, height_, split_dims[i]};
408+ }
409+
222410 shared_ptr<memory::primitive_desc> usr_diff_src_mpd (
223411 new memory::primitive_desc ({dims, data_type, mfmt_nchw},
224412 cpu_engine));
@@ -259,7 +447,7 @@ void MKLDNNConcatLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
259447 LOG (INFO) << " MKLDNNConcatLayer<Dtype>::Forward_cpu: " << this ->layer_param_ .name ();
260448#endif
261449
262- if (NULL == concatFwd_pd)
450+ if (( NULL == concatFwd_pd) || ( true == reshape) )
263451 InitConcatFwd (bottom, top);
264452 for (auto i = 0 ; i < num_concats_; i++) {
265453 // making reorders if needed.
@@ -284,7 +472,7 @@ void MKLDNNConcatLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top
284472 LOG (INFO) << " MKLDNNConcatLayer<Dtype>::Backward_cpu: " << this ->layer_param_ .name ();
285473#endif
286474
287- if (reorders.size () == 0 )
475+ if (( reorders.size () == 0 ) || ( true == reshape) )
288476 InitConcatBwd (top, propagate_down, bottom);
289477 bwd_top_diff->sync_before_read ();
290478 for (auto i = 0 ; i < num_concats_; ++i) {
0 commit comments