pytorch 6 batch_train 批訓練操作
看代碼吧~
import torch import torch.utils.data as Data torch.manual_seed(1) # reproducible # BATCH_SIZE = 5 BATCH_SIZE = 8 # 每次使用8個數(shù)據(jù)同時傳入網(wǎng)路 x = torch.linspace(1, 10, 10) # this is x data (torch tensor) y = torch.linspace(10, 1, 10) # this is y data (torch tensor) torch_dataset = Data.TensorDataset(x, y) loader = Data.DataLoader( dataset=torch_dataset, # torch TensorDataset format batch_size=BATCH_SIZE, # mini batch size shuffle=False, # 設置不隨機打亂數(shù)據(jù) random shuffle for training num_workers=2, # 使用兩個進程提取數(shù)據(jù),subprocesses for loading data ) def show_batch(): for epoch in range(3): # 全部的數(shù)據(jù)使用3遍,train entire dataset 3 times for step, (batch_x, batch_y) in enumerate(loader): # for each training step # train your data... print('Epoch: ', epoch, '| Step: ', step, '| batch x: ', batch_x.numpy(), '| batch y: ', batch_y.numpy()) if __name__ == '__main__': show_batch()
BATCH_SIZE = 8 , 所有數(shù)據(jù)利用三次
Epoch: 0 | Step: 0 | batch x: [1. 2. 3. 4. 5. 6. 7. 8.] | batch y: [10. 9. 8. 7. 6. 5. 4. 3.] Epoch: 0 | Step: 1 | batch x: [ 9. 10.] | batch y: [2. 1.] Epoch: 1 | Step: 0 | batch x: [1. 2. 3. 4. 5. 6. 7. 8.] | batch y: [10. 9. 8. 7. 6. 5. 4. 3.] Epoch: 1 | Step: 1 | batch x: [ 9. 10.] | batch y: [2. 1.] Epoch: 2 | Step: 0 | batch x: [1. 2. 3. 4. 5. 6. 7. 8.] | batch y: [10. 9. 8. 7. 6. 5. 4. 3.] Epoch: 2 | Step: 1 | batch x: [ 9. 10.] | batch y: [2. 1.]
補充:pytorch批訓練bug
問題描述:
在進行pytorch神經(jīng)網(wǎng)絡批訓練的時候,有時會出現(xiàn)報錯
TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'torch.autograd.variable.Variable'>
解決辦法:
第一步:
檢查(重點!?。。。?:
train_dataset = Data.TensorDataset(train_x, train_y)
train_x,和train_y格式,要求是tensor類,我第一次出錯就是因為傳入的是variable
可以這樣將數(shù)據(jù)變?yōu)閠ensor類:
train_x = torch.FloatTensor(train_x)
第二步:
train_loader = Data.DataLoader( dataset=train_dataset, batch_size=batch_size, shuffle=True )
實例化一個DataLoader對象
第三步:
for epoch in range(epochs): for step, (batch_x, batch_y) in enumerate(train_loader): batch_x, batch_y = Variable(batch_x), Variable(batch_y)
這樣就可以批訓練了
需要注意的是:train_loader輸出的是tensor,在訓練網(wǎng)絡時,需要變成Variable
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
探索Python中zoneinfo模塊處理時區(qū)操作實例
這篇文章主要為大家介紹了探索Python中zoneinfo模塊的用法實例,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪2024-01-01Python 爬蟲實現(xiàn)增加播客訪問量的方法實現(xiàn)
這篇文章主要介紹了Python 爬蟲實現(xiàn)增加播客訪問量的方法實現(xiàn),文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2019-10-10Django多數(shù)據(jù)庫配置及逆向生成model教程
這篇文章主要介紹了Django多數(shù)據(jù)庫配置及逆向生成model教程,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-03-03Pycharm出現(xiàn)卡頓、反應慢及CPU占用高等問題解決
相信很多剛開始使用pycharm不太熟練的小伙伴,每天一開機打開pycharm總是卡半天,不知道的還以為是電腦卡了或者啥問題的,下面這篇文章主要給大家介紹了關于Pycharm出現(xiàn)卡頓、反應慢及CPU占用高等問題解決的相關資料,需要的朋友可以參考下2023-06-06學習Python,你還不知道m(xù)ain函數(shù)嗎
Python?中的?main?函數(shù)充當程序的執(zhí)行點,在?Python?編程中定義?main?函數(shù)是啟動程序執(zhí)行的必要條件。本文就來帶大家深入了解一下main函數(shù),感興趣的可以了解一下2022-09-09