美文网首页
PyTorch中如何加载子模块的权重

PyTorch中如何加载子模块的权重

作者: WritingHere | 来源:发表于2021-01-02 19:28 被阅读0次

假设我们在深度学习模型中有一个这样的需求:主要模型A中包含子模块B,而模型B可以通过一定的方式得到一个预训练的权重,模型A需要利用B模型的权重,在此基础上继续训练。
首先我们到官网上去寻找,PyTorch官网上给出了一些保存和加载模型的示例,可以说非常全面总结了模型保存和加载的方法和主义事项,https://pytorch.org/tutorials/beginner/saving_loading_models.html。但是这里的方案都是针对一个完整模型的保存和加载的,不能满足我们这个需求。
因此需要基于此做一些改进,具体如代码所示:

import torch
import torch.nn as nn

class ModelA(nn.Module):
    def __init__(self):
        super(ModelA, self).__init__()
        self.A = nn.Linear(2, 3)
    
    def forward(self, A):
        pass


class ModelB(nn.Module):
    def __init__(self):
        super(ModelB, self).__init__()
        self.model_a = ModelA()
        self.A = nn.Linear(2, 3)
    
    def forward(self, x):
        pass

print("Model")
modelA = ModelA()
modelA_dict = modelA.state_dict()
print('-' * 80)
for key in sorted(modelA_dict.keys()):
    parameter = modelA_dict[key]
    print(key)
    print(parameter.size())
    print(parameter)
modelB = ModelB()
modelB_dict = modelB.state_dict()
print('-'*80)
for key in sorted(modelB_dict.keys()):
    print('-'*20)
    parameter = modelB_dict[key]
    print(type(key), key)
    print(parameter.size())
    print(parameter)
    print('-'*20)
print('-'*80)

pretrained_dict = modelA_dict
model_dict = modelB_dict

pretrained_dict = {'model_a.' + k: v for k, v in pretrained_dict.items() if 'model_a.' + k in model_dict}
model_dict.update(pretrained_dict)

modelB.load_state_dict(model_dict)
modelB_dict = modelB.state_dict()
for key in sorted(modelB_dict.keys()):
    parameter = modelB_dict[key]
    print(key)
    print(parameter.size())
    print(parameter)

相关文章

  • PyTorch中如何加载子模块的权重

    假设我们在深度学习模型中有一个这样的需求:主要模型A中包含子模块B,而模型B可以通过一定的方式得到一个预训练的权重...

  • Pytorch: 加载部分权重

    直接加载所有权重 适用于直接使用别人的模型和权重 加载部分权重 适用于对别人的模型做了适当的修改

  • node.js中的require,exports使用说明

    模块是一门语言编写大项目的基石,因此,了解如何组织、编写、编译、加载模块很重要。这里主要谈谈Node中的模块加载。...

  • Pytorch中transform的常见用法

    transforms模块详解 transforms是torchvision中的一个重要模块,它是Pytorch的图...

  • 关于ngx lua require路径

    在nginx中可以直接设定模块的加载路径 在lua中也可以设定模块的加载路径--这是加载

  • PyTorch如何恢复指定权重

    1. 如何从已训练好的网络模型中提取指定层权重 2. 如何加载模型部分参数并更新 可以发现classifier.2...

  • Node.js 核心模块概述

    模块加载原理与加载方式 Node 中的模块:核心模块/原生模块:Node提供的模块。文件模块:用户编写的模块。 N...

  • 如何在Pytorch中正确设计并加载数据集

    本教程属于Pytorch基础教学的一部分 ————《如何在Pytorch中正确设计并加载数据集》 教程所适合的Py...

  • Webpack模块异步加载

    按需加载 1 .当执行到另一个文件中的模块时,通过网络请求即时加载对应的异步模块代码,再继续下面的流程 如何判断哪...

  • pytorch之保存与加载模型

    pytorch之保存与加载模型 本篇笔记译自pytorch官网tutorial,用于方便查看。pytorch与保存...

网友评论

      本文标题:PyTorch中如何加载子模块的权重

      本文链接:https://www.haomeiwen.com/subject/yxtyoktx.html