亚洲乱码中文字幕综合,中国熟女仑乱hd,亚洲精品乱拍国产一区二区三区,一本大道卡一卡二卡三乱码全集资源,又粗又黄又硬又爽的免费视频

Pytorch固定隨機(jī)數(shù)種子的方法小結(jié)

 更新時(shí)間:2023年12月08日 11:15:18   作者:lgc0208  
在對(duì)神經(jīng)網(wǎng)絡(luò)模型進(jìn)行訓(xùn)練時(shí),有時(shí)候會(huì)存在對(duì)訓(xùn)練過程進(jìn)行復(fù)現(xiàn)的需求,然而,每次運(yùn)行時(shí) Pytorch、Numpy 中的隨機(jī)性將使得該目的變得困難重重,基于此,本文記錄了 Pytorch 中的固定隨機(jī)數(shù)種子的方法,需要的朋友可以參考下

引言

在對(duì)神經(jīng)網(wǎng)絡(luò)模型進(jìn)行訓(xùn)練時(shí),有時(shí)候會(huì)存在對(duì)訓(xùn)練過程進(jìn)行復(fù)現(xiàn)的需求。然而,每次運(yùn)行時(shí) Pytorch、Numpy 中的隨機(jī)性將使得該目的變得困難重重。在程序運(yùn)行前固定所有隨機(jī)數(shù)的種子有望解決這一問題?;诖?,本文記錄了 Pytorch 中的固定隨機(jī)數(shù)種子的方法。

在使用 Pytorch 對(duì)模型進(jìn)行訓(xùn)練時(shí),通常涉及到隨機(jī)數(shù)的模塊包括:Python、Pytorch、Numpy、Cudnn。因此,在開始訓(xùn)練前,需要針對(duì)這些涉及隨機(jī)數(shù)的模塊進(jìn)行隨機(jī)數(shù)種子的固定。

1. Python

Python 本身涉及到的隨機(jī)性主要是 Python 自帶的 random 庫(kù)隨機(jī)化和 Hash 隨機(jī)化問題,需要通過 os 庫(kù)對(duì)其進(jìn)行限制:

import os, random
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)

2. Numpy

在使用 Numpy 庫(kù)取隨機(jī)數(shù)時(shí),需要對(duì)其隨機(jī)數(shù)種子進(jìn)行限制:

import numpy as np
np.random.seed(seed)

3. Pytorch

當(dāng) Pytorch 使用 CPU 進(jìn)行運(yùn)算時(shí),需要設(shè)定 CPU 支撐下的 Pytorch 隨機(jī)數(shù)種子:

import torch
torch.manual_seed(seed)

當(dāng) Pytorch 使用 GPU 進(jìn)行運(yùn)算時(shí),需要設(shè)定 GPU 支撐下的 Pytorch 隨機(jī)數(shù)種子:

import torch
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # 使用多 GPU 時(shí)使用

需要特別注意的是:目前很多博客和知乎回答提出 torch.cuda.manual_seed(seed) 和 torch.cuda.manual_seed_all(seed) 具有相同的作用。這個(gè)結(jié)論需要注意 Pytorch 版本。在筆者所用的 Pytorch 2.1 版本下,這兩個(gè)函數(shù)的作用完全不同。參考官方文檔:torch.cuda.manual_seed 和 torch.cuda.manual_seed_all(seed)

當(dāng) Pytorch 使用 Cudnn 進(jìn)行加速運(yùn)算時(shí),還需要限制 Cudnn 在加速過程中涉及到的隨機(jī)策略:

import torch
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

總結(jié)

基于上述庫(kù)的固定隨機(jī)數(shù)方法總結(jié)為:

def set_random_seed(seed: int) -> None:
	random.seed(seed)
	os.environ['PYTHONHASHSEED'] = str(seed)
	np.random.seed(seed)
	torch.manual_seed(seed)
	torch.cuda.manual_seed_all(seed)
	torch.backends.cudnn.benchmark = False
	torch.backends.cudnn.deterministic = True

seed = 114514
set_torch_seed(seed)

如果在實(shí)踐中還調(diào)用了其他涉及隨機(jī)性的第三方庫(kù),則需要根據(jù)上述思路對(duì)該固定隨機(jī)數(shù)方法進(jìn)行動(dòng)態(tài)補(bǔ)充。

以上就是Pytorch固定隨機(jī)數(shù)種子的方法小結(jié)的詳細(xì)內(nèi)容,更多關(guān)于Pytorch固定隨機(jī)數(shù)種子的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!

相關(guān)文章

最新評(píng)論