≡
  • 网络编程
  • 数据库
  • CMS技巧
  • 软件编程
  • PHP笔记
  • JavaScript
  • MySQL
位置:首页 > 网络编程 > Python

Python之PyTorch 设置随机数种子使结果可复现的简单示例

人气:360 时间:2021-06-07

这篇文章主要为大家详细介绍了Python之PyTorch 设置随机数种子使结果可复现的简单示例,具有一定的参考价值,可以用来参考一下。

感兴趣的小伙伴,下面一起跟随四海网的雯雯来看看吧!

由于在模型训练的过程中存在大量的随机操作,使得对于同一份代码,重复运行后得到的结果不一致。

因此,为了得到可重复的实验结果,我们需要对随机数生成器设置一个固定的种子。

CUDNN

cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率。如果需要保证可重复性,可以使用如下设置:

代码如下:


from torch.backends import cudnn
cudnn.benchmark = False            # if benchmark=True, deterministic will be False
cudnn.deterministic = True

PyTorch 如何设置随机数种子使结果可复现

不过实际上这个设置对精度影响不大,仅仅是小数点后几位的差别。所以如果不是对精度要求极高,其实不太建议修改,因为会使计算效率降低。

Pytorch

代码如下:


torch.manual_seed(seed)            # 为CPU设置随机种子
torch.cuda.manual_seed(seed)       # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed)   # 为所有GPU设置随机种子

PyTorch 如何设置随机数种子使结果可复现

Python & Numpy

如果读取数据的过程采用了随机预处理(如RandomCrop、RandomHorizontalFlip等),那么对python、numpy的随机数生成器也需要设置种子。

代码如下:


import random
import numpy as np
random.seed(seed)
np.random.seed(seed)

PyTorch 如何设置随机数种子使结果可复现

Dataloader

如果dataloader采用了多线程(num_workers > 1), 那么由于读取数据的顺序不同,最终运行结果也会有差异。

也就是说,改变num_workers参数,也会对实验结果产生影响。

目前暂时没有发现解决这个问题的方法,但是只要固定num_workers数目(线程数)不变,基本上也能够重复实验结果。

 

补充:pytorch 固定随机数种子踩过的坑

 

1.初步固定

代码如下:


 def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     torch.cuda.manual_seed(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True
     torch.backends.cudnn.enabled = False
     torch.backends.cudnn.benchmark = False
     #torch.backends.cudnn.benchmark = True #for accelerating the running
 setup_seed(2019)

PyTorch 如何设置随机数种子使结果可复现

2.继续添加如下代码:

代码如下:


tensor_dataset = ImageList(opt.training_list,transform)
def _init_fn(worker_id): 
    random.seed(10 + worker_id)
    np.random.seed(10 + worker_id)
    torch.manual_seed(10 + worker_id)
    torch.cuda.manual_seed(10 + worker_id)
    torch.cuda.manual_seed_all(10 + worker_id)
dataloader = DataLoader(tensor_dataset,                        
                    batch_size=opt.batchSize,     
                    shuffle=True,     
                    num_workers=opt.workers,
                    worker_init_fn=_init_fn)

PyTorch 如何设置随机数种子使结果可复现

3.在上面的操作之后发现加载的数据多次试验大部分一致了

但是仍然有些数据是不一致的,后来发现是pytorch版本的问题,将原先的0.3.1版本升级到1.1.0版本,问题解决

4.按照上面的操作后虽然解决了问题

但是由于将cudnn.benchmark设置为False,运行速度降低到原来的1/3,所以继续探索,最终解决方案是把第1步变为如下,同时将该部分代码尽可能放在主程序最开始的部分,例如:

代码如下:


import torch
import torch.nn as nn
from torch.nn import init
import pdb
import torch.nn.parallel
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.utils.data import DataLoader, Dataset
import sys
gpu_id = "3,2"
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
print('GPU: ',gpu_id)
def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     torch.cuda.manual_seed(seed)
     np.random.seed(seed)
     random.seed(seed)
     cudnn.deterministic = True
     #cudnn.benchmark = False
     #cudnn.enabled = False

setup_seed(2019)

 

以上为个人经验,希望能给大家一个参考,也希望大家多多支持四海网。如有错误或未考虑完全的地方,望不吝赐教。

本文来自:http://www.q1010.com/181/18819-0.html

注:关于Python之PyTorch 设置随机数种子使结果可复现的简单示例的内容就先介绍到这里,更多相关文章的可以留意四海网的其他信息。

关键词:python

您可能感兴趣的文章

  • Python之使用Tensorflow训练BP神经网络实现鸢尾花分类的简单示例
  • Python之Django如何创作一个简单的最小程序的简单示例
  • 解决Python之pytorch损失反向传播后梯度为none的问题
  • Python之Django中Pyecharts生成图表的简单示例
  • 解决python2中unicode()函数在python3中报错的问题
  • Python之pytorch MSELoss计算平均的简单示例
  • Python之对B站收藏夹按照视频发布时间进行排序的问题
  • Python之将带小数的浮点型字符串转换为整数的简单示例
  • Python之绘制棒棒糖图表的简单示例
  • Python异步爬虫实现原理与知识总结的简单示例
上一篇:Python之Pytorch中TensorBoard及torchsummary用法示例
下一篇:解决Python之pytorch损失反向传播后梯度为none的问题
热门文章
  • Python 处理Cookie的菜鸟教程(一)Cookie库
  • python之pandas取dataframe特定行列的简单示例
  • Python解决json.dumps错误::‘utf8’ codec can‘t decode byte
  • Python通过pythony连接Hive执行Hql的脚本
  • Python 三种方法删除列表中重复元素的简单示例
  • python爬虫代码示例
  • Python 中英文标点转换示例
  • Python 不得不知的开源项目解析
  • Python urlencode编码和url拼接实现方法
  • python按中文拆分中英文混合字符串的简单示例
  • 最新文章
    • Python利用numpy三层神经网络的简单示例
    • pygame可视化幸运大转盘的简单示例
    • Python爬虫之爬取二手房信息的简单示例
    • Python之time库的简单示例
    • OpenCV灰度、高斯模糊、边缘检测的简单示例
    • Python安装Bs4及使用的简单示例
    • django自定义manage.py管理命令的简单示例
    • Python之matplotlib 向任意位置添加一个子图(axes)的简单示例
    • Python图像标签标注软件labelme分析的简单示例
    • python调用摄像头并拍照发邮箱的简单示例

四海网收集整理一些常用的php代码,JS代码,数据库mysql等技术文章。