在pytorch中如何查看模型model參數(shù)parameters
pytorch查看模型model參數(shù)parameters
示例1:pytorch自帶的faster r-cnn模型
import torch import torchvision model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) for name, p in model.named_parameters(): print(name) print(p.requires_grad) print(...) #或者 for p in model.parameters(): print(p) print(...)
示例2:自定義網(wǎng)絡(luò)模型
class Net(nn.Module): def __init__(self): super(Net, self).__init__() cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512] self.features = self._vgg_layers(cfg) def _vgg_layers(self, cfg): layers = [] in_channels = 3 for x in cfg: if x == 'M': layers += [nn.MaxPool2d(kernel_size=2, stride=2)] else: layers += [nn.Conv2d(in_channels, x ,kernel_size=3, padding=1), nn.BatchNorm2d(x), nn.ReLU(inplace=True) ] in_channels = x return nn.Sequential(*layers) def forward(self, data): out_map = self.features(data) return out_map Model = Net() for name, p in model.named_parameters(): print(name) print(p.requires_grad) print(...) #或者 for p in model.parameters(): print(p) print(...)
在自定義網(wǎng)絡(luò)中,model.parameters()方法繼承自nn.Module
pytorch查看模型參數(shù)總結(jié)
1:DNN_printer
其中(3, 32, 32)是輸入的大小,其他方法中的參數(shù)同理
from DNN_printer import DNN_printer batch_size = 512 def train(epoch): print('\nEpoch: %d' % epoch) net.train() train_loss = 0 correct = 0 total = 0 // put the code here and you can get the result DNN_printer(net, (3, 32, 32),batch_size)
結(jié)果
2:parameters
def cnn_paras_count(net): """cnn參數(shù)量統(tǒng)計, 使用方式cnn_paras_count(net)""" # Find total parameters and trainable parameters total_params = sum(p.numel() for p in net.parameters()) print(f'{total_params:,} total parameters.') total_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad) print(f'{total_trainable_params:,} training parameters.') return total_params, total_trainable_params cnn_paras_count(net)
直接輸出參數(shù)量,然后自己計算
需要注意的是,一般模型中參數(shù)是以float32保存的,也就是一個參數(shù)由4個bytes表示,那么就可以將參數(shù)量轉(zhuǎn)化為存儲大小。
例如:
- 44426個參數(shù)*4 / 1024 ≈ 174KB
3:get_model_complexity_info()
from ptflops import get_model_complexity_info from torchvision import models net = models.mobilenet_v2() ops, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True, print_per_layer_stat=True, verbose=True)
4:torchstat
from torchstat import stat import torchvision.models as models model = models.resnet152() stat(model, (3, 224, 224))
輸出
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
python實現(xiàn)的登錄與提交表單數(shù)據(jù)功能示例
這篇文章主要介紹了python實現(xiàn)的登錄與提交表單數(shù)據(jù)功能,結(jié)合實例形式分析了Python表單登錄相關(guān)的請求與響應(yīng)操作實現(xiàn)技巧,需要的朋友可以參考下2019-09-09Python實現(xiàn)讀取并寫入Excel文件過程解析
這篇文章主要介紹了Python實現(xiàn)讀取并寫入Excel文件過程解析,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下2020-05-05Python中各類Excel表格批量合并問題的實現(xiàn)思路與案例
在日常工作中,可能會遇到各類表格合并的需求。本文主要介紹了Python中各類Excel表格批量合并問題的實現(xiàn)思路與案例,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2023-01-01python爬蟲使用requests發(fā)送post請求示例詳解
這篇文章主要介紹了python爬蟲使用requests發(fā)送post請求示例詳解,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-08-08Tensorflow與RNN、雙向LSTM等的踩坑記錄及解決
這篇文章主要介紹了Tensorflow與RNN、雙向LSTM等的踩坑記錄及解決方案,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2021-05-05對網(wǎng)站內(nèi)嵌gradio應(yīng)用的輸入輸出做審核實現(xiàn)詳解
這篇文章主要為大家介紹了對網(wǎng)站內(nèi)嵌gradio應(yīng)用的輸入輸出做審核實現(xiàn)詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-04-04