美文网首页
torch.optim优化器

torch.optim优化器

作者: 星光下的胖子 | 来源:发表于2020-05-21 19:57 被阅读0次

使用流程:

# torch.optim优化器的使用
'''
步骤:
1、自定义神经网络模型
    1)初始化模型参数
    2)重载前向传播forward()
2、定义优化器,并使用优化器实现参数优化
    1)创建优化器实例(有多种优化器类型,常用SGD和Adam)
    2)优化参数:
        1.清空梯度
        2.前向传播
        3.计算loss
        4.反向传播(根据loss来计算梯度)
        5.参数更新(根据梯度来更新参数)
'''
import torch
from torch import optim
import matplotlib.pyplot as plt

# 1、自定义模型
class Line(torch.nn.Module):
    def __init__(self):
        super(Line, self).__init__()
        # 初始化模型参数,并设置为Parameter类型
        self.w = torch.nn.Parameter(torch.rand(1))
        self.b = torch.nn.Parameter(torch.rand(1))

    # 重载前向传播:数据输入到模型中,返回结果
    def forward(self, x):
        return self.w * x + self.b


if __name__ == '__main__':
    # 构建数据
    xs = torch.arange(0, 1, 0.01)
    ys = 3 * xs + 4 + 0.01 * torch.rand(100)

    # 模型实例化
    line = Line()

    # 2、定义优化器
    # SGD继承Optimizer
    opt = optim.SGD(line.parameters(), lr=0.1)
    # opt = optim.Adam(line.parameters(), lr=0.1)

    # 模型训练
    plt.ion()
    for epoch in range(30):
        for x, y in zip(xs, ys):
            # 梯度清空
            # 优化器optim的梯度是被累积的而不是被替换掉,每次先调用zero_grad()清空梯度
            opt.zero_grad()
            # 前向传播(父类Module的__call__方法调用forward()方法)
            h = line(x)
            # 计算loss
            loss = (h - y) ** 2
            # 反向传播:计算当前张量loss的梯度,并保存到buffer中
            # loss必须是torch.tensor(张量)数据类型
            loss.backward()
            # 参数更新:根据梯度来更新参数
            # step()函数从buffer中获取梯度
            opt.step()

        # plt绘图
        print(line.w.item(), line.b.item(), loss.item())
        # plt.cla()
        plt.plot(xs, ys, ".")
        plt.plot(xs, [line.w * x + line.b for x in xs])
        plt.pause(0.01)
    plt.ioff()
    plt.show()

拟合过程可视化:


拟合过程

相关文章

  • 优化器

    优化器(optim) 优化算法模块(torch.optim) torch.optim 实现了丰富的优化算法,包括S...

  • 优化器

    torch.optim : torch.optim是实现各种优化算法的包 要使用torch.optim,必须构造一...

  • torch.optim优化器

    使用流程: 拟合过程可视化:

  • pytorch中的optim

    torch.optim torch.optim是一个实现了各种优化算法的库。大部分常用的方法得到支持,并且接口具备...

  • Pytorch学习笔记(8) Pytorch 常用的优化算法

    Pytorch中如何使用优化方法。torch.optim是一个实现了各种优化算法的库。大部分常用的方法得到支持,并...

  • PyTorch 优化

     torch.optim is a package implementing various optimizati...

  • crf

    import torch import torch.nn as nn import torch.optim as ...

  • 性能优化

    内容优化 服务器优化 Cookie优化 CSS优化 javascript优化 图像优化

  • LLVM

    一、编译器 性能优化:启动优化、界面优化、架构优化 编译型语言:OC(编译器是clang)、C(编译器可以直接执行...

  • 读《数据库索引设计与优化》以及相关知识

    一.优化器的优化方式 Oracle的优化器共有两种的优化方式,即基于规则的优化方式(Rule-Based Opti...

网友评论

      本文标题:torch.optim优化器

      本文链接:https://www.haomeiwen.com/subject/dfgtahtx.html