美文网首页
pytorch finetune模型

pytorch finetune模型

作者: jiangwenj02 | 来源:发表于2017-11-28 14:42 被阅读0次

pytorch finetune模型

文章主要讲述如何在pytorch上读取以往训练的模型参数,在模型的名字已经变更的情况下又如何读取模型的部分参数等。
                                                                                       --------作者:jiangwenj02【转载请注明】


pytorch 模型的存储与读取

其中在模型的保存过程有存储模型和参数一起的也有单独存储模型参数的

单独存储模型参数

存储时使用:

torch.save(the_model.state_dict(), PATH)

读取时:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

存储模型与参数

存储:

torch.save(the_model, PATH)

读取:

the_model = torch.load(PATH)

模型的参数

fine-tune的过程是读取原有模型的参数,但是由于模型的所要处理的数据集不同,最后的一层class的总数不同,所以需要修改模型的最后一层,这样模型读取的参数,和在大数据集上训练好下载的模型参数在形式上不一样。需要我们自己去写函数读取参数。

pytorch模型参数的形式

模型的参数是以字典的形式存储的。

model_dict = the_model.state_dict(),
for k,v in model_dict.items():
    print(k)

即可看到所有的键值
如果想修改模型的参数,给相应的键值赋值即可

model_dict[k] = new_value

最后更新模型的参数

the_model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是一样的

我们可以通过下列算法进行读取模型

model_dict = model.state_dict()

pretrained_dict = torch.load(model_path)
 # 1. filter out unnecessary keys
diff = {k: v for k, v in model_dict.items() if \
            k in pretrained_dict and pretrained_dict[k].size() == v.size()}
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
pretrained_dict.update(diff)
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是一样的

model_dict = model.state_dict()

pretrained_dict = torch.load(model_path)
keys = []
for k,v in pretrained_dict.items():
    keys.append(k)
i = 0
for k,v in model_dict.items():
    if v.size() == pretrained_dict[keys[i]].size():
        print(k, ',', keys[i])
         model_dict[k]=pretrained_dict[keys[i]]
    i = i + 1
model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是也不一样的

自己找对应关系,一个key对应一个key的赋值

相关文章

  • pytorch finetune模型

    pytorch finetune模型 文章主要讲述如何在pytorch上读取以往训练的模型参数,在模型的名字已经变...

  • Pytorch预训练模型finetune

    这一块实在是因为之前没有过pytorch的经验,从0开始一步一步摸滚打爬。而且发现自己手总是处于闲置状态实在不好,...

  • Finetune预训练模型

    1、PyTorch学习笔记(1)-finetune网络的一些注意事项https://blog.csdn.net/u...

  • CV-字符识别模型

    Pytorch构建CNN模型 Pytorch中构建CNN模型只需要定义好模型的参数和正向传播就可以,Pytorch...

  • Pytorch冻结部分层的参数

    在迁移学习finetune时我们通常需要冻结前几层的参数不参与训练,在Pytorch中的实现如下: 假如我们想要冻...

  • pytorch转caffe2 之 onnx转caffe2报错的解

    目标:将 pytorch模型 转为 onnx模型 再转为 caffe2模型,得到两个.pb文件 pytorch转o...

  • UIE实体关系抽取解读

    通过UIE默认抽取关系 通过预训练模型直接抽取,数据没有返回。 先看下通过finetune预训练模型后的结果如下:...

  • PyTorch模型保存深入理解

    前面写过一篇PyTorch保存模型的文章:Pytorch模型保存与加载,并在加载的模型基础上继续训练 ,简单介绍了...

  • pytorch模型转keras模型

    1. 概述 使用pytorch建立的模型,有时想把pytorch建立好的模型装换为keras,本人使用Tensor...

  • pytorch之保存与加载模型

    pytorch之保存与加载模型 本篇笔记译自pytorch官网tutorial,用于方便查看。pytorch与保存...

网友评论

      本文标题:pytorch finetune模型

      本文链接:https://www.haomeiwen.com/subject/rntebxtx.html