关键词搜索

源码搜索 ×
×

pytorch 中 torch.no_grad()、requires_grad、eval()

发布2021-09-23浏览1198次

详情内容

辅助视频教程:Python基础教程|xin3721自学网ul li id=itemtitlePython3 从入门到精通视频教程/li /ul ul li class=description Python是一种跨平台的计算机程序设计语言。是一种面向对象的动态类型语言,最初被设计用于编写自动化脚本(shell),icon-default.png?t=L892https://www.xin3721.com/eschool/pythonxin3721/

requires_grad

requires_grad=True 要求计算梯度;
requires_grad=False 不要求计算梯度;
pytorch中,tensor有一个 requires_grad参数,如果设置为True,则反向传播时,该tensor就会自动求导。 tensor的requires_grad的属性默认为False,若一个节点(叶子变量:自己创建的tensor)requires_grad被设置为True,那么 所有依赖它的节点requires_grad都为True (即使其他相依赖的tensor的requires_grad = False)

  1. x = torch.randn(10, 5, requires_grad = True)
  2. y = torch.randn(10, 5, requires_grad = False)
  3. z = torch.randn(10, 5, requires_grad = False)
  4. w = x + y + z
  5. w.requires_grad

输出:

True

volatile

volatile是Variable的另一个重要的标识,它能够将所有依赖它的节点全部设为volatile=True,优先级比requires_grad=True高。
而volatile=True的节点不会求导,即使requires_grad=True,也不会进行反向传播,对于不需要反向传播的情景(inference,测试阶段推断阶段),该参数可以实现一定速度的提升,并节省一半的显存,因为其不需要保存梯度。
但是, 注意 volatile已经取消了,使用with torch.no_grad()来替代 。

torch.no_grad()

是一个上下文管理器,被该语句 wrap 起来的部分将不会track 梯度。
with torch.no_grad()或者@torch.no_grad()中的数据不需要计算梯度,也不会进行反向传播。
(torch.no_grad()是新版本pytorch中volatile的替代)

  1. x = torch.randn(2, 3, requires_grad = True)
  2. y = torch.randn(2, 3, requires_grad = False)
  3. z = torch.randn(2, 3, requires_grad = False)
  4. m=x+y+z
  5. with torch.no_grad():
  6. w = x + y + z
  7. print(w)
  8. print(m)
  9. print(w.requires_grad)
  10. print(w.grad_fn)
  11. print(w.requires_grad)

输出:

  1. tensor([[-2.7066, -0.7406, 0.5740],
  2. [-0.7071, -1.6057, 1.9732]])
  3. tensor([[-2.7066, -0.7406, 0.5740],
  4. [-0.7071, -1.6057, 1.9732]], grad_fn=<AddBackward0>)
  5. False
  6. None
  7. False

model.eval()与with torch.no_grad()

共同点:

在PyTorch中进行validation时,使用这两者均可切换到测试模式。

如用于通知dropout层和batchnorm层在train和val模式间切换。
在train模式下,dropout网络层会按照设定的参数p设置保留激活单元的概率(保留概率=p); batchnorm层会继续计算数据的mean和var等参数并更新。
在val模式下,dropout层会让所有的激活单元都通过,而batchnorm层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值。

不同点:

model.eval()会影响各层的gradient计算行为,即gradient计算和存储与training模式一样,只是不进行反传。

with torch.zero_grad()则停止autograd模块的工作,也就是停止gradient计算,以起到加速和节省显存的作用,从而节省了GPU算力和显存,但是并不会影响dropout和batchnorm层的行为。

也就是说,如果不在意显存大小和计算时间的话,仅使用model.eval()已足够得到正确的validation的结果;而with torch.zero_grad()则是更进一步加速和节省gpu空间(因为不用计算和存储gradient),从而可以更快计算,也可以跑更大的batch来测试。

 

相关技术文章

点击QQ咨询
开通会员
返回顶部
×
微信扫码支付
微信扫码支付
确定支付下载
请使用微信描二维码支付
×

提示信息

×

选择支付方式

  • 微信支付
  • 支付宝付款
确定支付下载