2021/10/4
代码主体如下:
def prepare_qat_fx_by_platform(
model: torch.nn.Module,
deploy_backend: BackendType,
prepare_custom_config_dict: Dict[str, Any] = {}):
"""
Args:
model (torch.nn.Module):
deploy_backend (BackendType):
>>> prepare_custom_config_dict : {
extra_qconfig_dict : Dict, Find explainations in get_qconfig_by_platform,
extra_quantizer_dict: Extra params for quantizer.
preserve_attr: Dict, Specify attribute of model which should be preserved
after prepare. Since symbolic_trace only store attributes which is
in forward. If model.func1 and model.backbone.func2 should be preserved,
{"": ["func1"], "backbone": ["func2"] } should work.
Attr below is inherited from Pytorch.
concrete_args: Specify input for model tracing.
extra_fuse_dict: Specify extra fusing patterns and functions.
}
"""
assert model.training, 'prepare_qat_fx_custom only works for models in ' + \
'train mode'
logger.info("Quantize model using {} scheme.".format(deploy_backend))
_swap_ff_with_fxff(model)
# Get Qconfig
extra_qconfig_dict = prepare_custom_config_dict.get('extra_qconfig_dict', {})
qconfig = get_qconfig_by_platform(deploy_backend, extra_qconfig_dict)
# Preserve attr.
preserve_attr_dict = dict()
if 'preserve_attr' in prepare_custom_config_dict:
for submodule_name in prepare_custom_config_dict['preserve_attr']:
cur_module = model
if submodule_name != "":
cur_module = getattr(model, submodule_name)
preserve_attr_list = prepare_custom_config_dict['preserve_attr'][submodule_name]
preserve_attr_dict[submodule_name] = {}
for attr in preserve_attr_list:
preserve_attr_dict[submodule_name][attr] = getattr(cur_module, attr)
# Symbolic trace
concrete_args = prepare_custom_config_dict.get('concrete_args', None)
graph_module = symbolic_trace(model, concrete_args=concrete_args)
# Model fusion.
extra_fuse_dict = prepare_custom_config_dict.get('extra_fuse_dict', {})
extra_fuse_dict.update(fuse_custom_config_dict)
# Prepare
import mqbench.custom_quantizer # noqa: F401
extra_quantizer_dict = prepare_custom_config_dict.get('extra_quantizer_dict', {})
quantizer = DEFAULT_MODEL_QUANTIZER[deploy_backend](extra_quantizer_dict, extra_fuse_dict)
prepared = quantizer.prepare(graph_module, qconfig)
# Restore attr.
if 'preserve_attr' in prepare_custom_config_dict:
for submodule_name in prepare_custom_config_dict['preserve_attr']:
cur_module = prepared
if submodule_name != "":
cur_module = getattr(prepared, submodule_name)
preserve_attr_list = prepare_custom_config_dict['preserve_attr'][submodule_name]
for attr in preserve_attr_list:
logger.info("Preserve attr: {}.{}".format(submodule_name, attr))
setattr(cur_module, attr, preserve_attr_dict[submodule_name][attr])
return prepared
_swap_ff_with_fxff(model)
,这是从torch.quantization.quantize_fx里导入进来的,这个函数的作用是将model里面的FloatFunctional替换成FXFloatFunctional,但是断点测试训练和测试过程中这个待替换的list始终是空。preserve_attr_dict
暂时没看出来有什么用。symbolic_trace
,找到了个不错的doc翻译Model fusion
part现在理解主要是fuse的部分,比如linear+BN,跳到MQBench对应的部分,默认字典里有两种(Linear在前或者在后),一个是Pytorch自带的,一个是自己补充的。quantizer = DEFAULT_MODEL_QUANTIZER[deploy_backend](extra_quantizer_dict, extra_fuse_dict)
有大门道! hash
之后会找到对应的quantizer,这里直接跳转到了custom_quantizer.py
里的Line345,对应TRTModelQuantizer
(TensorRT Model Quantizer)prepared = quantizer.prepare(graph_module, qconfig)
非常重要!完成了插入量化点和替换op的操作! quantizer.prepare
->model = self._weight_quant(model, qconfig)
->propagate_qconfig_(model, flattened_qconfig_dict)
->_propagate_qconfig_helper(module, qconfig_dict, allow_list)
->allow_list = get_default_qconfig_propagation_list()
->torch/quantization/quantization_mappings.py
时发现,DEFAULT_QAT_MODULE_MAPPINGS
中没有TransposeConv!可能得自己写了!model = self._weight_quant(model, qconfig)
->propagate_qconfig_(model, flattened_qconfig_dict)
中迭代地给每个子模块加一个qconfig,应该可以从这里入手!MQBench/mqbench/nn/intrinsic/
里加自己的qat module?model = self._weight_quant(model, qconfig)
->self._qat_swap_modules(model, self._additional_qat_module_mapping)
将一般的模块替换成QAT模块。 model = self._insert_fake_quantize_for_act_quant(model, qconfig)
插act量化点 conv1_1
的输入进行的。_find_act_quants
里有一个判断self._is_implicit_merge(modules, (node, _node))
是判断当前node是否为add/mul的输入,具体细节有点复杂,没看。 node_need_to_quantize_output
)里的东西却是上一个node,例如,x
不在self.module_type_to_quant_input
/self.function_type_to_quant_input
里,但是它作为conv1_1
的输入需要加入Listnode_need_to_quantize_output
里。setattr(model, quantizer_name, fake_quantizer)
只是在GraphModule里开一个quantizer_name的属性出来(例如lrelu_4_post_act_fake_quantizer
),后面with scope下面是把这个属性和对应node连起来,本质是在当前node后面再插一个node。_node.args = self._fix_succ_recursivly(_node.args, node, inserted_node)
是修改对应node的输入输出。model.recompile()
是从GraphModule的graph属性重新编译整张图,看起来像是根据图重新生成了code然后前传了一次。model.graph.lint()
检查图是否正确生成。