Pytorch中的數(shù)據(jù)集劃分&正則化方法
1.訓(xùn)練集&驗(yàn)證集&測(cè)試集
訓(xùn)練集:訓(xùn)練數(shù)據(jù)
驗(yàn)證集:驗(yàn)證不同算法(比如利用網(wǎng)格搜索對(duì)超參數(shù)進(jìn)行調(diào)整等),檢驗(yàn)?zāi)姆N更有效
測(cè)試集:正確評(píng)估分類器的性能
正常流程:驗(yàn)證集會(huì)記錄每個(gè)時(shí)間戳的參數(shù),在加載test數(shù)據(jù)前會(huì)加載那個(gè)最好的參數(shù),再來評(píng)估。比方說訓(xùn)練完6000個(gè)epoch后,發(fā)現(xiàn)在第3520個(gè)epoch的validation表現(xiàn)最好,測(cè)試時(shí)會(huì)加載第3520個(gè)epoch的參數(shù)。
import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms #超參數(shù) batch_size=200 learning_rate=0.01 epochs=10 #獲取訓(xùn)練數(shù)據(jù) train_db = datasets.MNIST('../data', train=True, download=True, #train=True則得到的是訓(xùn)練集 transform=transforms.Compose([ #transform進(jìn)行數(shù)據(jù)預(yù)處理 transforms.ToTensor(), #轉(zhuǎn)成Tensor類型的數(shù)據(jù) transforms.Normalize((0.1307,), (0.3081,)) #進(jìn)行數(shù)據(jù)標(biāo)準(zhǔn)化(減去均值除以方差) ])) #DataLoader把訓(xùn)練數(shù)據(jù)分成多個(gè)小組,此函數(shù)每次拋出一組數(shù)據(jù)。直至把所有的數(shù)據(jù)都拋出。就是做一個(gè)數(shù)據(jù)的初始化 train_loader = torch.utils.data.DataLoader(train_db, batch_size=batch_size, shuffle=True) #獲取測(cè)試數(shù)據(jù) test_db = datasets.MNIST('../data', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])) test_loader = torch.utils.data.DataLoader(test_db, batch_size=batch_size, shuffle=True) #將訓(xùn)練集拆分成訓(xùn)練集和驗(yàn)證集 print('train:', len(train_db), 'test:', len(test_db)) #train: 60000 test: 10000 train_db, val_db = torch.utils.data.random_split(train_db, [50000, 10000]) print('db1:', len(train_db), 'db2:', len(val_db)) #db1: 50000 db2: 10000 train_loader = torch.utils.data.DataLoader(train_db, batch_size=batch_size, shuffle=True) val_loader = torch.utils.data.DataLoader(val_db, batch_size=batch_size, shuffle=True) class MLP(nn.Module): def __init__(self): super(MLP, self).__init__() self.model = nn.Sequential( #定義網(wǎng)絡(luò)的每一層, nn.Linear(784, 200), nn.ReLU(inplace=True), nn.Linear(200, 200), nn.ReLU(inplace=True), nn.Linear(200, 10), nn.ReLU(inplace=True), ) def forward(self, x): x = self.model(x) return x net = MLP() #定義sgd優(yōu)化器,指明優(yōu)化參數(shù)、學(xué)習(xí)率,net.parameters()得到這個(gè)類所定義的網(wǎng)絡(luò)的參數(shù)[[w1,b1,w2,b2,...] optimizer = optim.SGD(net.parameters(), lr=learning_rate) criteon = nn.CrossEntropyLoss() for epoch in range(epochs): for batch_idx, (data, target) in enumerate(train_loader): data = data.view(-1, 28*28) #將二維的圖片數(shù)據(jù)攤平[樣本數(shù),784] logits = net(data) #前向傳播 loss = criteon(logits, target) #nn.CrossEntropyLoss()自帶Softmax optimizer.zero_grad() #梯度信息清空 loss.backward() #反向傳播獲取梯度 optimizer.step() #優(yōu)化器更新 if batch_idx % 100 == 0: #每100個(gè)batch輸出一次信息 print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item())) #驗(yàn)證集用來檢測(cè)訓(xùn)練是否過擬合 val_loss = 0 correct = 0 for data, target in val_loader: data = data.view(-1, 28 * 28) logits = net(data) val_loss += criteon(logits, target).item() pred = logits.data.max(dim=1)[1] correct += pred.eq(target.data).sum() val_loss /= len(val_loader.dataset) print('\nVAL set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( val_loss, correct, len(val_loader.dataset), 100. * correct / len(val_loader.dataset))) #測(cè)試集用來評(píng)估 test_loss = 0 correct = 0 #correct記錄正確分類的樣本數(shù) for data, target in test_loader: data = data.view(-1, 28 * 28) logits = net(data) test_loss += criteon(logits, target).item() #其實(shí)就是criteon(logits, target)的值,標(biāo)量 pred = logits.data.max(dim=1)[1] #也可以寫成pred=logits.argmax(dim=1) correct += pred.eq(target.data).sum() test_loss /= len(test_loader.dataset) print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
2.正則化
正則化可以解決過擬合問題。
2.1L2范數(shù)(更常用)
在定義優(yōu)化器的時(shí)候設(shè)定weigth_decay,即L2范數(shù)前面的λ參數(shù)。
optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, weight_decay=0.01)
2.2L1范數(shù)
Pytorch沒有直接可以調(diào)用的方法,實(shí)現(xiàn)如下:
3.動(dòng)量(Momentum)
Adam優(yōu)化器內(nèi)置了momentum,SGD需要手動(dòng)設(shè)置。
optimizer = torch.optim.SGD(model.parameters(), args=lr, momentum=args.momentum, weight_decay=args.weight_decay)
4.學(xué)習(xí)率衰減
torch.optim.lr_scheduler 中提供了基于多種epoch數(shù)目調(diào)整學(xué)習(xí)率的方法。
4.1torch.optim.lr_scheduler.ReduceLROnPlateau:基于測(cè)量指標(biāo)對(duì)學(xué)習(xí)率進(jìn)行動(dòng)態(tài)的下降
torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08)
訓(xùn)練過程中,optimizer會(huì)把learning rate 交給scheduler管理,當(dāng)指標(biāo)(比如loss)連續(xù)patience次數(shù)還沒有改進(jìn)時(shí),需要降低學(xué)習(xí)率,factor為每次下降的比例。
scheduler.step(loss_val)每調(diào)用一次就會(huì)監(jiān)聽一次loss_val。
4.2torch.optim.lr_scheduler.StepLR:基于epoch
torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1)
當(dāng)epoch每過stop_size時(shí),學(xué)習(xí)率都變?yōu)槌跏紝W(xué)習(xí)率的gamma倍。
5.提前停止(防止overfitting)
基于經(jīng)驗(yàn)值。
6.Dropout隨機(jī)失活
遍歷每一層,設(shè)置消除神經(jīng)網(wǎng)絡(luò)中的節(jié)點(diǎn)概率,得到精簡(jiǎn)后的一個(gè)樣本。
torch.nn.Dropout(p=dropout_prob)
p表示的示的是刪除節(jié)點(diǎn)數(shù)的比例(Tip:tensorflow中keep_prob表示保留節(jié)點(diǎn)數(shù)的比例,不要混淆)
測(cè)試階段無需使用dropout,所以在train之前執(zhí)行net_dropped.train()相當(dāng)于啟用dropout,測(cè)試之前執(zhí)行net_dropped.eval()相當(dāng)于不啟用dropout。
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Django 連接sql server數(shù)據(jù)庫(kù)的方法
這篇文章主要介紹了Django 連接sql server數(shù)據(jù)庫(kù)的方法,小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧2018-06-06基于sklearn實(shí)現(xiàn)Bagging算法(python)
這篇文章主要為大家詳細(xì)介紹了基于sklearn實(shí)現(xiàn)Bagging算法,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2019-07-07用python實(shí)現(xiàn)爬取奧特曼圖片實(shí)例
大家好,本篇文章主要講的是用python實(shí)現(xiàn)爬取奧特曼圖片實(shí)例,感興趣的同學(xué)趕快來看一看吧,對(duì)你有幫助的話記得收藏一下2022-02-02如何在Python3中使用telnetlib模塊連接網(wǎng)絡(luò)設(shè)備
這篇文章主要介紹了如何在Python3中使用telnetlib模塊連接網(wǎng)絡(luò)設(shè)備,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-09-09跟老齊學(xué)Python之賦值,簡(jiǎn)單也不簡(jiǎn)單
在《初識(shí)永遠(yuǎn)強(qiáng)大的函數(shù)》一文中,有一節(jié)專門討論“取名字的學(xué)問”,就是有關(guān)變量名稱的問題,本溫故而知新的原則,這里要復(fù)習(xí)一下2014-09-09使用Python編寫簡(jiǎn)單網(wǎng)絡(luò)爬蟲抓取視頻下載資源
從上一篇文章的評(píng)論中看出似乎很多童鞋都比較關(guān)注爬蟲的源代碼。所有本文就使用Python編寫簡(jiǎn)單網(wǎng)絡(luò)爬蟲抓取視頻下載資源做了很詳細(xì)的記錄,幾乎每一步都介紹給大家,希望對(duì)大家能有所幫助2014-11-11