keras 兩種訓(xùn)練模型方式詳解fit和fit_generator(節(jié)省內(nèi)存)
第一種,fit
import keras from keras.models import Sequential from keras.layers import Dense import numpy as np from sklearn.preprocessing import LabelEncoder from sklearn.preprocessing import OneHotEncoder from sklearn.model_selection import train_test_split #讀取數(shù)據(jù) x_train = np.load("D:\\machineTest\\testmulPE_win7\\data_sprase.npy")[()] y_train = np.load("D:\\machineTest\\testmulPE_win7\\lable_sprase.npy") # 獲取分類類別總數(shù) classes = len(np.unique(y_train)) #對(duì)label進(jìn)行one-hot編碼,必須的 label_encoder = LabelEncoder() integer_encoded = label_encoder.fit_transform(y_train) onehot_encoder = OneHotEncoder(sparse=False) integer_encoded = integer_encoded.reshape(len(integer_encoded), 1) y_train = onehot_encoder.fit_transform(integer_encoded) #shuffle X_train, X_test, y_train, y_test = train_test_split(x_train, y_train, test_size=0.3, random_state=0) model = Sequential() model.add(Dense(units=1000, activation='relu', input_dim=784)) model.add(Dense(units=classes, activation='softmax')) model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) model.fit(X_train, y_train, epochs=50, batch_size=128) score = model.evaluate(X_test, y_test, batch_size=128) # #fit參數(shù)詳情 # keras.models.fit( # self, # x=None, #訓(xùn)練數(shù)據(jù) # y=None, #訓(xùn)練數(shù)據(jù)label標(biāo)簽 # batch_size=None, #每經(jīng)過多少個(gè)sample更新一次權(quán)重,defult 32 # epochs=1, #訓(xùn)練的輪數(shù)epochs # verbose=1, #0為不在標(biāo)準(zhǔn)輸出流輸出日志信息,1為輸出進(jìn)度條記錄,2為每個(gè)epoch輸出一行記錄 # callbacks=None,#list,list中的元素為keras.callbacks.Callback對(duì)象,在訓(xùn)練過程中會(huì)調(diào)用list中的回調(diào)函數(shù) # validation_split=0., #浮點(diǎn)數(shù)0-1,將訓(xùn)練集中的一部分比例作為驗(yàn)證集,然后下面的驗(yàn)證集validation_data將不會(huì)起到作用 # validation_data=None, #驗(yàn)證集 # shuffle=True, #布爾值和字符串,如果為布爾值,表示是否在每一次epoch訓(xùn)練前隨機(jī)打亂輸入樣本的順序,如果為"batch",為處理HDF5數(shù)據(jù) # class_weight=None, #dict,分類問題的時(shí)候,有的類別可能需要額外關(guān)注,分錯(cuò)的時(shí)候給的懲罰會(huì)比較大,所以權(quán)重會(huì)調(diào)高,體現(xiàn)在損失函數(shù)上面 # sample_weight=None, #array,和輸入樣本對(duì)等長度,對(duì)輸入的每個(gè)特征+個(gè)權(quán)值,如果是時(shí)序的數(shù)據(jù),則采用(samples,sequence_length)的矩陣 # initial_epoch=0, #如果之前做了訓(xùn)練,則可以從指定的epoch開始訓(xùn)練 # steps_per_epoch=None, #將一個(gè)epoch分為多少個(gè)steps,也就是劃分一個(gè)batch_size多大,比如steps_per_epoch=10,則就是將訓(xùn)練集分為10份,不能和batch_size共同使用 # validation_steps=None, #當(dāng)steps_per_epoch被啟用的時(shí)候才有用,驗(yàn)證集的batch_size # **kwargs #用于和后端交互 # ) # # 返回的是一個(gè)History對(duì)象,可以通過History.history來查看訓(xùn)練過程,loss值等等
第二種,fit_generator(節(jié)省內(nèi)存)
# 第二種,可以節(jié)省內(nèi)存 ''' Created on 2018-4-11 fit_generate.txt,后面兩列為lable,已經(jīng)one-hot編碼 1 2 0 1 2 3 1 0 1 3 0 1 1 4 0 1 2 4 1 0 2 5 1 0 ''' import keras from keras.models import Sequential from keras.layers import Dense import numpy as np from sklearn.model_selection import train_test_split count =1 def generate_arrays_from_file(path): global count while 1: datas = np.loadtxt(path,delimiter=' ',dtype="int") x = datas[:,:2] y = datas[:,2:] print("count:"+str(count)) count = count+1 yield (x,y) x_valid = np.array([[1,2],[2,3]]) y_valid = np.array([[0,1],[1,0]]) model = Sequential() model.add(Dense(units=1000, activation='relu', input_dim=2)) model.add(Dense(units=2, activation='softmax')) model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) model.fit_generator(generate_arrays_from_file("D:\\fit_generate.txt"),steps_per_epoch=10, epochs=2,max_queue_size=1,validation_data=(x_valid, y_valid),workers=1) # steps_per_epoch 每執(zhí)行一次steps,就去執(zhí)行一次生產(chǎn)函數(shù)generate_arrays_from_file # max_queue_size 從生產(chǎn)函數(shù)中出來的數(shù)據(jù)時(shí)可以緩存在queue隊(duì)列中 # 輸出如下: # Epoch 1/2 # count:1 # count:2 # # 1/10 [==>...........................] - ETA: 2s - loss: 0.7145 - acc: 0.3333count:3 # count:4 # count:5 # count:6 # count:7 # # 7/10 [====================>.........] - ETA: 0s - loss: 0.7001 - acc: 0.4286count:8 # count:9 # count:10 # count:11 # # 10/10 [==============================] - 0s 36ms/step - loss: 0.6960 - acc: 0.4500 - val_loss: 0.6794 - val_acc: 0.5000 # Epoch 2/2 # # 1/10 [==>...........................] - ETA: 0s - loss: 0.6829 - acc: 0.5000count:12 # count:13 # count:14 # count:15 # # 5/10 [==============>...............] - ETA: 0s - loss: 0.6800 - acc: 0.5000count:16 # count:17 # count:18 # count:19 # count:20 # # 10/10 [==============================] - 0s 11ms/step - loss: 0.6766 - acc: 0.5000 - val_loss: 0.6662 - val_acc: 0.5000
補(bǔ)充知識(shí):
自動(dòng)生成數(shù)據(jù)還可以繼承keras.utils.Sequence,然后寫自己的生成數(shù)據(jù)類:
keras數(shù)據(jù)自動(dòng)生成器,繼承keras.utils.Sequence,結(jié)合fit_generator實(shí)現(xiàn)節(jié)約內(nèi)存訓(xùn)練
#coding=utf-8 ''' Created on 2018-7-10 ''' import keras import math import os import cv2 import numpy as np from keras.models import Sequential from keras.layers import Dense class DataGenerator(keras.utils.Sequence): def __init__(self, datas, batch_size=1, shuffle=True): self.batch_size = batch_size self.datas = datas self.indexes = np.arange(len(self.datas)) self.shuffle = shuffle def __len__(self): #計(jì)算每一個(gè)epoch的迭代次數(shù) return math.ceil(len(self.datas) / float(self.batch_size)) def __getitem__(self, index): #生成每個(gè)batch數(shù)據(jù),這里就根據(jù)自己對(duì)數(shù)據(jù)的讀取方式進(jìn)行發(fā)揮了 # 生成batch_size個(gè)索引 batch_indexs = self.indexes[index*self.batch_size:(index+1)*self.batch_size] # 根據(jù)索引獲取datas集合中的數(shù)據(jù) batch_datas = [self.datas[k] for k in batch_indexs] # 生成數(shù)據(jù) X, y = self.data_generation(batch_datas) return X, y def on_epoch_end(self): #在每一次epoch結(jié)束是否需要進(jìn)行一次隨機(jī),重新隨機(jī)一下index if self.shuffle == True: np.random.shuffle(self.indexes) def data_generation(self, batch_datas): images = [] labels = [] # 生成數(shù)據(jù) for i, data in enumerate(batch_datas): #x_train數(shù)據(jù) image = cv2.imread(data) image = list(image) images.append(image) #y_train數(shù)據(jù) right = data.rfind("\\",0) left = data.rfind("\\",0,right)+1 class_name = data[left:right] if class_name=="dog": labels.append([0,1]) else: labels.append([1,0]) #如果為多輸出模型,Y的格式要變一下,外層list格式包裹numpy格式是list[numpy_out1,numpy_out2,numpy_out3] return np.array(images), np.array(labels) # 讀取樣本名稱,然后根據(jù)樣本名稱去讀取數(shù)據(jù) class_num = 0 train_datas = [] for file in os.listdir("D:/xxx"): file_path = os.path.join("D:/xxx", file) if os.path.isdir(file_path): class_num = class_num + 1 for sub_file in os.listdir(file_path): train_datas.append(os.path.join(file_path, sub_file)) # 數(shù)據(jù)生成器 training_generator = DataGenerator(train_datas) #構(gòu)建網(wǎng)絡(luò) model = Sequential() model.add(Dense(units=64, activation='relu', input_dim=784)) model.add(Dense(units=2, activation='softmax')) model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy']) model.fit_generator(training_generator, epochs=50,max_queue_size=10,workers=1)
以上這篇keras 兩種訓(xùn)練模型方式詳解fit和fit_generator(節(jié)省內(nèi)存)就是小編分享給大家的全部內(nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python 批量驗(yàn)證和添加手機(jī)號(hào)碼為企業(yè)微信聯(lián)系人
你是否也有過需要添加很多微信好友的時(shí)候,一個(gè)個(gè)輸入添加太麻煩了,本篇文章手把手教你用Python替我們完成這繁瑣的操作,大家可以在過程中查缺補(bǔ)漏,看看自己掌握程度怎么樣2021-10-10Python解決asyncio文件描述符最大數(shù)量限制的問題
這篇文章主要介紹了Python解決asyncio文件描述符最大數(shù)量限制的問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2024-06-06Python之兩種模式的生產(chǎn)者消費(fèi)者模型詳解
今天小編就為大家分享一篇Python之兩種模式的生產(chǎn)者消費(fèi)者模型詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2018-10-10用selenium解決滑塊驗(yàn)證碼的實(shí)現(xiàn)步驟
驗(yàn)證碼作為一種自然人的機(jī)器人的判別工具,被廣泛的用于各種防止程序做自動(dòng)化的場(chǎng)景中,下面這篇文章主要給大家介紹了關(guān)于用selenium解決滑塊驗(yàn)證碼的實(shí)現(xiàn)步驟,需要的朋友可以參考下2023-02-02Python 3.6打包成EXE可執(zhí)行程序的實(shí)現(xiàn)
這篇文章主要介紹了Python 3.6打包成EXE可執(zhí)行程序的實(shí)現(xiàn),文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-10-10對(duì)pandas里的loc并列條件索引的實(shí)例講解
今天小編就為大家分享一篇對(duì)pandas里的loc并列條件索引的實(shí)例講解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2018-11-11Python第三方包之DingDingBot釘釘機(jī)器人
這篇文章主要介紹了Python第三方包之DingDingBot釘釘機(jī)器人,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-04-04Python基于pygame實(shí)現(xiàn)圖片代替鼠標(biāo)移動(dòng)效果
這篇文章主要介紹了Python基于pygame實(shí)現(xiàn)圖片代替鼠標(biāo)移動(dòng)效果,可實(shí)現(xiàn)將鼠標(biāo)箭頭轉(zhuǎn)換成圖形的功能,涉及pygame圖形操作的相關(guān)技巧,具有一定參考借鑒價(jià)值,需要的朋友可以參考下2015-11-11解決windows上安裝tensorflow時(shí)報(bào)錯(cuò),“DLL load failed: 找不到指定的模塊”的問題
這篇文章主要介紹了解決windows上安裝tensorflow時(shí)報(bào)錯(cuò),“DLL load failed: 找不到指定的模塊”的問題,本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-05-05