From b1f8be67dbe9b7b60773893f5713ee4a562696b4 Mon Sep 17 00:00:00 2001 From: long <2964901878@qq.com> Date: Tue, 7 Mar 2023 11:30:23 +0000 Subject: [PATCH] 'replace the usage of keys of self.io_module ' --- .../custom_quantizer/academic_quantizer.py | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/mqbench/custom_quantizer/academic_quantizer.py b/mqbench/custom_quantizer/academic_quantizer.py index 923d7834..9349efba 100644 --- a/mqbench/custom_quantizer/academic_quantizer.py +++ b/mqbench/custom_quantizer/academic_quantizer.py @@ -61,10 +61,11 @@ def module_type_to_quant_input(self) -> tuple: ) + self.additional_module_type def _get_post_act_8bit_node_name(self, model): - for node in self.io_module.values(): - for _arg in node.args: - if isinstance(_arg, torch.fx.node.Node): - self.post_act_8bit_node_name.append(_arg.name) + for nodes in self.io_module.values(): + for node in nodes: + for _arg in node.args: + if isinstance(_arg, torch.fx.node.Node): + self.post_act_8bit_node_name.append(_arg.name) def _get_io_module(self, model): total_args = [] @@ -77,11 +78,22 @@ def _get_io_module(self, model): the_first_layer = True total_args.append(_arg.name) if the_first_layer: - self.io_module[node.target] = node + if node.target in self.io_module.keys(): + # 如果已经创建过键值对了的话,列表中添加新的相关node + self.io_module[node.target].append(node) + else: + # 如果还没有创建键值对,则定义键值对 + self.io_module[node.target] = [node] if node.op == 'output': for _arg in node.args: if isinstance(_arg, torch.fx.node.Node): self.io_module[_arg.target] = _arg + if _arg.target in self.io_module.keys(): + # 如果已经创建过键值对了的话,列表中添加新的相关node + self.io_module[_arg.target].append(_arg) + else: + # 如果还没有创建键值对,则定义键值对 + self.io_module[_arg.target] = [_arg] def _find_act_quants(self, model: GraphModule) -> List: nodes = list(model.graph.nodes)