21/10/2 Learning Log
2021/10/2
2021/10/4
Reading MQBench
MQBench/mqbench/prepare_by_platform.py/prepare_qat_fx_by_platform
代码主体如下:
def prepare_qat_fx_by_platform(
model: torch.nn.Module,
deploy_backend: BackendType,
prepare_custom_config_dict: Dict[str, Any] = {}):
"""
Args:
model (torch.nn.Module):
deploy_backend (BackendType):
>>> prepare_custom_config_dict : {
extra_qconfig_dict : Dict, Find explainations in get_qconfig_by_platform,
extra_quantizer_dict: Extra params for quantizer.
preserve_attr: Dict, Specify attribute of model which should be preserved
after prepare. Since symbolic_trace only store attributes which is
in forward. If model.func1 and model.backbone.func2 should be preserved,
{"": ["func1"], "backbone": ["func2"] } should work.
Attr below is inherited from Pytorch.
concrete_args: Specify input for model tracing.
extra_fuse_dict: Specify extra fusing patterns and functions.
}
"""
assert model.training, 'prepare_qat_fx_custom only works for models in ' + \
'train mode'
logger.info("Quantize model using {} scheme.".format(deploy_backend))
_swap_ff_with_fxff(model)
# Get Qconfig
extra_qconfig_dict = prepare_custom_config_dict.get('extra_qconfig_dict', {})
qconfig = get_qconfig_by_platform(deploy_backend, extra_qconfig_dict)
# Preserve attr.
preserve_attr_dict = dict()
if 'preserve_attr' in prepare_custom_config_dict:
for submodule_name in prepare_custom_config_dict['preserve_attr']:
cur_module = model
if submodule_name != "":
cur_module = getattr(model, submodule_name)
preserve_attr_list = prepare_custom_config_dict['preserve_attr'][submodule_name]
preserve_attr_dict[submodule_name] = {}
for attr in preserve_attr_list:
preserve_attr_dict[submodule_name][attr] = getattr(cur_module, attr)
# Symbolic trace
concrete_args = prepare_custom_config_dict.get('concrete_args', None)
graph_module = symbolic_trace(model, concrete_args=concrete_args)
# Model fusion.
extra_fuse_dict = prepare_custom_config_dict.get('extra_fuse_dict', {})
extra_fuse_dict.update(fuse_custom_config_dict)
# Prepare
import mqbench.custom_quantizer # noqa: F401
extra_quantizer_dict = prepare_custom_config_dict.get('extra_quantizer_dict', {})
quantizer = DEFAULT_MODEL_QUANTIZER[deploy_backend](extra_quantizer_dict, extra_fuse_dict)
prepared = quantizer.prepare(graph_module, qconfig)
# Restore attr.
if 'preserve_attr' in prepare_custom_config_dict:
for submodule_name in prepare_custom_config_dict['preserve_attr']:
cur_module = prepared
if submodule_name != "":
cur_module = getattr(prepared, submodule_name)
preserve_attr_list = prepare_custom_config_dict['preserve_attr'][submodule_name]
for attr in preserve_attr_list:
logger.info("Preserve attr: {}.{}".format(submodule_name, attr))
setattr(cur_module, attr, preserve_attr_dict[submodule_name][attr])
return prepared
- 在原文Line240有个
_swap_ff_with_fxff(model),这是从torch.quantization.quantize_fx里导入进来的,这个函数的作用是将model里面的FloatFunctional替换成FXFloatFunctional,但是断点测试训练和测试过程中这个待替换的list始终是空。 - Line242那边,好像可以自己传量化参数进去,所以不用再去搞个quantizer了!
- 第一个要改的地方是在Line202与Line185,这里就指定了量化的位宽,大可不必这么急。
- Line245
preserve_attr_dict暂时没看出来有什么用。 - Line257
symbolic_trace,找到了个不错的doc翻译 - Line259
Model fusionpart现在理解主要是fuse的部分,比如linear+BN,跳到MQBench对应的部分,默认字典里有两种(Linear在前或者在后),一个是Pytorch自带的,一个是自己补充的。 - Line264
quantizer = DEFAULT_MODEL_QUANTIZER[deploy_backend](extra_quantizer_dict, extra_fuse_dict)有大门道!- 经过
hash之后会找到对应的quantizer,这里直接跳转到了custom_quantizer.py里的Line345,对应TRTModelQuantizer(TensorRT Model Quantizer)
- 经过
- Line265
prepared = quantizer.prepare(graph_module, qconfig)非常重要!完成了插入量化点和替换op的操作!- 在检查
quantizer.prepare->model = self._weight_quant(model, qconfig)->propagate_qconfig_(model, flattened_qconfig_dict)->_propagate_qconfig_helper(module, qconfig_dict, allow_list)->allow_list = get_default_qconfig_propagation_list()->torch/quantization/quantization_mappings.py时发现,DEFAULT_QAT_MODULE_MAPPINGS中没有TransposeConv!可能得自己写了! - 在
model = self._weight_quant(model, qconfig)->propagate_qconfig_(model, flattened_qconfig_dict)中迭代地给每个子模块加一个qconfig,应该可以从这里入手! - 应该可以在
MQBench/mqbench/nn/intrinsic/里加自己的qat module? model = self._weight_quant(model, qconfig)->self._qat_swap_modules(model, self._additional_qat_module_mapping)将一般的模块替换成QAT模块。- Line217 插入自己的额外qat module mapping list,这也不用自己写了!
- 在检查
- Line59
model = self._insert_fake_quantize_for_act_quant(model, qconfig)插act量化点- 对于ConvTranspose,它的量化有一个mismatch:act量化,但是参数不量化
- 在第一轮没量化placeholder x是因为它不是任何层的输出,量化也是对
conv1_1的输入进行的。 - 在
_find_act_quants里有一个判断self._is_implicit_merge(modules, (node, _node))是判断当前node是否为add/mul的输入,具体细节有点复杂,没看。- Conv和LeakyReLU之间不插量化点,但是leakyReLU和Maxpool之间要插,而且maxpool之后也要插量化点。
- 而这个函数本身的作用是找插量化点的锚点,node指的是对应层的输入
- 但是,存的时候(存到
node_need_to_quantize_output)里的东西却是上一个node,例如,x不在self.module_type_to_quant_input/self.function_type_to_quant_input里,但是它作为conv1_1的输入需要加入Listnode_need_to_quantize_output里。 - 这里找量化点是全量搜索,后面return set会剃掉重复的点。
- Line75
setattr(model, quantizer_name, fake_quantizer)只是在GraphModule里开一个quantizer_name的属性出来(例如lrelu_4_post_act_fake_quantizer),后面with scope下面是把这个属性和对应node连起来,本质是在当前node后面再插一个node。 - Line80
_node.args = self._fix_succ_recursivly(_node.args, node, inserted_node)是修改对应node的输入输出。 - Line82
model.recompile()是从GraphModule的graph属性重新编译整张图,看起来像是根据图重新生成了code然后前传了一次。 - Line83
model.graph.lint()检查图是否正确生成。