Skip to content
This repository was archived by the owner on Aug 5, 2022. It is now read-only.

Commit 7dd72ef

Browse files
author
Yu, Chong
committed
Fix the crash of GoogleNet-V1 inference due to reshape of batch size.
1 parent 6ec49fc commit 7dd72ef

File tree

1 file changed

+212
-24
lines changed

1 file changed

+212
-24
lines changed

src/caffe/layers/mkldnn_concat_layer.cpp

Lines changed: 212 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -50,37 +50,194 @@ namespace caffe {
5050
template <typename Dtype>
5151
void MKLDNNConcatLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
5252
const vector<Blob<Dtype>*>& top) {
53-
// VLOG(1) << "MKLDNNConcatLayer<Dtype>::LayerSetUp: " << this->layer_param_.name();
53+
VLOG(1) << "MKLDNNConcatLayer<Dtype>::LayerSetUp: " << this->layer_param_.name();
54+
55+
const ConcatParameter& concat_param = this->layer_param_.concat_param();
56+
CHECK(!(concat_param.has_axis() && concat_param.has_concat_dim()))
57+
<< "Either axis or concat_dim should be specified; not both.";
5458

5559
int dim_src = bottom[0]->shape().size();
5660
// int dim_dst = dim_src;
5761

5862
num_concats_ = bottom.size();
59-
channels_ = 0;
6063

61-
for (auto i = 1; i < num_concats_; ++i) {
62-
CHECK_EQ(bottom[0]->num(), bottom[i]->num());
63-
CHECK_EQ(bottom[0]->height(), bottom[i]->height());
64-
CHECK_EQ(bottom[0]->width(), bottom[i]->width());
64+
const int num_axes = bottom[0]->num_axes();
65+
if (concat_param.has_concat_dim()) {
66+
concat_dimension = static_cast<int>(concat_param.concat_dim());
67+
// Don't allow negative indexing for concat_dim, a uint32 -- almost certainly unintended.
68+
CHECK_GE(concat_dimension, 0) << "casting concat_dim from uint32 to int32 "
69+
<< "produced negative result; concat_dim must satisfy "
70+
<< "0 <= concat_dimension < " << kMaxBlobAxes;
71+
CHECK_LT(concat_dimension, num_axes) << "concat_dimension out of range.";
72+
} else {
73+
concat_dimension = bottom[0]->CanonicalAxisIndex(concat_param.axis());
6574
}
6675

67-
split_channels.reserve(num_concats_);
68-
for (auto i = 0; i < num_concats_; ++i) {
69-
CHECK_EQ(dim_src, bottom[i]->shape().size());
76+
for (auto i = 1; i < num_concats_; ++i) {
77+
if (concat_dimension == 0)
78+
{
79+
CHECK_EQ(bottom[0]->channels(), bottom[i]->channels());
80+
CHECK_EQ(bottom[0]->height(), bottom[i]->height());
81+
CHECK_EQ(bottom[0]->width(), bottom[i]->width());
82+
break;
83+
}
84+
else if (concat_dimension == 1)
85+
{
86+
CHECK_EQ(bottom[0]->num(), bottom[i]->num());
87+
CHECK_EQ(bottom[0]->height(), bottom[i]->height());
88+
CHECK_EQ(bottom[0]->width(), bottom[i]->width());
89+
break;
90+
}
91+
else if (concat_dimension == 2)
92+
{
93+
CHECK_EQ(bottom[0]->num(), bottom[i]->num());
94+
CHECK_EQ(bottom[0]->channels(), bottom[i]->channels());
95+
CHECK_EQ(bottom[0]->width(), bottom[i]->width());
96+
break;
97+
}
98+
else if (concat_dimension == 3)
99+
{
100+
CHECK_EQ(bottom[0]->num(), bottom[i]->num());
101+
CHECK_EQ(bottom[0]->channels(), bottom[i]->channels());
102+
CHECK_EQ(bottom[0]->height(), bottom[i]->height());
103+
break;
104+
}
105+
}
70106

71-
split_channels[i] = bottom[i]->channels();
72-
channels_ += split_channels[i];
107+
split_dims.reserve(num_concats_);
108+
if (concat_dimension == 0)
109+
{
110+
num_ = 0;
111+
channels_ = bottom[0]->channels();
112+
height_ = bottom[0]->height();
113+
width_ = bottom[0]->width();
114+
for (auto i = 0; i < num_concats_; ++i) {
115+
CHECK_EQ(dim_src, bottom[i]->shape().size());
116+
split_dims[i] = bottom[i]->num();
117+
num_ += split_dims[i];
118+
}
119+
}
120+
else if (concat_dimension == 1)
121+
{
122+
num_ = bottom[0]->num();
123+
channels_ = 0;
124+
height_ = bottom[0]->height();
125+
width_ = bottom[0]->width();
126+
for (auto i = 0; i < num_concats_; ++i) {
127+
CHECK_EQ(dim_src, bottom[i]->shape().size());
128+
split_dims[i] = bottom[i]->channels();
129+
channels_ += split_dims[i];
130+
}
131+
}
132+
else if (concat_dimension == 2)
133+
{
134+
num_ = bottom[0]->num();
135+
channels_ = bottom[0]->channels();
136+
height_ = 0;
137+
width_ = bottom[0]->width();
138+
for (auto i = 0; i < num_concats_; ++i) {
139+
CHECK_EQ(dim_src, bottom[i]->shape().size());
140+
split_dims[i] = bottom[i]->height();
141+
height_ += split_dims[i];
142+
}
143+
}
144+
else if (concat_dimension == 3)
145+
{
146+
num_ = bottom[0]->num();
147+
channels_ = bottom[0]->channels();
148+
height_ = bottom[0]->height();
149+
width_ = 0;
150+
for (auto i = 0; i < num_concats_; ++i) {
151+
CHECK_EQ(dim_src, bottom[i]->shape().size());
152+
split_dims[i] = bottom[i]->width();
153+
width_ += split_dims[i];
154+
}
73155
}
74156
}
75157

76158
template <typename Dtype>
77159
void MKLDNNConcatLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
78160
const vector<Blob<Dtype>*>& top) {
79-
// VLOG(1) << "MKLDNNConcatLayer<Dtype>::Reshape: " << this->layer_param_.name();
161+
VLOG(1) << "MKLDNNConcatLayer<Dtype>::Reshape: " << this->layer_param_.name();
80162

81-
num_ = bottom[0]->num();
82-
height_ = bottom[0]->height();
83-
width_ = bottom[0]->width();
163+
if (concat_dimension == 0)
164+
{
165+
//Need to re-calculate the shape duo to the change of batch size
166+
num_ = 0;
167+
channels_ = bottom[0]->channels();
168+
height_ = bottom[0]->height();
169+
width_ = bottom[0]->width();
170+
//Also need to reshape the concat dim, in case the concat dim is just be reshaped by batch size
171+
for (auto i = 0; i < num_concats_; ++i) {
172+
split_dims[i] = bottom[i]->num();
173+
num_ += split_dims[i];
174+
}
175+
176+
if (this->channels_ == bottom[0]->channels() &&
177+
this->height_ == bottom[0]->height() &&
178+
this->width_ == bottom[0]->width()) {
179+
reshape = false;
180+
} else {
181+
reshape = true;
182+
}
183+
}
184+
else if (concat_dimension == 1)
185+
{
186+
num_ = bottom[0]->num();
187+
channels_ = 0;
188+
height_ = bottom[0]->height();
189+
width_ = bottom[0]->width();
190+
for (auto i = 0; i < num_concats_; ++i) {
191+
split_dims[i] = bottom[i]->channels();
192+
channels_ += split_dims[i];
193+
}
194+
195+
if (this->num_ == bottom[0]->num() &&
196+
this->height_ == bottom[0]->height() &&
197+
this->width_ == bottom[0]->width()) {
198+
reshape = false;
199+
} else {
200+
reshape = true;
201+
}
202+
}
203+
else if (concat_dimension == 2)
204+
{
205+
num_ = bottom[0]->num();
206+
channels_ = bottom[0]->channels();
207+
height_ = 0;
208+
width_ = bottom[0]->width();
209+
for (auto i = 0; i < num_concats_; ++i) {
210+
split_dims[i] = bottom[i]->height();
211+
height_ += split_dims[i];
212+
}
213+
214+
if (this->num_ == bottom[0]->num() &&
215+
this->channels_ == bottom[0]->channels() &&
216+
this->width_ == bottom[0]->width()) {
217+
reshape = false;
218+
} else {
219+
reshape = true;
220+
}
221+
}
222+
else if (concat_dimension == 3)
223+
{
224+
num_ = bottom[0]->num();
225+
channels_ = bottom[0]->channels();
226+
height_ = bottom[0]->height();
227+
width_ = 0;
228+
for (auto i = 0; i < num_concats_; ++i) {
229+
split_dims[i] = bottom[i]->width();
230+
width_ += split_dims[i];
231+
}
232+
233+
if (this->num_ == bottom[0]->num() &&
234+
this->channels_ == bottom[0]->channels() &&
235+
this->height_ == bottom[0]->height()) {
236+
reshape = false;
237+
} else {
238+
reshape = true;
239+
}
240+
}
84241

85242
top[0]->Reshape(num_, channels_, height_, width_);
86243
}
@@ -126,7 +283,25 @@ void MKLDNNConcatLayer<Dtype>::InitConcatFwd(const vector<Blob<Dtype>*>& bottom,
126283
std::vector<primitive::at> srcs;
127284
for (auto i = 0; i < num_concats_; i++) {
128285
fwd_bottom_data.push_back(boost::shared_ptr<MKLDNNData<Dtype> >());
129-
memory::dims input_tz = {num_, split_channels[i], height_, width_};
286+
287+
memory::dims input_tz = {0, 0, 0, 0};
288+
if (concat_dimension == 0)
289+
{
290+
input_tz = {split_dims[i], channels_, height_, width_};
291+
}
292+
else if (concat_dimension == 1)
293+
{
294+
input_tz = {num_, split_dims[i], height_, width_};
295+
}
296+
else if (concat_dimension == 2)
297+
{
298+
input_tz = {num_, channels_, split_dims[i], width_};
299+
}
300+
else if (concat_dimension == 3)
301+
{
302+
input_tz = {num_, channels_, height_, split_dims[i]};
303+
}
304+
130305
memory::format src_mfmt = mfmt_nchw;
131306
shared_ptr<memory::primitive_desc> prv_src_mpd;
132307
shared_ptr<memory::primitive_desc> usr_src_mpd(
@@ -154,8 +329,6 @@ void MKLDNNConcatLayer<Dtype>::InitConcatFwd(const vector<Blob<Dtype>*>& bottom,
154329
shared_ptr<memory::primitive_desc> usr_dst_mpd(new memory::primitive_desc(
155330
{output_tz, data_type, mfmt_nchw}, cpu_engine));
156331

157-
// FIXME: concat dimension
158-
concat_dimension = 1;
159332
concatFwd_pd.reset(new concat::primitive_desc(concat_dimension, srcs_mpd));
160333

161334
shared_ptr<memory::primitive_desc> prv_dst_mpd(new memory::primitive_desc(
@@ -191,9 +364,6 @@ void MKLDNNConcatLayer<Dtype>::InitConcatBwd(const vector<Blob<Dtype>*>& top,
191364
memory::dims input_tz = {num_, channels_, height_, width_};
192365
memory::dims offsets = {0, 0, 0, 0};
193366

194-
// FIXME: concat dimension
195-
concat_dimension = 1;
196-
197367
shared_ptr<memory::primitive_desc> prv_diff_dst_mpd;
198368
shared_ptr<memory::primitive_desc> usr_diff_dst_mpd(
199369
new memory::primitive_desc({input_tz, data_type, mfmt_nchw},
@@ -218,7 +388,25 @@ void MKLDNNConcatLayer<Dtype>::InitConcatBwd(const vector<Blob<Dtype>*>& top,
218388
for (auto i = 0; i < num_concats_; i++) {
219389
bwd_bottom_diff.push_back(boost::shared_ptr<MKLDNNDiff<Dtype> >());
220390
reorders.push_back(MKLDNNPrimitive<Dtype>());
221-
memory::dims dims = {num_, split_channels[i], height_, width_};
391+
392+
memory::dims dims = {0, 0, 0, 0};
393+
if (concat_dimension == 0)
394+
{
395+
dims = {split_dims[i], channels_, height_, width_};
396+
}
397+
else if (concat_dimension == 1)
398+
{
399+
dims = {num_, split_dims[i], height_, width_};
400+
}
401+
else if (concat_dimension == 2)
402+
{
403+
dims = {num_, channels_, split_dims[i], width_};
404+
}
405+
else if (concat_dimension == 3)
406+
{
407+
dims = {num_, channels_, height_, split_dims[i]};
408+
}
409+
222410
shared_ptr<memory::primitive_desc> usr_diff_src_mpd(
223411
new memory::primitive_desc({dims, data_type, mfmt_nchw},
224412
cpu_engine));
@@ -259,7 +447,7 @@ void MKLDNNConcatLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
259447
LOG(INFO) << "MKLDNNConcatLayer<Dtype>::Forward_cpu: " << this->layer_param_.name();
260448
#endif
261449

262-
if (NULL == concatFwd_pd)
450+
if ((NULL == concatFwd_pd) || (true == reshape))
263451
InitConcatFwd(bottom, top);
264452
for (auto i = 0; i < num_concats_; i++) {
265453
// making reorders if needed.
@@ -284,7 +472,7 @@ void MKLDNNConcatLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top
284472
LOG(INFO) << "MKLDNNConcatLayer<Dtype>::Backward_cpu: " << this->layer_param_.name();
285473
#endif
286474

287-
if (reorders.size() == 0)
475+
if ((reorders.size() == 0) || (true == reshape))
288476
InitConcatBwd(top, propagate_down, bottom);
289477
bwd_top_diff->sync_before_read();
290478
for (auto i = 0; i < num_concats_; ++i) {

0 commit comments

Comments
 (0)