欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch中fuse_modules源碼解讀

 更新時間:2023年05月18日 14:10:48   作者:weixin_45919003  
這篇文章主要介紹了pytorch中fuse_modules,fuse_known_modules將給定的模塊列表mod_list中的一些常見模塊進行融合,返回融合后的模塊列表,本文通過實例代碼詳細講解,需要的朋友可以參考下

1. 官方代碼

FUSE_MODULES
TORCH.AO.QUANTIZATION.FUSE_MODULES的源代碼

2. fuse_modules源碼解讀

僅融合以下序列:

  • conv, bn
  • conv, bn, relu
  • conv, relu
  • linear, relu
  • bn, relu

網(wǎng)絡中所有其他序列保持不變,對于上述序列,用融合的模塊替換列表中的第一項,用identity替換其余模塊。

fuse_modules

def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
  • model:要進行操作的模型名稱
  • modules_to_fuse:要融合的模塊名稱的列表。如果只有一個要融合的模塊列表,可以是一個字符串列表,如:[‘conv1’, ‘bn1’, ‘relu’]
  • inplace:bool類型參數(shù),默認為false。融合發(fā)生在模型上,默認會返回一個新模型
  • fuser_func:接收模塊列表并輸出相同長度的融合模塊列表的函數(shù)。例如,fuser_func([convModule, BNModule]) 返回 [ConvBNModule, nn.Identity()] 。 默認為 fuse_known_modules
  • fuse_custom_config_dict :自定義配置,默認為none

fuse_known_modules

將給定的模塊列表mod_list中的一些常見模塊進行融合,返回融合后的模塊列表。融合后的模塊可以有效地減少模型計算量和內(nèi)存占用,從而提高模型的計算效率。

參數(shù)

  • mod_list:一個包含了一系列PyTorch模塊對象的列表,這些模塊可以是常見的卷積、線性、批歸一化等模塊。
  • is_qat:指定模型是否使用量化感知訓練(true使用,false不使用)
  • additional_fuser_method_mapping:一個可選的字典,用于指定額外的融合方法。字典的key是要融合的模塊類型,value是一個融合函數(shù),它將被用于融合指定類型的模塊。默認為None。
def fuse_known_modules(mod_list, is_qat, additional_fuser_method_mapping=None):
    r"""Returns a list of modules that fuses the operations specified
     in the input module list.
    Fuses only the following sequence of modules:
    conv, bn
    conv, bn, relu
    conv, relu
    linear, bn
    linear, relu
    For these sequences, the first element in the output module list performs
    the fused operation. The rest of the elements are set to nn.Identity()
    """
    types = tuple(type_before_parametrizations(m) for m in mod_list)
    fuser_method = get_fuser_method(types, additional_fuser_method_mapping)
    if fuser_method is None:
        raise NotImplementedError("Cannot fuse modules: {}".format(types))
    new_mod : List[Optional[nn.Module]] = [None] * len(mod_list)
    fused = fuser_method(is_qat, *mod_list)
    # NOTE: forward hooks not processed in the two following for loops will be lost after the fusion
    # Move pre forward hooks of the base module to resulting fused module
    for handle_id, pre_hook_fn in mod_list[0]._forward_pre_hooks.items():
        fused.register_forward_pre_hook(pre_hook_fn)
        del mod_list[0]._forward_pre_hooks[handle_id]
    # Move post forward hooks of the last module to resulting fused module
    for handle_id, hook_fn in mod_list[-1]._forward_hooks.items():
        fused.register_forward_hook(hook_fn)
        del mod_list[-1]._forward_hooks[handle_id]
    new_mod[0] = fused
    for i in range(1, len(mod_list)):
        identity = nn.Identity()
        identity.training = mod_list[0].training
        new_mod[i] = identity
    return new_mod
  • 在融合前,首先獲取mod_list中每個模塊的類型,并將它們作為一個元組存儲在types變量中。這個元組中的類型用于選擇要使用的模塊融合方法。在默認情況下,該函數(shù)支持一些特定的模塊序列進行融合。如果輸入模塊序列不符合這些支持的模式,則函數(shù)會嘗試使用 additional_fuser_method_mapping 中定義的自定義融合函數(shù)fuser_method。
  • 融合方法fuser_method :使用get_fuser_method() 函數(shù)根據(jù)types來選擇一個合適的融合函數(shù)。
  • – 在 get_fuser_method函數(shù)中調用了字典DEFAULT_OP_LIST_TO_FUSER_METHOD(定義了元組和融合函數(shù)之間的映射關系)。下面僅展示部分2d模塊融合
DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = {
    (nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn,
    (nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu,
    (nn.Conv2d, nn.ReLU): sequential_wrapper2(nni.ConvReLU2d),
    (nn.Linear, nn.BatchNorm1d): fuse_linear_bn,
    (nn.Linear, nn.ReLU): sequential_wrapper2(nni.LinearReLU),
    (nn.BatchNorm2d, nn.ReLU): sequential_wrapper2(nni.BNReLU2d),
}
  • 如果在特定模塊序列的additional_fuser_method_mapping中提供了自定義fuser函數(shù),則將使用該函數(shù)來代替默認的fuser函數(shù)。如果找不到合適的fuser函數(shù),該函數(shù)將引發(fā)NotImplementedError
  • 定義new_mod :使用 [None] * len(mod_list)創(chuàng)建一個長度為len(mod_list)的列表,這個列表中,每個元素都是一個nn.Module類型的可選對象,初始值為None。
  • 融合后的新模塊fused:使用fuser_method調用對應的融合函數(shù),如 fuse_conv_bn(is_qat, conv, bn)得到一個模塊融合后的新的模塊(ConvBn2d)。該模塊包含了卷積層和BN層的參數(shù),并將其組合成一個新的運算,該融合模塊的名稱默認為ConvBn2d、ConvBn1d或ConvBn3d。fuse_conv_bn函數(shù)在后面進行介紹。
  • 融合后,第一個for循環(huán)遍歷 mod_list列表中第一個模塊(mod_list[0])的handle_id(前向預處理鉤子函數(shù)的ID)和hook_fn(前向預處理鉤子函數(shù),在模塊前向傳遞時會被自動調用,用于執(zhí)行某些操作,如記錄中間結果、打印日志等。)。
  • – 然后,將這些鉤子函數(shù)注冊到fused模塊中,使其能夠在后續(xù)計算中被調用。
  • – 接著,從mod_list[0]._forward_pre_hooks字典中刪除這些鉤子函數(shù),避免這些鉤子函數(shù)被重復調用。
  • 第一個for循環(huán)的作用是將mod_list列表中第一個模塊的前向預處理鉤子函數(shù)從原始模塊對象中轉移到融合模塊對象中,以確保在使用融合模塊進行前向傳遞時,所有需要的操作都能夠被執(zhí)行。
  • 第二個for循環(huán)將mod_list列表中最后一個模塊的前向鉤子函數(shù)注冊到fused模塊中,并從原始模塊對象的鉤子字典中刪除這些鉤子函數(shù)。
  • 與前向預處理鉤子函數(shù)不同,前向鉤子函數(shù)是在模塊的前向傳遞過程中執(zhí)行的,通常用于在模塊輸出計算完成后執(zhí)行某些操作,如統(tǒng)計模型輸出分布、進行可視化等。
  • 最后,將融合好的fused模塊賦給前面定義的new_mod 列表的第一個元素,最后使用for循環(huán)補充identity()到new_mod列表,使其長度和原始模塊長度一致。

fuse_conv_bn

將給定的conv和bn模塊融合并返回融合后的模塊。

在此函數(shù)中構建了一個fused_module_class_map字典,用于指定模塊類型與對應的融合模塊類型之間的映射關系。

如果其類型在fused_module_class_map字典中有對應的融合模塊類型,則將這些模塊融合為一個新的模塊(ConvBn2d),如果沒有對應的融合模塊類型,則不對其進行融合處理。

def fuse_conv_bn(is_qat, conv, bn):
    assert(conv.training == bn.training),\
        "Conv and BN both must be in the same mode (train or eval)."
    fused_module_class_map = {
        nn.Conv1d: nni.ConvBn1d,
        nn.Conv2d: nni.ConvBn2d,
        nn.Conv3d: nni.ConvBn3d,
    }
    if is_qat:
        assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
        assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
        assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
        fused_module_class = fused_module_class_map.get((type(conv)), None)
        if fused_module_class is not None:
            return fused_module_class(conv, bn)
        else:
            raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn)))
    else:
        return nn.utils.fuse_conv_bn_eval(conv, bn)

返回調用的 fuse_conv_bn_eval(conv, bn) 函數(shù)如下

返回一個新的融合模塊,該模塊包含了卷積層和BN層的參數(shù),并將其組合成一個新的運算。

def fuse_conv_bn_eval(conv, bn, transpose=False):
    assert(not (conv.training or bn.training)), "Fusion only for eval!"
    fused_conv = copy.deepcopy(conv)
    fused_conv.weight, fused_conv.bias = \
        fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias,
                             bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias, transpose)
    return fused_conv

3. fuse_modules實際測試

3.1 modules_to_fuse參數(shù)的使用方法

1. 此參數(shù)的列表可以包含多個需要融合的組合,子模塊列表也可以,使用方法一

方法一:

modules_to_fuse = [ [‘conv1’, ‘bn1’, ‘relu1’], [‘submodule.conv’, ‘submodule.relu’]]

融合ResNet18中l(wèi)ayer1的conv和bn層如下:

print('\n Before fusion \n\n', r18_o.layer1)
r18_o.eval()
r18 = torch.quantization.fuse_modules(
    r18_o,
    [['conv1', 'bn1', 'relu'],
     ['layer1.0.conv1', 'layer1.0.bn1'], # , 'layer1.0.relu'],
     ['layer1.0.conv2', 'layer1.0.bn2'],
     ['layer1.1.conv1', 'layer1.1.bn1'], #, 'layer1.1.relu'],
     ['layer1.1.conv2', 'layer1.1.bn2']]
)
print('\n After fusion\n\n', r18.layer1)

結果:

ResNet18融合前:(僅顯示ResNet18中l(wèi)ayer1的網(wǎng)絡結構)

ResNet18融合后

此融合只將Conv2d和BN層進行融合,從上面對比可以看到融合后的 (bn) 變成了 identity(),(conv) 中的Conv2d是原本Conv2d和BN融合的。

2. 如果要融合的module被Sequential封裝了,可使用方法二

方法二:

torch.quantization.fuse_modules(m, [‘0’, ‘1’, ‘2’], inplace=True)

1. 使用方法二對ResNet18中模塊進行融合操作,融合代碼如下:

def fuse_model(self):
    for m in self.modules():
        if type(m) == BasicBlock:
            torch.quantization.fuse_modules(m, [['conv1', 'bn1', 'relu'], ['conv2', 'bn2']], inplace=True)

此處代碼是仿pytorch官方寫MobileNetV2模塊融合,這部分代碼寫在 class ResNet(nn.Module) 中,后面融合直接使用model.fuse_model(),得到的方法二融合ResNet18結果如下:

此處是分別對(conv2d、bn、relu)和(conv2d、bn)進行融合融合

2. 使用方法二對MobileNetv2中模塊進行融合操作

def fuse_model(self):
    for m in self.modules():
        if type(m) == ConvBNReLU:
            torch.quantization.fuse_modacules(m, ['0', '1', '2'], inplace=True)
        if type(m) == InvertedResidual:
            for idx in range(len(m.conv)):
                if type(m.conv[idx]) == nn.Conv2d:
                    torch.quantization.fuse_modules(m.conv, [str(idx), str(idx + 1)], inplace=True)

結果

MobileNetv2融合前(下面結果展示的是第一個殘差模塊,因此沒有第一個1x1的卷積)

MobileNetv2融合后

從此對比可以看到,融合前的conv2d、bn、relu融合成了ConvRelu2d(Conv2d,ReLU),這里面的Conv2d是融合前的Conv2d和BN融合的。

到此這篇關于pytorch中fuse_modules的文章就介紹到這了,更多相關pytorch中fuse_modules內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關文章希望大家以后多多支持腳本之家!

相關文章

  • python計算波峰波谷值的方法(極值點)

    python計算波峰波谷值的方法(極值點)

    這篇文章主要介紹了python求極值點(波峰波谷)求極值點主要用到了scipy庫,本文通過實例代碼給大家介紹的非常詳細,具有一定的參考借鑒價值,需要的朋友可以參考下
    2020-02-02
  • Python全局變量與局部變量區(qū)別及用法分析

    Python全局變量與局部變量區(qū)別及用法分析

    這篇文章主要介紹了Python全局變量與局部變量區(qū)別及用法,結合實例形式分析了Python全局變量與局部變量的定義、常見用法、區(qū)別及相關操作注意事項,需要的朋友可以參考下
    2018-09-09
  • Python(TensorFlow框架)實現(xiàn)手寫數(shù)字識別系統(tǒng)的方法

    Python(TensorFlow框架)實現(xiàn)手寫數(shù)字識別系統(tǒng)的方法

    這篇文章主要介紹了Python(TensorFlow框架)實現(xiàn)手寫數(shù)字識別系統(tǒng)的方法。小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧
    2018-05-05
  • python3.6根據(jù)m3u8下載mp4視頻

    python3.6根據(jù)m3u8下載mp4視頻

    這篇文章主要為大家詳細介紹了python3.6根據(jù)m3u8下載mp4視頻,具有一定的參考價值,感興趣的小伙伴們可以參考一下
    2019-06-06
  • 5 分鐘讀懂Python 中的 Hook 鉤子函數(shù)

    5 分鐘讀懂Python 中的 Hook 鉤子函數(shù)

    這篇文章主要介紹了5 分鐘掌握 Python 中的 Hook 鉤子函數(shù),本文通過實例代碼給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下
    2020-12-12
  • 解決Python列表字符不區(qū)分大小寫的問題

    解決Python列表字符不區(qū)分大小寫的問題

    今天小編就為大家分享一篇解決Python列表字符不區(qū)分大小寫的問題,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2019-12-12
  • python中的zip模塊

    python中的zip模塊

    這篇文章主要介紹了zip文件格式是通用的文檔壓縮標準,在ziplib模塊中,使用ZipFile類來操作zip文件,感興趣的朋友參考如下
    2021-08-08
  • 學習python (1)

    學習python (1)

    學習python (1)...
    2006-10-10
  • python Dataframe字符串合并的操作方法

    python Dataframe字符串合并的操作方法

    Dataframe的字符串合并包括2種場景,1.合并df中其中幾列字符串;2.將df中的字符串與外部字符串合并,本文主要介紹在Python下對Dataframe進行字符串合并操作的方法,感興趣的朋友跟隨小編一起看看吧
    2024-06-06
  • python_opencv用線段畫封閉矩形的實例

    python_opencv用線段畫封閉矩形的實例

    今天小編就為大家分享一篇python_opencv用線段畫封閉矩形的實例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2018-12-12

最新評論