PyTorch 如何設(shè)置隨機數(shù)種子使結(jié)果可復現(xiàn)
由于在模型訓練的過程中存在大量的隨機操作,使得對于同一份代碼,重復運行后得到的結(jié)果不一致。
因此,為了得到可重復的實驗結(jié)果,我們需要對隨機數(shù)生成器設(shè)置一個固定的種子。
CUDNN
cudnn中對卷積操作進行了優(yōu)化,犧牲了精度來換取計算效率。如果需要保證可重復性,可以使用如下設(shè)置:
from torch.backends import cudnn cudnn.benchmark = False # if benchmark=True, deterministic will be False cudnn.deterministic = True
不過實際上這個設(shè)置對精度影響不大,僅僅是小數(shù)點后幾位的差別。所以如果不是對精度要求極高,其實不太建議修改,因為會使計算效率降低。
Pytorch
torch.manual_seed(seed) # 為CPU設(shè)置隨機種子 torch.cuda.manual_seed(seed) # 為當前GPU設(shè)置隨機種子 torch.cuda.manual_seed_all(seed) # 為所有GPU設(shè)置隨機種子
Python & Numpy
如果讀取數(shù)據(jù)的過程采用了隨機預處理(如RandomCrop、RandomHorizontalFlip等),那么對python、numpy的隨機數(shù)生成器也需要設(shè)置種子。
import random import numpy as np random.seed(seed) np.random.seed(seed)
Dataloader
如果dataloader采用了多線程(num_workers > 1), 那么由于讀取數(shù)據(jù)的順序不同,最終運行結(jié)果也會有差異。
也就是說,改變num_workers參數(shù),也會對實驗結(jié)果產(chǎn)生影響。
目前暫時沒有發(fā)現(xiàn)解決這個問題的方法,但是只要固定num_workers數(shù)目(線程數(shù))不變,基本上也能夠重復實驗結(jié)果。
補充:pytorch 固定隨機數(shù)種子踩過的坑
1.初步固定
def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.enabled = False torch.backends.cudnn.benchmark = False #torch.backends.cudnn.benchmark = True #for accelerating the running setup_seed(2019)
2.繼續(xù)添加如下代碼:
tensor_dataset = ImageList(opt.training_list,transform) def _init_fn(worker_id): random.seed(10 + worker_id) np.random.seed(10 + worker_id) torch.manual_seed(10 + worker_id) torch.cuda.manual_seed(10 + worker_id) torch.cuda.manual_seed_all(10 + worker_id) dataloader = DataLoader(tensor_dataset, batch_size=opt.batchSize, shuffle=True, num_workers=opt.workers, worker_init_fn=_init_fn)
3.在上面的操作之后發(fā)現(xiàn)加載的數(shù)據(jù)多次試驗大部分一致了
但是仍然有些數(shù)據(jù)是不一致的,后來發(fā)現(xiàn)是pytorch版本的問題,將原先的0.3.1版本升級到1.1.0版本,問題解決
4.按照上面的操作后雖然解決了問題
但是由于將cudnn.benchmark設(shè)置為False,運行速度降低到原來的1/3,所以繼續(xù)探索,最終解決方案是把第1步變?yōu)槿缦拢瑫r將該部分代碼盡可能放在主程序最開始的部分,例如:
import torch import torch.nn as nn from torch.nn import init import pdb import torch.nn.parallel import torch.nn.functional as F import torch.backends.cudnn as cudnn import torch.optim as optim import torch.utils.data from torch.utils.data import DataLoader, Dataset import sys gpu_id = "3,2" os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id print('GPU: ',gpu_id) def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed(seed) np.random.seed(seed) random.seed(seed) cudnn.deterministic = True #cudnn.benchmark = False #cudnn.enabled = False setup_seed(2019)
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。如有錯誤或未考慮完全的地方,望不吝賜教。
相關(guān)文章
Python OpenCV高斯金字塔與拉普拉斯金字塔的實現(xiàn)
這篇文章主要介紹了Python OpenCV高斯金字塔與拉普拉斯金字塔的實現(xiàn),文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2021-03-03用python打包exe應(yīng)用程序及PyInstaller安裝方式
PyInstaller 制作出來的執(zhí)行文件并不是跨平臺的,如果需要為不同平臺打包,就要在相應(yīng)平臺上運行PyInstaller進行打包。今天通過本文給大家介紹用python打包exe應(yīng)用程序及PyInstaller安裝方式,感興趣的朋友一起看看吧2021-12-12Python多線程編程threading模塊使用最佳實踐及常見問題解析
這篇文章主要為大家介紹了Python多線程編程threading模塊使用最佳實踐及常見問題解析,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪2024-01-01Python中處理表格數(shù)據(jù)的Tablib庫詳解
這篇文章主要介紹了Python中處理表格數(shù)據(jù)的Tablib庫詳解,Tablib 是一個 MIT 許可的格式不可知的表格數(shù)據(jù)集庫,用 Python 編寫,它允許您導入、導出和操作表格數(shù)據(jù)集,需要的朋友可以參考下2023-08-08