From 48c7a79d688e97f19a388942444d52d72786d53f Mon Sep 17 00:00:00 2001 From: Chenghao Zhang Date: Tue, 6 Nov 2018 14:56:17 -0800 Subject: [PATCH] modify caffe output shape function --- mmdnn/conversion/caffe/graph.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/mmdnn/conversion/caffe/graph.py b/mmdnn/conversion/caffe/graph.py index 64c53131..1ee978d0 100644 --- a/mmdnn/conversion/caffe/graph.py +++ b/mmdnn/conversion/caffe/graph.py @@ -267,7 +267,15 @@ def compute_output_shapes(self, model): continue for node in sorted_nodes: if node.output_shape is None: - node.output_shape = TensorShape(*NodeKind.compute_output_shape(node)) + top_name = net.top_names.get(node.name) + if top_name is not None: + value = net.blobs.get(top_name[0]) + dims = list(value.shape) + dims = dims + [1] * (4 - len(dims)) + node.output_shape = TensorShape(*dims) + else: + node.output_shape = TensorShape(*NodeKind.compute_output_shape(node)) + os.close(tmp_handle) else: for node in sorted_nodes: