pytorch中如何設(shè)置隨機(jī)種子
pytorch設(shè)置隨機(jī)種子
pytorch設(shè)置隨機(jī)種子 - 保證復(fù)現(xiàn)模型所有的訓(xùn)練過程
在使用 PyTorch 時(shí),如果希望通過設(shè)置隨機(jī)數(shù)種子,在 GPU 或 CPU 上固定每一次的訓(xùn)練結(jié)果,則需要在程序執(zhí)行的開始處添加以下代碼:
def seed_everything():
'''
設(shè)置整個(gè)開發(fā)環(huán)境的seed
:param seed:
:param device:
:return:
'''
import os
import random
import numpy as np
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# some cudnn methods can be random even after fixing the seed
# unless you tell it to be deterministic
torch.backends.cudnn.deterministic = Truepytorch/tensorflow設(shè)置隨機(jī)種子 ,保證結(jié)果復(fù)現(xiàn)
Pytorch隨機(jī)種子設(shè)置
import numpy as np import random import os import torch def seed_torch(seed=2021): ? ? random.seed(seed) ? ? os.environ['PYTHONHASHSEED'] = str(seed) ? ? np.random.seed(seed) ? ? torch.manual_seed(seed) ? ? torch.cuda.manual_seed(seed) ? ? torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. ? ? torch.backends.cudnn.benchmark = False ? ? torch.backends.cudnn.deterministic = True ? ? torch.backends.cudnn.enabled = False seed_torch()
Tensorflow設(shè)置隨機(jī)種子
第一步 僅導(dǎo)入設(shè)置種子和初始化種子值所需的那些庫
import tensorflow as tf import os import numpy as np import random SEED = 0
第二步 為所有可能具有隨機(jī)行為的庫初始化種子的函數(shù)
def set_seeds(seed=SEED): ? ? os.environ['PYTHONHASHSEED'] = str(seed) ? ? random.seed(seed) ? ? tf.random.set_seed(seed) ? ? np.random.seed(seed)
第三步 激活 Tensorflow 確定性功能
def set_global_determinism(seed=SEED): ? ? set_seeds(seed=seed) ? ? os.environ['TF_DETERMINISTIC_OPS'] = '1' ? ? os.environ['TF_CUDNN_DETERMINISTIC'] = '1' ? ?? ? ? tf.config.threading.set_inter_op_parallelism_threads(1) ? ? tf.config.threading.set_intra_op_parallelism_threads(1) # Call the above function with seed value set_global_determinism(seed=SEED)
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python機(jī)器學(xué)習(xí)之實(shí)現(xiàn)模型持久化與加載
在實(shí)際的機(jī)器學(xué)習(xí)項(xiàng)目中,我們通常需要將訓(xùn)練好的模型保存到磁盤,本文我們會(huì)介紹如何在Python中使用pickle和joblib庫將訓(xùn)練好的模型持久化到磁盤,需要的可以參考一下2023-05-05
Python常用數(shù)據(jù)庫接口sqlite3和MySQLdb學(xué)習(xí)指南
在本章節(jié)中,我們將學(xué)習(xí) Python 中常用的數(shù)據(jù)庫接口,包括 sqlite3用于SQLite數(shù)據(jù)庫和MySQLdb用于 MySQL 數(shù)據(jù)庫,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-06-06
Matplotlib實(shí)戰(zhàn)之百分比柱狀圖繪制詳解
百分比堆疊式柱狀圖是一種特殊的柱狀圖,可以用于可視化比較不同類別或組的百分比或比例的圖表,下面我們就來介紹一下如何使用Matplotlib繪制百分比柱狀圖,需要的可以參考下2023-08-08

