21/10/2 Learning Log

2021/10/2
2021/10/4

Reading MQBench

MQBench/mqbench/prepare_by_platform.py/prepare_qat_fx_by_platform

代码主体如下:

def prepare_qat_fx_by_platform(
        model: torch.nn.Module,
        deploy_backend: BackendType,
        prepare_custom_config_dict: Dict[str, Any] = {}):
    """
    Args:
        model (torch.nn.Module):
        deploy_backend (BackendType):
    >>> prepare_custom_config_dict : {
            extra_qconfig_dict : Dict, Find explainations in get_qconfig_by_platform,
            extra_quantizer_dict: Extra params for quantizer.
            preserve_attr: Dict, Specify attribute of model which should be preserved 
                after prepare. Since symbolic_trace only store attributes which is 
                in forward. If model.func1 and model.backbone.func2 should be preserved,
                {"": ["func1"], "backbone": ["func2"] } should work.
            Attr below is inherited from Pytorch.
            concrete_args: Specify input for model tracing.
            extra_fuse_dict: Specify extra fusing patterns and functions.
        }
    """
    assert model.training, 'prepare_qat_fx_custom only works for models in  ' + \
        'train mode'

    logger.info("Quantize model using {} scheme.".format(deploy_backend))

    _swap_ff_with_fxff(model)
    # Get Qconfig
    extra_qconfig_dict = prepare_custom_config_dict.get('extra_qconfig_dict', {})
    qconfig = get_qconfig_by_platform(deploy_backend, extra_qconfig_dict)
    # Preserve attr.
    preserve_attr_dict = dict()
    if 'preserve_attr' in prepare_custom_config_dict:
        for submodule_name in prepare_custom_config_dict['preserve_attr']:
            cur_module = model
            if submodule_name != "":
                cur_module = getattr(model, submodule_name)
            preserve_attr_list = prepare_custom_config_dict['preserve_attr'][submodule_name]
            preserve_attr_dict[submodule_name] = {}
            for attr in preserve_attr_list:
                preserve_attr_dict[submodule_name][attr] = getattr(cur_module, attr)
    # Symbolic trace
    concrete_args = prepare_custom_config_dict.get('concrete_args', None)
    graph_module = symbolic_trace(model, concrete_args=concrete_args)
    # Model fusion.
    extra_fuse_dict = prepare_custom_config_dict.get('extra_fuse_dict', {})
    extra_fuse_dict.update(fuse_custom_config_dict)
    # Prepare
    import mqbench.custom_quantizer  # noqa: F401
    extra_quantizer_dict = prepare_custom_config_dict.get('extra_quantizer_dict', {})
    quantizer = DEFAULT_MODEL_QUANTIZER[deploy_backend](extra_quantizer_dict, extra_fuse_dict)
    prepared = quantizer.prepare(graph_module, qconfig)
    # Restore attr.
    if 'preserve_attr' in prepare_custom_config_dict:
        for submodule_name in prepare_custom_config_dict['preserve_attr']:
            cur_module = prepared
            if submodule_name != "":
                cur_module = getattr(prepared, submodule_name)
            preserve_attr_list = prepare_custom_config_dict['preserve_attr'][submodule_name]
            for attr in preserve_attr_list:
                logger.info("Preserve attr: {}.{}".format(submodule_name, attr))
                setattr(cur_module, attr, preserve_attr_dict[submodule_name][attr])
    return prepared  
  • 在原文Line240有个 _swap_ff_with_fxff(model),这是从torch.quantization.quantize_fx里导入进来的,这个函数的作用是将model里面的FloatFunctional替换成FXFloatFunctional,但是断点测试训练和测试过程中这个待替换的list始终是空。
  • Line242那边,好像可以自己传量化参数进去,所以不用再去搞个quantizer了!
  • 第一个要改的地方是在Line202与Line185,这里就指定了量化的位宽,大可不必这么急。
  • Line245preserve_attr_dict暂时没看出来有什么用。
  • Line257 symbolic_trace,找到了个不错的doc翻译
  • Line259 Model fusionpart现在理解主要是fuse的部分,比如linear+BN,跳到MQBench对应的部分,默认字典里有两种(Linear在前或者在后),一个是Pytorch自带的,一个是自己补充的。
  • Line264 quantizer = DEFAULT_MODEL_QUANTIZER[deploy_backend](extra_quantizer_dict, extra_fuse_dict)有大门道!
    • 经过hash之后会找到对应的quantizer,这里直接跳转到了custom_quantizer.py里的Line345,对应TRTModelQuantizer(TensorRT Model Quantizer)
  • Line265 prepared = quantizer.prepare(graph_module, qconfig)非常重要!完成了插入量化点和替换op的操作!
    • 在检查quantizer.prepare->model = self._weight_quant(model, qconfig)->propagate_qconfig_(model, flattened_qconfig_dict)->_propagate_qconfig_helper(module, qconfig_dict, allow_list)->allow_list = get_default_qconfig_propagation_list()->torch/quantization/quantization_mappings.py时发现,DEFAULT_QAT_MODULE_MAPPINGS中没有TransposeConv!可能得自己写了!
    • model = self._weight_quant(model, qconfig)->propagate_qconfig_(model, flattened_qconfig_dict)中迭代地给每个子模块加一个qconfig,应该可以从这里入手!
    • 应该可以在MQBench/mqbench/nn/intrinsic/里加自己的qat module?
    • model = self._weight_quant(model, qconfig)->self._qat_swap_modules(model, self._additional_qat_module_mapping)将一般的模块替换成QAT模块。
      • Line217 插入自己的额外qat module mapping list,这也不用自己写了!
  • Line59 model = self._insert_fake_quantize_for_act_quant(model, qconfig)插act量化点
    • 对于ConvTranspose,它的量化有一个mismatch:act量化,但是参数不量化
    • 在第一轮没量化placeholder x是因为它不是任何层的输出,量化也是对conv1_1的输入进行的。
    • _find_act_quants里有一个判断self._is_implicit_merge(modules, (node, _node))是判断当前node是否为add/mul的输入,具体细节有点复杂,没看。
      • Conv和LeakyReLU之间不插量化点,但是leakyReLU和Maxpool之间要插,而且maxpool之后也要插量化点。
      • 而这个函数本身的作用是找插量化点的锚点,node指的是对应层的输入
      • 但是,存的时候(存到node_need_to_quantize_output)里的东西却是上一个node,例如,x不在self.module_type_to_quant_input/self.function_type_to_quant_input里,但是它作为conv1_1的输入需要加入Listnode_need_to_quantize_output里。
      • 这里找量化点是全量搜索,后面return set会剃掉重复的点。
    • Line75 setattr(model, quantizer_name, fake_quantizer)只是在GraphModule里开一个quantizer_name的属性出来(例如lrelu_4_post_act_fake_quantizer),后面with scope下面是把这个属性和对应node连起来,本质是在当前node后面再插一个node。
    • Line80 _node.args = self._fix_succ_recursivly(_node.args, node, inserted_node)是修改对应node的输入输出。
    • Line82 model.recompile()是从GraphModule的graph属性重新编译整张图,看起来像是根据图重新生成了code然后前传了一次。
    • Line83 model.graph.lint()检查图是否正确生成。