亚洲乱码中文字幕综合,中国熟女仑乱hd,亚洲精品乱拍国产一区二区三区,一本大道卡一卡二卡三乱码全集资源,又粗又黄又硬又爽的免费视频

pytorch中fuse_modules源碼解讀

 更新時(shí)間:2023年05月18日 14:10:48   作者:weixin_45919003  
這篇文章主要介紹了pytorch中fuse_modules,fuse_known_modules將給定的模塊列表mod_list中的一些常見模塊進(jìn)行融合,返回融合后的模塊列表,本文通過實(shí)例代碼詳細(xì)講解,需要的朋友可以參考下

1. 官方代碼

FUSE_MODULES
TORCH.AO.QUANTIZATION.FUSE_MODULES的源代碼

2. fuse_modules源碼解讀

僅融合以下序列:

  • conv, bn
  • conv, bn, relu
  • conv, relu
  • linear, relu
  • bn, relu

網(wǎng)絡(luò)中所有其他序列保持不變,對(duì)于上述序列,用融合的模塊替換列表中的第一項(xiàng),用identity替換其余模塊。

fuse_modules

def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
  • model:要進(jìn)行操作的模型名稱
  • modules_to_fuse:要融合的模塊名稱的列表。如果只有一個(gè)要融合的模塊列表,可以是一個(gè)字符串列表,如:[‘conv1’, ‘bn1’, ‘relu’]
  • inplace:bool類型參數(shù),默認(rèn)為false。融合發(fā)生在模型上,默認(rèn)會(huì)返回一個(gè)新模型
  • fuser_func:接收模塊列表并輸出相同長(zhǎng)度的融合模塊列表的函數(shù)。例如,fuser_func([convModule, BNModule]) 返回 [ConvBNModule, nn.Identity()] 。 默認(rèn)為 fuse_known_modules
  • fuse_custom_config_dict :自定義配置,默認(rèn)為none

fuse_known_modules

將給定的模塊列表mod_list中的一些常見模塊進(jìn)行融合,返回融合后的模塊列表。融合后的模塊可以有效地減少模型計(jì)算量和內(nèi)存占用,從而提高模型的計(jì)算效率。

參數(shù)

  • mod_list:一個(gè)包含了一系列PyTorch模塊對(duì)象的列表,這些模塊可以是常見的卷積、線性、批歸一化等模塊。
  • is_qat:指定模型是否使用量化感知訓(xùn)練(true使用,false不使用)
  • additional_fuser_method_mapping:一個(gè)可選的字典,用于指定額外的融合方法。字典的key是要融合的模塊類型,value是一個(gè)融合函數(shù),它將被用于融合指定類型的模塊。默認(rèn)為None。
def fuse_known_modules(mod_list, is_qat, additional_fuser_method_mapping=None):
    r"""Returns a list of modules that fuses the operations specified
     in the input module list.
    Fuses only the following sequence of modules:
    conv, bn
    conv, bn, relu
    conv, relu
    linear, bn
    linear, relu
    For these sequences, the first element in the output module list performs
    the fused operation. The rest of the elements are set to nn.Identity()
    """
    types = tuple(type_before_parametrizations(m) for m in mod_list)
    fuser_method = get_fuser_method(types, additional_fuser_method_mapping)
    if fuser_method is None:
        raise NotImplementedError("Cannot fuse modules: {}".format(types))
    new_mod : List[Optional[nn.Module]] = [None] * len(mod_list)
    fused = fuser_method(is_qat, *mod_list)
    # NOTE: forward hooks not processed in the two following for loops will be lost after the fusion
    # Move pre forward hooks of the base module to resulting fused module
    for handle_id, pre_hook_fn in mod_list[0]._forward_pre_hooks.items():
        fused.register_forward_pre_hook(pre_hook_fn)
        del mod_list[0]._forward_pre_hooks[handle_id]
    # Move post forward hooks of the last module to resulting fused module
    for handle_id, hook_fn in mod_list[-1]._forward_hooks.items():
        fused.register_forward_hook(hook_fn)
        del mod_list[-1]._forward_hooks[handle_id]
    new_mod[0] = fused
    for i in range(1, len(mod_list)):
        identity = nn.Identity()
        identity.training = mod_list[0].training
        new_mod[i] = identity
    return new_mod
  • 在融合前,首先獲取mod_list中每個(gè)模塊的類型,并將它們作為一個(gè)元組存儲(chǔ)在types變量中。這個(gè)元組中的類型用于選擇要使用的模塊融合方法。在默認(rèn)情況下,該函數(shù)支持一些特定的模塊序列進(jìn)行融合。如果輸入模塊序列不符合這些支持的模式,則函數(shù)會(huì)嘗試使用 additional_fuser_method_mapping 中定義的自定義融合函數(shù)fuser_method。
  • 融合方法fuser_method :使用get_fuser_method() 函數(shù)根據(jù)types來選擇一個(gè)合適的融合函數(shù)。
  • – 在 get_fuser_method函數(shù)中調(diào)用了字典DEFAULT_OP_LIST_TO_FUSER_METHOD(定義了元組和融合函數(shù)之間的映射關(guān)系)。下面僅展示部分2d模塊融合
DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = {
    (nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn,
    (nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu,
    (nn.Conv2d, nn.ReLU): sequential_wrapper2(nni.ConvReLU2d),
    (nn.Linear, nn.BatchNorm1d): fuse_linear_bn,
    (nn.Linear, nn.ReLU): sequential_wrapper2(nni.LinearReLU),
    (nn.BatchNorm2d, nn.ReLU): sequential_wrapper2(nni.BNReLU2d),
}
  • 如果在特定模塊序列的additional_fuser_method_mapping中提供了自定義fuser函數(shù),則將使用該函數(shù)來代替默認(rèn)的fuser函數(shù)。如果找不到合適的fuser函數(shù),該函數(shù)將引發(fā)NotImplementedError
  • 定義new_mod :使用 [None] * len(mod_list)創(chuàng)建一個(gè)長(zhǎng)度為len(mod_list)的列表,這個(gè)列表中,每個(gè)元素都是一個(gè)nn.Module類型的可選對(duì)象,初始值為None。
  • 融合后的新模塊fused:使用fuser_method調(diào)用對(duì)應(yīng)的融合函數(shù),如 fuse_conv_bn(is_qat, conv, bn)得到一個(gè)模塊融合后的新的模塊(ConvBn2d)。該模塊包含了卷積層和BN層的參數(shù),并將其組合成一個(gè)新的運(yùn)算,該融合模塊的名稱默認(rèn)為ConvBn2d、ConvBn1d或ConvBn3d。fuse_conv_bn函數(shù)在后面進(jìn)行介紹。
  • 融合后,第一個(gè)for循環(huán)遍歷 mod_list列表中第一個(gè)模塊(mod_list[0])的handle_id(前向預(yù)處理鉤子函數(shù)的ID)和hook_fn(前向預(yù)處理鉤子函數(shù),在模塊前向傳遞時(shí)會(huì)被自動(dòng)調(diào)用,用于執(zhí)行某些操作,如記錄中間結(jié)果、打印日志等。)。
  • – 然后,將這些鉤子函數(shù)注冊(cè)到fused模塊中,使其能夠在后續(xù)計(jì)算中被調(diào)用。
  • – 接著,從mod_list[0]._forward_pre_hooks字典中刪除這些鉤子函數(shù),避免這些鉤子函數(shù)被重復(fù)調(diào)用。
  • 第一個(gè)for循環(huán)的作用是將mod_list列表中第一個(gè)模塊的前向預(yù)處理鉤子函數(shù)從原始模塊對(duì)象中轉(zhuǎn)移到融合模塊對(duì)象中,以確保在使用融合模塊進(jìn)行前向傳遞時(shí),所有需要的操作都能夠被執(zhí)行。
  • 第二個(gè)for循環(huán)將mod_list列表中最后一個(gè)模塊的前向鉤子函數(shù)注冊(cè)到fused模塊中,并從原始模塊對(duì)象的鉤子字典中刪除這些鉤子函數(shù)。
  • 與前向預(yù)處理鉤子函數(shù)不同,前向鉤子函數(shù)是在模塊的前向傳遞過程中執(zhí)行的,通常用于在模塊輸出計(jì)算完成后執(zhí)行某些操作,如統(tǒng)計(jì)模型輸出分布、進(jìn)行可視化等。
  • 最后,將融合好的fused模塊賦給前面定義的new_mod 列表的第一個(gè)元素,最后使用for循環(huán)補(bǔ)充identity()到new_mod列表,使其長(zhǎng)度和原始模塊長(zhǎng)度一致。

fuse_conv_bn

將給定的conv和bn模塊融合并返回融合后的模塊。

在此函數(shù)中構(gòu)建了一個(gè)fused_module_class_map字典,用于指定模塊類型與對(duì)應(yīng)的融合模塊類型之間的映射關(guān)系。

如果其類型在fused_module_class_map字典中有對(duì)應(yīng)的融合模塊類型,則將這些模塊融合為一個(gè)新的模塊(ConvBn2d),如果沒有對(duì)應(yīng)的融合模塊類型,則不對(duì)其進(jìn)行融合處理。

def fuse_conv_bn(is_qat, conv, bn):
    assert(conv.training == bn.training),\
        "Conv and BN both must be in the same mode (train or eval)."
    fused_module_class_map = {
        nn.Conv1d: nni.ConvBn1d,
        nn.Conv2d: nni.ConvBn2d,
        nn.Conv3d: nni.ConvBn3d,
    }
    if is_qat:
        assert bn.num_features == conv.out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
        assert bn.affine, 'Only support fusing BatchNorm2d with affine set to True'
        assert bn.track_running_stats, 'Only support fusing BatchNorm2d with tracking_running_stats set to True'
        fused_module_class = fused_module_class_map.get((type(conv)), None)
        if fused_module_class is not None:
            return fused_module_class(conv, bn)
        else:
            raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn)))
    else:
        return nn.utils.fuse_conv_bn_eval(conv, bn)

返回調(diào)用的 fuse_conv_bn_eval(conv, bn) 函數(shù)如下

返回一個(gè)新的融合模塊,該模塊包含了卷積層和BN層的參數(shù),并將其組合成一個(gè)新的運(yùn)算。

def fuse_conv_bn_eval(conv, bn, transpose=False):
    assert(not (conv.training or bn.training)), "Fusion only for eval!"
    fused_conv = copy.deepcopy(conv)
    fused_conv.weight, fused_conv.bias = \
        fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias,
                             bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias, transpose)
    return fused_conv

3. fuse_modules實(shí)際測(cè)試

3.1 modules_to_fuse參數(shù)的使用方法

1. 此參數(shù)的列表可以包含多個(gè)需要融合的組合,子模塊列表也可以,使用方法一

方法一:

modules_to_fuse = [ [‘conv1’, ‘bn1’, ‘relu1’], [‘submodule.conv’, ‘submodule.relu’]]

融合ResNet18中l(wèi)ayer1的conv和bn層如下:

print('\n Before fusion \n\n', r18_o.layer1)
r18_o.eval()
r18 = torch.quantization.fuse_modules(
    r18_o,
    [['conv1', 'bn1', 'relu'],
     ['layer1.0.conv1', 'layer1.0.bn1'], # , 'layer1.0.relu'],
     ['layer1.0.conv2', 'layer1.0.bn2'],
     ['layer1.1.conv1', 'layer1.1.bn1'], #, 'layer1.1.relu'],
     ['layer1.1.conv2', 'layer1.1.bn2']]
)
print('\n After fusion\n\n', r18.layer1)

結(jié)果:

ResNet18融合前:(僅顯示ResNet18中l(wèi)ayer1的網(wǎng)絡(luò)結(jié)構(gòu))

ResNet18融合后

此融合只將Conv2d和BN層進(jìn)行融合,從上面對(duì)比可以看到融合后的 (bn) 變成了 identity(),(conv) 中的Conv2d是原本Conv2d和BN融合的。

2. 如果要融合的module被Sequential封裝了,可使用方法二

方法二:

torch.quantization.fuse_modules(m, [‘0’, ‘1’, ‘2’], inplace=True)

1. 使用方法二對(duì)ResNet18中模塊進(jìn)行融合操作,融合代碼如下:

def fuse_model(self):
    for m in self.modules():
        if type(m) == BasicBlock:
            torch.quantization.fuse_modules(m, [['conv1', 'bn1', 'relu'], ['conv2', 'bn2']], inplace=True)

此處代碼是仿pytorch官方寫MobileNetV2模塊融合,這部分代碼寫在 class ResNet(nn.Module) 中,后面融合直接使用model.fuse_model(),得到的方法二融合ResNet18結(jié)果如下:

此處是分別對(duì)(conv2d、bn、relu)和(conv2d、bn)進(jìn)行融合融合

2. 使用方法二對(duì)MobileNetv2中模塊進(jìn)行融合操作

def fuse_model(self):
    for m in self.modules():
        if type(m) == ConvBNReLU:
            torch.quantization.fuse_modacules(m, ['0', '1', '2'], inplace=True)
        if type(m) == InvertedResidual:
            for idx in range(len(m.conv)):
                if type(m.conv[idx]) == nn.Conv2d:
                    torch.quantization.fuse_modules(m.conv, [str(idx), str(idx + 1)], inplace=True)

結(jié)果

MobileNetv2融合前(下面結(jié)果展示的是第一個(gè)殘差模塊,因此沒有第一個(gè)1x1的卷積)

MobileNetv2融合后

從此對(duì)比可以看到,融合前的conv2d、bn、relu融合成了ConvRelu2d(Conv2d,ReLU),這里面的Conv2d是融合前的Conv2d和BN融合的。

到此這篇關(guān)于pytorch中fuse_modules的文章就介紹到這了,更多相關(guān)pytorch中fuse_modules內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

  • python計(jì)算波峰波谷值的方法(極值點(diǎn))

    python計(jì)算波峰波谷值的方法(極值點(diǎn))

    這篇文章主要介紹了python求極值點(diǎn)(波峰波谷)求極值點(diǎn)主要用到了scipy庫,本文通過實(shí)例代碼給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2020-02-02
  • Python全局變量與局部變量區(qū)別及用法分析

    Python全局變量與局部變量區(qū)別及用法分析

    這篇文章主要介紹了Python全局變量與局部變量區(qū)別及用法,結(jié)合實(shí)例形式分析了Python全局變量與局部變量的定義、常見用法、區(qū)別及相關(guān)操作注意事項(xiàng),需要的朋友可以參考下
    2018-09-09
  • Python(TensorFlow框架)實(shí)現(xiàn)手寫數(shù)字識(shí)別系統(tǒng)的方法

    Python(TensorFlow框架)實(shí)現(xiàn)手寫數(shù)字識(shí)別系統(tǒng)的方法

    這篇文章主要介紹了Python(TensorFlow框架)實(shí)現(xiàn)手寫數(shù)字識(shí)別系統(tǒng)的方法。小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧
    2018-05-05
  • python3.6根據(jù)m3u8下載mp4視頻

    python3.6根據(jù)m3u8下載mp4視頻

    這篇文章主要為大家詳細(xì)介紹了python3.6根據(jù)m3u8下載mp4視頻,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下
    2019-06-06
  • 5 分鐘讀懂Python 中的 Hook 鉤子函數(shù)

    5 分鐘讀懂Python 中的 Hook 鉤子函數(shù)

    這篇文章主要介紹了5 分鐘掌握 Python 中的 Hook 鉤子函數(shù),本文通過實(shí)例代碼給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2020-12-12
  • 解決Python列表字符不區(qū)分大小寫的問題

    解決Python列表字符不區(qū)分大小寫的問題

    今天小編就為大家分享一篇解決Python列表字符不區(qū)分大小寫的問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧
    2019-12-12
  • python中的zip模塊

    python中的zip模塊

    這篇文章主要介紹了zip文件格式是通用的文檔壓縮標(biāo)準(zhǔn),在ziplib模塊中,使用ZipFile類來操作zip文件,感興趣的朋友參考如下
    2021-08-08
  • 學(xué)習(xí)python (1)

    學(xué)習(xí)python (1)

    學(xué)習(xí)python (1)...
    2006-10-10
  • python Dataframe字符串合并的操作方法

    python Dataframe字符串合并的操作方法

    Dataframe的字符串合并包括2種場(chǎng)景,1.合并df中其中幾列字符串;2.將df中的字符串與外部字符串合并,本文主要介紹在Python下對(duì)Dataframe進(jìn)行字符串合并操作的方法,感興趣的朋友跟隨小編一起看看吧
    2024-06-06
  • python_opencv用線段畫封閉矩形的實(shí)例

    python_opencv用線段畫封閉矩形的實(shí)例

    今天小編就為大家分享一篇python_opencv用線段畫封閉矩形的實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧
    2018-12-12

最新評(píng)論