Skip to content
This repository was archived by the owner on Aug 5, 2022. It is now read-only.

Commit 9b87dc6

Browse files
committed
Revert "Fix the crash of GoogleNet-V1 inference due to reshape of batch size."
This reverts commit 7dd72ef.
1 parent 7dd72ef commit 9b87dc6

File tree

1 file changed

+24
-212
lines changed

1 file changed

+24
-212
lines changed

src/caffe/layers/mkldnn_concat_layer.cpp

Lines changed: 24 additions & 212 deletions
Original file line numberDiff line numberDiff line change
@@ -50,194 +50,37 @@ namespace caffe {
5050
template <typename Dtype>
5151
void MKLDNNConcatLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
5252
const vector<Blob<Dtype>*>& top) {
53-
VLOG(1) << "MKLDNNConcatLayer<Dtype>::LayerSetUp: " << this->layer_param_.name();
54-
55-
const ConcatParameter& concat_param = this->layer_param_.concat_param();
56-
CHECK(!(concat_param.has_axis() && concat_param.has_concat_dim()))
57-
<< "Either axis or concat_dim should be specified; not both.";
53+
// VLOG(1) << "MKLDNNConcatLayer<Dtype>::LayerSetUp: " << this->layer_param_.name();
5854

5955
int dim_src = bottom[0]->shape().size();
6056
// int dim_dst = dim_src;
6157

6258
num_concats_ = bottom.size();
63-
64-
const int num_axes = bottom[0]->num_axes();
65-
if (concat_param.has_concat_dim()) {
66-
concat_dimension = static_cast<int>(concat_param.concat_dim());
67-
// Don't allow negative indexing for concat_dim, a uint32 -- almost certainly unintended.
68-
CHECK_GE(concat_dimension, 0) << "casting concat_dim from uint32 to int32 "
69-
<< "produced negative result; concat_dim must satisfy "
70-
<< "0 <= concat_dimension < " << kMaxBlobAxes;
71-
CHECK_LT(concat_dimension, num_axes) << "concat_dimension out of range.";
72-
} else {
73-
concat_dimension = bottom[0]->CanonicalAxisIndex(concat_param.axis());
74-
}
59+
channels_ = 0;
7560

7661
for (auto i = 1; i < num_concats_; ++i) {
77-
if (concat_dimension == 0)
78-
{
79-
CHECK_EQ(bottom[0]->channels(), bottom[i]->channels());
80-
CHECK_EQ(bottom[0]->height(), bottom[i]->height());
81-
CHECK_EQ(bottom[0]->width(), bottom[i]->width());
82-
break;
83-
}
84-
else if (concat_dimension == 1)
85-
{
86-
CHECK_EQ(bottom[0]->num(), bottom[i]->num());
87-
CHECK_EQ(bottom[0]->height(), bottom[i]->height());
88-
CHECK_EQ(bottom[0]->width(), bottom[i]->width());
89-
break;
90-
}
91-
else if (concat_dimension == 2)
92-
{
93-
CHECK_EQ(bottom[0]->num(), bottom[i]->num());
94-
CHECK_EQ(bottom[0]->channels(), bottom[i]->channels());
95-
CHECK_EQ(bottom[0]->width(), bottom[i]->width());
96-
break;
97-
}
98-
else if (concat_dimension == 3)
99-
{
100-
CHECK_EQ(bottom[0]->num(), bottom[i]->num());
101-
CHECK_EQ(bottom[0]->channels(), bottom[i]->channels());
102-
CHECK_EQ(bottom[0]->height(), bottom[i]->height());
103-
break;
104-
}
62+
CHECK_EQ(bottom[0]->num(), bottom[i]->num());
63+
CHECK_EQ(bottom[0]->height(), bottom[i]->height());
64+
CHECK_EQ(bottom[0]->width(), bottom[i]->width());
10565
}
10666

107-
split_dims.reserve(num_concats_);
108-
if (concat_dimension == 0)
109-
{
110-
num_ = 0;
111-
channels_ = bottom[0]->channels();
112-
height_ = bottom[0]->height();
113-
width_ = bottom[0]->width();
114-
for (auto i = 0; i < num_concats_; ++i) {
115-
CHECK_EQ(dim_src, bottom[i]->shape().size());
116-
split_dims[i] = bottom[i]->num();
117-
num_ += split_dims[i];
118-
}
119-
}
120-
else if (concat_dimension == 1)
121-
{
122-
num_ = bottom[0]->num();
123-
channels_ = 0;
124-
height_ = bottom[0]->height();
125-
width_ = bottom[0]->width();
126-
for (auto i = 0; i < num_concats_; ++i) {
127-
CHECK_EQ(dim_src, bottom[i]->shape().size());
128-
split_dims[i] = bottom[i]->channels();
129-
channels_ += split_dims[i];
130-
}
131-
}
132-
else if (concat_dimension == 2)
133-
{
134-
num_ = bottom[0]->num();
135-
channels_ = bottom[0]->channels();
136-
height_ = 0;
137-
width_ = bottom[0]->width();
138-
for (auto i = 0; i < num_concats_; ++i) {
139-
CHECK_EQ(dim_src, bottom[i]->shape().size());
140-
split_dims[i] = bottom[i]->height();
141-
height_ += split_dims[i];
142-
}
143-
}
144-
else if (concat_dimension == 3)
145-
{
146-
num_ = bottom[0]->num();
147-
channels_ = bottom[0]->channels();
148-
height_ = bottom[0]->height();
149-
width_ = 0;
150-
for (auto i = 0; i < num_concats_; ++i) {
151-
CHECK_EQ(dim_src, bottom[i]->shape().size());
152-
split_dims[i] = bottom[i]->width();
153-
width_ += split_dims[i];
154-
}
67+
split_channels.reserve(num_concats_);
68+
for (auto i = 0; i < num_concats_; ++i) {
69+
CHECK_EQ(dim_src, bottom[i]->shape().size());
70+
71+
split_channels[i] = bottom[i]->channels();
72+
channels_ += split_channels[i];
15573
}
15674
}
15775

15876
template <typename Dtype>
15977
void MKLDNNConcatLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
16078
const vector<Blob<Dtype>*>& top) {
161-
VLOG(1) << "MKLDNNConcatLayer<Dtype>::Reshape: " << this->layer_param_.name();
79+
// VLOG(1) << "MKLDNNConcatLayer<Dtype>::Reshape: " << this->layer_param_.name();
16280

163-
if (concat_dimension == 0)
164-
{
165-
//Need to re-calculate the shape duo to the change of batch size
166-
num_ = 0;
167-
channels_ = bottom[0]->channels();
168-
height_ = bottom[0]->height();
169-
width_ = bottom[0]->width();
170-
//Also need to reshape the concat dim, in case the concat dim is just be reshaped by batch size
171-
for (auto i = 0; i < num_concats_; ++i) {
172-
split_dims[i] = bottom[i]->num();
173-
num_ += split_dims[i];
174-
}
175-
176-
if (this->channels_ == bottom[0]->channels() &&
177-
this->height_ == bottom[0]->height() &&
178-
this->width_ == bottom[0]->width()) {
179-
reshape = false;
180-
} else {
181-
reshape = true;
182-
}
183-
}
184-
else if (concat_dimension == 1)
185-
{
186-
num_ = bottom[0]->num();
187-
channels_ = 0;
188-
height_ = bottom[0]->height();
189-
width_ = bottom[0]->width();
190-
for (auto i = 0; i < num_concats_; ++i) {
191-
split_dims[i] = bottom[i]->channels();
192-
channels_ += split_dims[i];
193-
}
194-
195-
if (this->num_ == bottom[0]->num() &&
196-
this->height_ == bottom[0]->height() &&
197-
this->width_ == bottom[0]->width()) {
198-
reshape = false;
199-
} else {
200-
reshape = true;
201-
}
202-
}
203-
else if (concat_dimension == 2)
204-
{
205-
num_ = bottom[0]->num();
206-
channels_ = bottom[0]->channels();
207-
height_ = 0;
208-
width_ = bottom[0]->width();
209-
for (auto i = 0; i < num_concats_; ++i) {
210-
split_dims[i] = bottom[i]->height();
211-
height_ += split_dims[i];
212-
}
213-
214-
if (this->num_ == bottom[0]->num() &&
215-
this->channels_ == bottom[0]->channels() &&
216-
this->width_ == bottom[0]->width()) {
217-
reshape = false;
218-
} else {
219-
reshape = true;
220-
}
221-
}
222-
else if (concat_dimension == 3)
223-
{
224-
num_ = bottom[0]->num();
225-
channels_ = bottom[0]->channels();
226-
height_ = bottom[0]->height();
227-
width_ = 0;
228-
for (auto i = 0; i < num_concats_; ++i) {
229-
split_dims[i] = bottom[i]->width();
230-
width_ += split_dims[i];
231-
}
232-
233-
if (this->num_ == bottom[0]->num() &&
234-
this->channels_ == bottom[0]->channels() &&
235-
this->height_ == bottom[0]->height()) {
236-
reshape = false;
237-
} else {
238-
reshape = true;
239-
}
240-
}
81+
num_ = bottom[0]->num();
82+
height_ = bottom[0]->height();
83+
width_ = bottom[0]->width();
24184

24285
top[0]->Reshape(num_, channels_, height_, width_);
24386
}
@@ -283,25 +126,7 @@ void MKLDNNConcatLayer<Dtype>::InitConcatFwd(const vector<Blob<Dtype>*>& bottom,
283126
std::vector<primitive::at> srcs;
284127
for (auto i = 0; i < num_concats_; i++) {
285128
fwd_bottom_data.push_back(boost::shared_ptr<MKLDNNData<Dtype> >());
286-
287-
memory::dims input_tz = {0, 0, 0, 0};
288-
if (concat_dimension == 0)
289-
{
290-
input_tz = {split_dims[i], channels_, height_, width_};
291-
}
292-
else if (concat_dimension == 1)
293-
{
294-
input_tz = {num_, split_dims[i], height_, width_};
295-
}
296-
else if (concat_dimension == 2)
297-
{
298-
input_tz = {num_, channels_, split_dims[i], width_};
299-
}
300-
else if (concat_dimension == 3)
301-
{
302-
input_tz = {num_, channels_, height_, split_dims[i]};
303-
}
304-
129+
memory::dims input_tz = {num_, split_channels[i], height_, width_};
305130
memory::format src_mfmt = mfmt_nchw;
306131
shared_ptr<memory::primitive_desc> prv_src_mpd;
307132
shared_ptr<memory::primitive_desc> usr_src_mpd(
@@ -329,6 +154,8 @@ void MKLDNNConcatLayer<Dtype>::InitConcatFwd(const vector<Blob<Dtype>*>& bottom,
329154
shared_ptr<memory::primitive_desc> usr_dst_mpd(new memory::primitive_desc(
330155
{output_tz, data_type, mfmt_nchw}, cpu_engine));
331156

157+
// FIXME: concat dimension
158+
concat_dimension = 1;
332159
concatFwd_pd.reset(new concat::primitive_desc(concat_dimension, srcs_mpd));
333160

334161
shared_ptr<memory::primitive_desc> prv_dst_mpd(new memory::primitive_desc(
@@ -364,6 +191,9 @@ void MKLDNNConcatLayer<Dtype>::InitConcatBwd(const vector<Blob<Dtype>*>& top,
364191
memory::dims input_tz = {num_, channels_, height_, width_};
365192
memory::dims offsets = {0, 0, 0, 0};
366193

194+
// FIXME: concat dimension
195+
concat_dimension = 1;
196+
367197
shared_ptr<memory::primitive_desc> prv_diff_dst_mpd;
368198
shared_ptr<memory::primitive_desc> usr_diff_dst_mpd(
369199
new memory::primitive_desc({input_tz, data_type, mfmt_nchw},
@@ -388,25 +218,7 @@ void MKLDNNConcatLayer<Dtype>::InitConcatBwd(const vector<Blob<Dtype>*>& top,
388218
for (auto i = 0; i < num_concats_; i++) {
389219
bwd_bottom_diff.push_back(boost::shared_ptr<MKLDNNDiff<Dtype> >());
390220
reorders.push_back(MKLDNNPrimitive<Dtype>());
391-
392-
memory::dims dims = {0, 0, 0, 0};
393-
if (concat_dimension == 0)
394-
{
395-
dims = {split_dims[i], channels_, height_, width_};
396-
}
397-
else if (concat_dimension == 1)
398-
{
399-
dims = {num_, split_dims[i], height_, width_};
400-
}
401-
else if (concat_dimension == 2)
402-
{
403-
dims = {num_, channels_, split_dims[i], width_};
404-
}
405-
else if (concat_dimension == 3)
406-
{
407-
dims = {num_, channels_, height_, split_dims[i]};
408-
}
409-
221+
memory::dims dims = {num_, split_channels[i], height_, width_};
410222
shared_ptr<memory::primitive_desc> usr_diff_src_mpd(
411223
new memory::primitive_desc({dims, data_type, mfmt_nchw},
412224
cpu_engine));
@@ -447,7 +259,7 @@ void MKLDNNConcatLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
447259
LOG(INFO) << "MKLDNNConcatLayer<Dtype>::Forward_cpu: " << this->layer_param_.name();
448260
#endif
449261

450-
if ((NULL == concatFwd_pd) || (true == reshape))
262+
if (NULL == concatFwd_pd)
451263
InitConcatFwd(bottom, top);
452264
for (auto i = 0; i < num_concats_; i++) {
453265
// making reorders if needed.
@@ -472,7 +284,7 @@ void MKLDNNConcatLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top
472284
LOG(INFO) << "MKLDNNConcatLayer<Dtype>::Backward_cpu: " << this->layer_param_.name();
473285
#endif
474286

475-
if ((reorders.size() == 0) || (true == reshape))
287+
if (reorders.size() == 0)
476288
InitConcatBwd(top, propagate_down, bottom);
477289
bwd_top_diff->sync_before_read();
478290
for (auto i = 0; i < num_concats_; ++i) {

0 commit comments

Comments
 (0)