美文网首页
pytorch 拆分已定义的网络结构(slicing netwo

pytorch 拆分已定义的网络结构(slicing netwo

作者: Zeke_Wang | 来源:发表于2019-03-24 20:17 被阅读0次

在pytorch中,常会load已有模型甚至pretrained的模型,用其中几层作为特征提取(feature extraction)。比如用pytorch内置的pretrained ResNet作为特征提取器,需要把fully connected layer去掉。可以用children()方法提出需要的层

import torch.nn as nn
from torchvision import models

model = models.resnet50(pretrained=True)
truncated_model = nn.Sequential(*list(model.children())[:8])
print(truncated_model)

truncted_model可作为feature extractor,需要注意输入输出大小即可。
PS: *list可以达到以下效果

l = ["./foo", "bar", "quux"]

funcXXX(*l)
# 等价于
funcXXX("./foo", "bar", "quux")

也即是,iterate 提取list中的内容,并以逗号分隔。满足nn.Sequential()的输入条件

相关文章

网友评论

      本文标题:pytorch 拆分已定义的网络结构(slicing netwo

      本文链接:https://www.haomeiwen.com/subject/swhpvqtx.html