CNN_model.py
2021/2/5
包括:
AuxiliaryHead(nn.Module)AuxiliaryHeadImageNet(nn.Module)CNNGenotypeModel(FinalModel)CNNGenotypeCell(nn.Module)
这几个类。其中bnn_model.py中的BNNGenotypeModel直接继承自CNNGenotypeModel,所以目前只看了CNNGenotypeModel。
New Funcs
Class
CNNGenotypeModel(FinalModel)
由NAME = "cnn_final_model"推测为最终训练的CNN模型的基类。
__init__方法
出现的属性有(不含父类中的属性):search_space:搜索空间- 似乎是一个tuple,里面包括
num_cell_groupsnum_init_nodescell_layoutreduce_cell_groupsnum_layers
等元素
- 似乎是一个tuple,里面包括
device:训练设备?genotypes:是evolution方法的?但是为什么这个类也叫genotypenum_classes:应该指的是图像分类的类别数init_channels:初始通道数layer_channels:输入一个tuple,目测是每层的channel数?stem_multiplierdropout_rate:训练的dropout率dropout_path_rate:路径dropout率auxiliary_headauxiliary_cfguse_stem:默认参数为"conv_bn_3x3",*可选的类型是(list, tuple)或者bool?stem_stridestem_affineno_fc:似乎是决定最后输出是否进行分类的bool值。cell_use_preprocesscell_pool_batchnormcell_group_kwargs:应该是自定义的cell布局(包括cell class和channel数?)cell_independent_conncell_use_shortcutcell_shortcut_op_typecell_preprocess_stridecell_preprocess_normalschedule_cfg
*认为初始化中对
genotypes的处理不是很重要,没看。
*Line134(if not self.use_stem:)至Line150(init_strides = [1] * self._num_init)似乎是根据self.use_stem进行"sub module"的初始化。
*Line161(for i_layer, stride in enumerate(strides):)至Line169(num_out_channels = num_channels)计算每层的输入输出channel数,至Line182(kwargs = {})是用cell_group_kwargs中的设置获取channel数。
*Line185(cell = CNNGenotypeCell(self.search_space,)至Line205(self.cells.append(cell))是根据config生成cell,并拼到一起。
*Line215(self.global_pooling = nn.AdaptiveAvgPool2d(1))至Line225(self.to(self.device))是设置global_pooling、drop_out、final_classification、device。-
set_hook方法和_hook_intermediate_feature方法*Line232(
def set_hook(self):)至Line247(pass)注册hook用于计算参数数量。hook的写法很重要,但是还没看。 -
forward方法、forward_one_step方法和forward_one_step_callback方法*应该是前向传播、单步前传、带回调的单步前传,暂时不关心。
问题集合
stem指的是什么?forward中也出现了
prev_num_channel是干嘛的?auxiliary_net是什么?对应Line207(if i_layer == (2 * self._num_layers) // 3 and self.auxiliary_head:)至Line213(prev_num_channels[-1], num_classes, **(auxiliary_cfg or {})))是做什么的?
TO-DO
- 中间出现的
CNNGenotypeCell类没看 ops.py文件中的ops.get_op没有看- pytorch里hook的写法
- 三个
forward也没看