PyTorch提供的autograd包能够根据输入和前向传播过程自动构建计算图,并执行反向传播(链式求导)。
在创建tensor时若将其属性.requires_grad设置为True,它将开始track在其上的所有操作,这样就可以利用链式法则进行梯度传播。完成前向计算后,调用.backward()来完成所有梯度计算。Tensor的梯度将累积到.grad属性中。
- 若不想tensor被继续track,可调用.detach()将其从追踪记录中分离出去
- 若不想tensor被继续track, 也可用with torch.no_grad()将不想被track的操作代码块包裹起来
with no_grad
设置tensor的requires_grad属性值
import torch
x = torch.ones(2, 2, requires_grad=True)
print(x)
print(x.grad_fn)
"""
tensor([[1., 1.],
[1., 1.]], requires_grad=True)
None
# 说明:x是直接创建的,所以它没有grad_fn
"""
y = x + 2
print(y)
print(y.grad_fn)
"""
tensor([[3., 3.],
[3., 3.]], grad_fn=<AddBackward0>)
<AddBackward0 object at 0x00000122034832E8>
# 说明:y是通过⼀个加法操作创建的,所以它有⼀个为<AddBackward> 的 grad_fn 。
"""
如上面的例子中x是直接创建的,也称为叶子节点,叶子节点对应的grad_fn是None
print(x.is_leaf, y.is_leaf) # True, False
- 创建tensor时默认requires_grad为False
- 可通过.requires_grad_()来用inplace方式改变requires_grad属性
import torch
a = torch.randn(2, 2)
print(a.requires_grad) # False
a.requires_grad_(True)
print(a.requires_grad) # True
b = (a*a).sum()
print(b.grad_fn) # <SumBackward0 object at 0x00000122032D3940>
梯度
out公式
out梯度
import torch
x = torch.ones(2, 2, requires_grad=True)
y = x + 2
z = y * y * 3
out =z.mean()
print(z, out)
# 求梯度 backward()
out.backward()
# 查看梯度
print(x.grad)
"""
tensor([[27., 27.],
[27., 27.]], grad_fn=<MulBackward0>) tensor(27., grad_fn=<MeanBackward0>)
tensor([[4.5000, 4.5000],
[4.5000, 4.5000]])
"""
# grad在反向传播过程中是累加的(accumulated)
out2 = x.sum()
out2.backward()
print(x.grad)
"""
tensor([[5.5000, 5.5000],
[5.5000, 5.5000]])
"""
# 反向传播之前需把梯度清零
out3 = x.sum()
x.grad.data.zero_()
out3.backward()
print(x.grad)
"""
tensor([[1., 1.],
[1., 1.]])
"""
grad在反向传播过程中是累加的(accumulated),这意味着每⼀次运⾏行行反向传播,梯度都会累加之前的梯度,所以⼀般在反向传播之前需把梯度清零。
中断梯度track的例子
import torch
x = torch.tensor(1.0, requires_grad=True)
y1 = x ** 2
# y2没有被track
with torch.no_grad():
y2 = x ** 3
# 只有y1被track,所以梯度为2
y3 = y1 + y2
print(x.requires_grad) # True
print(y1, y1.requires_grad) # tensor(1., grad_fn=<PowBackward0>) True
print(y2, y2.requires_grad) # tensor(1.) False
print(y3, y3.requires_grad) # tensor(2.grad_fn=<AddBackward0>) True










网友评论