MxNet預訓練模型到Pytorch模型的轉換方式
預訓練模型在不同深度學習框架中的轉換是一種常見的任務。今天剛好DPN預訓練模型轉換問題,順手將這個過程記錄一下。
核心轉換函數(shù)如下所示:
def convert_from_mxnet(model, checkpoint_prefix, debug=False): _, mxnet_weights, mxnet_aux = mxnet.model.load_checkpoint(checkpoint_prefix, 0) remapped_state = {} for state_key in model.state_dict().keys(): k = state_key.split('.') aux = False mxnet_key = '' if k[0] == 'features': if k[1] == 'conv1_1': # input block mxnet_key += 'conv1_x_1__' if k[2] == 'bn': mxnet_key += 'relu-sp__bn_' aux, key_add = _convert_bn(k[3]) mxnet_key += key_add else: assert k[3] == 'weight' mxnet_key += 'conv_' + k[3] elif k[1] == 'conv5_bn_ac': # bn + ac at end of features block mxnet_key += 'conv5_x_x__relu-sp__bn_' assert k[2] == 'bn' aux, key_add = _convert_bn(k[3]) mxnet_key += key_add else: # middle blocks if model.b and 'c1x1_c' in k[2]: bc_block = True # b-variant split c-block special treatment else: bc_block = False ck = k[1].split('_') mxnet_key += ck[0] + '_x__' + ck[1] + '_' ck = k[2].split('_') mxnet_key += ck[0] + '-' + ck[1] if ck[1] == 'w' and len(ck) > 2: mxnet_key += '(s/2)' if ck[2] == 's2' else '(s/1)' mxnet_key += '__' if k[3] == 'bn': mxnet_key += 'bn_' if bc_block else 'bn__bn_' aux, key_add = _convert_bn(k[4]) mxnet_key += key_add else: ki = 3 if bc_block else 4 assert k[ki] == 'weight' mxnet_key += 'conv_' + k[ki] elif k[0] == 'classifier': if 'fc6-1k_weight' in mxnet_weights: mxnet_key += 'fc6-1k_' else: mxnet_key += 'fc6_' mxnet_key += k[1] else: assert False, 'Unexpected token' if debug: print(mxnet_key, '=> ', state_key, end=' ') mxnet_array = mxnet_aux[mxnet_key] if aux else mxnet_weights[mxnet_key] torch_tensor = torch.from_numpy(mxnet_array.asnumpy()) if k[0] == 'classifier' and k[1] == 'weight': torch_tensor = torch_tensor.view(torch_tensor.size() + (1, 1)) remapped_state[state_key] = torch_tensor if debug: print(list(torch_tensor.size()), torch_tensor.mean(), torch_tensor.std()) model.load_state_dict(remapped_state) return model
從中可以看出,其轉換步驟如下:
(1)創(chuàng)建pytorch的網(wǎng)絡結構模型,設為model
(2)利用mxnet來讀取其存儲的預訓練模型,得到mxnet_weights;
(3)遍歷加載后模型mxnet_weights的state_dict().keys
(4)對一些指定的key值,需要進行相應的處理和轉換
(5)對修改鍵名之后的key利用numpy之間的轉換來實現(xiàn)加載。
為了實現(xiàn)上述轉換,首先pip安裝mxnet,現(xiàn)在新版的mxnet安裝還是非常方便的。
第二步,運行轉換程序,實現(xiàn)預訓練模型的轉換。
可以看到在相當?shù)奈募A下已經(jīng)出現(xiàn)了轉換后的模型。
以上這篇MxNet預訓練模型到Pytorch模型的轉換方式就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
python 實現(xiàn)網(wǎng)上商城,轉賬,存取款等功能的信用卡系統(tǒng)
本篇文章主要介紹 基于python 實現(xiàn)信用卡系統(tǒng),附有代碼實例,對于用python 開發(fā)網(wǎng)絡上傳系統(tǒng)具有參考價值,有需要的朋友可以看下2016-07-07python實現(xiàn)socket+threading處理多連接的方法
今天小編就為大家分享一篇python實現(xiàn)socket+threading處理多連接的方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-07-07python 利用 PIL 將數(shù)組值轉成圖片的實現(xiàn)
這篇文章主要介紹了python 利用 PIL 將數(shù)組值轉成圖片的實現(xiàn),文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2021-04-04pandas 實現(xiàn)字典轉換成DataFrame的方法
今天小編就為大家分享一篇pandas 實現(xiàn)字典轉換成DataFrame的方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-07-07python實現(xiàn)在遍歷列表時,直接對dict元素增加字段的方法
今天小編就為大家分享一篇python實現(xiàn)在遍歷列表時,直接對dict元素增加字段的方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-01-01