keras導(dǎo)入weights方式
keras源碼engine中toplogy.py定義了加載權(quán)重的函數(shù):
load_weights(self, filepath, by_name=False)
其中默認(rèn)by_name為False,這時候加載權(quán)重按照網(wǎng)絡(luò)拓?fù)浣Y(jié)構(gòu)加載,適合直接使用keras中自帶的網(wǎng)絡(luò)模型,如VGG16
VGG19/resnet50等,源碼描述如下:
If `by_name` is False (default) weights are loaded
based on the network's topology, meaning the architecture
should be the same as when the weights were saved.
Note that layers that don't have weights are not taken
into account in the topological ordering, so adding or
removing layers is fine as long as they don't have weights.
若將by_name改為True則加載權(quán)重按照layer的name進(jìn)行,layer的name相同時加載權(quán)重,適合用于改變了
模型的相關(guān)結(jié)構(gòu)或增加了節(jié)點(diǎn)但利用了原網(wǎng)絡(luò)的主體結(jié)構(gòu)情況下使用,源碼描述如下:
If `by_name` is True, weights are loaded into layers
only if they share the same name. This is useful
for fine-tuning or transfer-learning models where
some of the layers have changed.
在進(jìn)行邊緣檢測時,利用VGG網(wǎng)絡(luò)的主體結(jié)構(gòu),網(wǎng)絡(luò)中增加反卷積層,這時加載權(quán)重應(yīng)該使用
model.load_weights(filepath,by_name=True)
補(bǔ)充知識:Keras下實(shí)現(xiàn)mnist手寫數(shù)字
之前一直在用tensorflow,被同學(xué)推薦來用keras了,把之前文檔中的mnist手寫數(shù)字?jǐn)?shù)據(jù)集拿來練手,
代碼如下。
import struct import numpy as np import os import keras from keras.models import Sequential from keras.layers import Dense from keras.optimizers import SGD def load_mnist(path, kind): labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind) images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind) with open(labels_path, 'rb') as lbpath: magic, n = struct.unpack('>II', lbpath.read(8)) labels = np.fromfile(lbpath, dtype=np.uint8) with open(images_path, 'rb') as imgpath: magic, num, rows, cols = struct.unpack(">IIII", imgpath.read(16)) images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784) #28*28=784 return images, labels #loading train and test data X_train, Y_train = load_mnist('.\\data', kind='train') X_test, Y_test = load_mnist('.\\data', kind='t10k') #turn labels to one_hot code Y_train_ohe = keras.utils.to_categorical(Y_train, num_classes=10) #define models model = Sequential() model.add(Dense(input_dim=X_train.shape[1],output_dim=50,init='uniform',activation='tanh')) model.add(Dense(input_dim=50,output_dim=50,init='uniform',activation='tanh')) model.add(Dense(input_dim=50,output_dim=Y_train_ohe.shape[1],init='uniform',activation='softmax')) sgd = SGD(lr=0.001, decay=1e-7, momentum=0.9, nesterov=True) model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=["accuracy"]) #start training model.fit(X_train,Y_train_ohe,epochs=50,batch_size=300,shuffle=True,verbose=1,validation_split=0.3) #count accuracy y_train_pred = model.predict_classes(X_train, verbose=0) train_acc = np.sum(Y_train == y_train_pred, axis=0) / X_train.shape[0] print('Training accuracy: %.2f%%' % (train_acc * 100)) y_test_pred = model.predict_classes(X_test, verbose=0) test_acc = np.sum(Y_test == y_test_pred, axis=0) / X_test.shape[0] print('Test accuracy: %.2f%%' % (test_acc * 100))
訓(xùn)練結(jié)果如下:
Epoch 45/50 42000/42000 [==============================] - 1s 17us/step - loss: 0.2174 - acc: 0.9380 - val_loss: 0.2341 - val_acc: 0.9323 Epoch 46/50 42000/42000 [==============================] - 1s 17us/step - loss: 0.2061 - acc: 0.9404 - val_loss: 0.2244 - val_acc: 0.9358 Epoch 47/50 42000/42000 [==============================] - 1s 17us/step - loss: 0.1994 - acc: 0.9413 - val_loss: 0.2295 - val_acc: 0.9347 Epoch 48/50 42000/42000 [==============================] - 1s 17us/step - loss: 0.2003 - acc: 0.9413 - val_loss: 0.2224 - val_acc: 0.9350 Epoch 49/50 42000/42000 [==============================] - 1s 18us/step - loss: 0.2013 - acc: 0.9417 - val_loss: 0.2248 - val_acc: 0.9359 Epoch 50/50 42000/42000 [==============================] - 1s 17us/step - loss: 0.1960 - acc: 0.9433 - val_loss: 0.2300 - val_acc: 0.9346 Training accuracy: 94.11% Test accuracy: 93.61%
以上這篇keras導(dǎo)入weights方式就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
python實(shí)現(xiàn)去除空格及tab換行符的方法
這篇文章主要為大家介紹了python實(shí)現(xiàn)去除空格及tab換行符的方法,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2022-06-06在Python中調(diào)用Ping命令,批量IP的方法
今天小編就為大家分享一篇在Python中調(diào)用Ping命令,批量IP的方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-01-01Python實(shí)戰(zhàn)之異步獲取中國天氣信息
這篇文章主要介紹了如何利用Python爬蟲異步獲取天氣信息,用的API是中國天氣網(wǎng)。文中的示例代碼講解詳細(xì),感興趣的小伙伴可以動手試一試2022-03-03Django項目搭建之實(shí)現(xiàn)簡單的API訪問
這篇文章主要給大家介紹了關(guān)于Django項目搭建之實(shí)現(xiàn)簡單的API訪問的相關(guān)資料,文中通過圖文以及示例代碼介紹的非常詳細(xì),對大家學(xué)習(xí)或者使用Django具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下2023-02-02PyQt教程之自定義組件Switch?Button的實(shí)現(xiàn)
這篇文章主要為大家詳細(xì)介紹了PyQt中如何實(shí)現(xiàn)自定義組件Switch?Button,文中的示例代碼簡潔易懂,具有一定的學(xué)習(xí)價值,感興趣的可以了解一下2023-05-05python dataframe astype 字段類型轉(zhuǎn)換方法
下面小編就為大家分享一篇python dataframe astype 字段類型轉(zhuǎn)換方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-04-04對Python subprocess.Popen子進(jìn)程管道阻塞詳解
今天小編就為大家分享一篇對Python subprocess.Popen子進(jìn)程管道阻塞詳解,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-10-10