美文网首页
SSAST模型结构

SSAST模型结构

作者: 小草_a484 | 来源:发表于2025-05-24 21:39 被阅读0次

模型架构分析

在模型剪枝的第一步就是理解模型的结构

  • 先采用了如下代码查看模型的结构
for i in list(audio_model.named_parameters()):
    if('bias' not in i[0]):
        print(i[0])
module.v.cls_token
module.v.pos_embed
module.v.dist_token
module.v.patch_embed.proj.weight
module.v.blocks.0.norm1.weight
module.v.blocks.0.attn.qkv.weight
module.v.blocks.0.attn.proj.weight
module.v.blocks.0.norm2.weight
module.v.blocks.0.mlp.fc1.weight
module.v.blocks.0.mlp.fc2.weight
module.v.blocks.1.norm1.weight
module.v.blocks.1.attn.qkv.weight
module.v.blocks.1.attn.proj.weight
module.v.blocks.1.norm2.weight
module.v.blocks.1.mlp.fc1.weight
module.v.blocks.1.mlp.fc2.weight
module.v.blocks.2.norm1.weight
module.v.blocks.2.attn.qkv.weight
module.v.blocks.2.attn.proj.weight
module.v.blocks.2.norm2.weight
module.v.blocks.2.mlp.fc1.weight
module.v.blocks.2.mlp.fc2.weight
module.v.blocks.3.norm1.weight
module.v.blocks.3.attn.qkv.weight
module.v.blocks.3.attn.proj.weight
module.v.blocks.3.norm2.weight
module.v.blocks.3.mlp.fc1.weight
module.v.blocks.3.mlp.fc2.weight
module.v.blocks.4.norm1.weight
module.v.blocks.4.attn.qkv.weight
module.v.blocks.4.attn.proj.weight
module.v.blocks.4.norm2.weight
module.v.blocks.4.mlp.fc1.weight
module.v.blocks.4.mlp.fc2.weight
module.v.blocks.5.norm1.weight
module.v.blocks.5.attn.qkv.weight
module.v.blocks.5.attn.proj.weight
module.v.blocks.5.norm2.weight
module.v.blocks.5.mlp.fc1.weight
module.v.blocks.5.mlp.fc2.weight
module.v.blocks.6.norm1.weight
module.v.blocks.6.attn.qkv.weight
module.v.blocks.6.attn.proj.weight
module.v.blocks.6.norm2.weight
module.v.blocks.6.mlp.fc1.weight
module.v.blocks.6.mlp.fc2.weight
module.v.blocks.7.norm1.weight
module.v.blocks.7.attn.qkv.weight
module.v.blocks.7.attn.proj.weight
module.v.blocks.7.norm2.weight
module.v.blocks.7.mlp.fc1.weight
module.v.blocks.7.mlp.fc2.weight
module.v.blocks.8.norm1.weight
module.v.blocks.8.attn.qkv.weight
module.v.blocks.8.attn.proj.weight
module.v.blocks.8.norm2.weight
module.v.blocks.8.mlp.fc1.weight
module.v.blocks.8.mlp.fc2.weight
module.v.blocks.9.norm1.weight
module.v.blocks.9.attn.qkv.weight
module.v.blocks.9.attn.proj.weight
module.v.blocks.9.norm2.weight
module.v.blocks.9.mlp.fc1.weight
module.v.blocks.9.mlp.fc2.weight
module.v.blocks.10.norm1.weight
module.v.blocks.10.attn.qkv.weight
module.v.blocks.10.attn.proj.weight
module.v.blocks.10.norm2.weight
module.v.blocks.10.mlp.fc1.weight
module.v.blocks.10.mlp.fc2.weight
module.v.blocks.11.norm1.weight
module.v.blocks.11.attn.qkv.weight
module.v.blocks.11.attn.proj.weight
module.v.blocks.11.norm2.weight
module.v.blocks.11.mlp.fc1.weight
module.v.blocks.11.mlp.fc2.weight
module.v.norm.weight
module.v.head.weight
module.v.head_dist.weight
module.mlp_head.0.weight
module.mlp_head.1.weight

可以看到模型的架构就是一个VIT和一个MLP

  • 采用了接下来的代码查看模型的结构
for name,layer in audio_model.named_children():
    print(f"name:{name},type:{type(layer)}")
name:module,type:<class 'models.ast_models.ASTModel'>

可以看到模型结果是一个包装好的模型采用的DataParallel

  • 进一步取出模型的结构
for name,layer in audio_model.module.named_children():
    print(f"name:{name},type:{type(layer)}")
name:v,type:<class 'timm.models.vision_transformer.DistilledVisionTransformer'>
name:mlp_head,type:<class 'torch.nn.modules.container.Sequential'>

和前面的原理相同,即一个vit和一个container

  • 再进一步取出vit
for name,layer in audio_model.module.v.named_children():
    print(f"name:{name},type:{type(layer)}")
name:patch_embed,type:<class 'models.ast_models.PatchEmbed'>
name:pos_drop,type:<class 'torch.nn.modules.dropout.Dropout'>
name:blocks,type:<class 'torch.nn.modules.container.ModuleList'>
name:norm,type:<class 'torch.nn.modules.normalization.LayerNorm'>
name:pre_logits,type:<class 'torch.nn.modules.linear.Identity'>
name:head,type:<class 'torch.nn.modules.linear.Linear'>
name:head_dist,type:<class 'torch.nn.modules.linear.Linear'>

结果除了第一个embed是嵌入块是提前修改了的,其他都是torch.nn的经典模块

name:proj,type:<class 'torch.nn.modules.conv.Conv2d'>

而embed块里面只有一个卷积层

  • 继续深入便利ModuleList
for name,layer in audio_model.module.v.blocks.named_children():
    print(f"name:{name},type:{type(layer)}")
name:0,type:<class 'timm.models.vision_transformer.Block'>
name:1,type:<class 'timm.models.vision_transformer.Block'>
name:2,type:<class 'timm.models.vision_transformer.Block'>
name:3,type:<class 'timm.models.vision_transformer.Block'>
name:4,type:<class 'timm.models.vision_transformer.Block'>
name:5,type:<class 'timm.models.vision_transformer.Block'>
name:6,type:<class 'timm.models.vision_transformer.Block'>
name:7,type:<class 'timm.models.vision_transformer.Block'>
name:8,type:<class 'timm.models.vision_transformer.Block'>
name:9,type:<class 'timm.models.vision_transformer.Block'>
name:10,type:<class 'timm.models.vision_transformer.Block'>
name:11,type:<class 'timm.models.vision_transformer.Block'>
  • 继续深入查看每个子block内部结构
for name,layer in audio_model.module.v.blocks[0].named_children():
    print(f"name:{name},type:{type(layer)}")
name:norm1,type:<class 'torch.nn.modules.normalization.LayerNorm'>
name:attn,type:<class 'timm.models.vision_transformer.Attention'>
name:drop_path,type:<class 'torch.nn.modules.linear.Identity'>
name:norm2,type:<class 'torch.nn.modules.normalization.LayerNorm'>
name:mlp,type:<class 'timm.models.vision_transformer.Mlp'>
  • 继续深入查看attention
for name,layer in audio_model.module.v.blocks[0].attn.named_children():
    print(f"name:{name},type:{type(layer)}")
name:qkv,type:<class 'torch.nn.modules.linear.Linear'>
name:attn_drop,type:<class 'torch.nn.modules.dropout.Dropout'>
name:proj,type:<class 'torch.nn.modules.linear.Linear'>
name:proj_drop,type:<class 'torch.nn.modules.dropout.Dropout'>
  • 以及查看同层级的mlp
for name,layer in audio_model.module.v.blocks[0].mlp.named_children():
    print(f"name:{name},type:{type(layer)}")
name:fc1,type:<class 'torch.nn.modules.linear.Linear'>
name:act,type:<class 'torch.nn.modules.activation.GELU'>
name:fc2,type:<class 'torch.nn.modules.linear.Linear'>
name:drop,type:<class 'torch.nn.modules.dropout.Dropout'>

prune剪枝

prune是pytorch库中剪枝方法,分为非结构剪枝和结构剪枝

  1. 非结构剪枝
  • random_unstructured
prune.random_unstructured(model.conv1,name="weight",amount=0.5)
prune.remove(model.conv1,'weight')
print(model.conv1.weight)
Parameter containing:
tensor([[[[-0.0000, -0.0000,  0.0000],
          [-0.1978, -0.2261, -0.0860],
          [ 0.1037,  0.0000,  0.0000]]],


        [[[-0.2053,  0.0410,  0.0000],
          [-0.0000, -0.0000, -0.0000],
          [-0.1972,  0.3187, -0.0000]]],


        [[[-0.0000, -0.0000, -0.1223],
          [ 0.1347,  0.3205,  0.0000],
          [-0.0682, -0.0000,  0.2382]]]], device='cuda:0', requires_grad=True)
  • l1_unstructured
prune.l1_unstructured(model.conv1,name="weight",amount=0.5)
prune.remove(model.conv1,'weight')
print(model.conv1.weight)
Parameter containing:
tensor([[[[-0.3320,  0.2106,  0.3279],
          [-0.3195,  0.3222,  0.0000],
          [-0.0000, -0.2038, -0.0000]]],


        [[[-0.0000,  0.3175,  0.0000],
          [-0.0000, -0.2928, -0.0000],
          [ 0.0000, -0.0000, -0.0000]]],


        [[[ 0.2634, -0.2414,  0.0000],
          [ 0.0000, -0.3128,  0.0000],
          [-0.2402, -0.2144,  0.0000]]]], device='cuda:0', requires_grad=True)
  1. 结构剪枝
  • random_structured
prune.random_structured(model.conv1,name="weight",amount=0.33,dim=2)
prune.remove(model.conv1,'weight')
print(model.conv1.weight)
Parameter containing:
tensor([[[[ 0.0048,  0.1459, -0.0502],
          [-0.0000, -0.0000,  0.0000],
          [ 0.2687, -0.1137,  0.1034]]],


        [[[ 0.1801, -0.2711, -0.0819],
          [ 0.0000,  0.0000,  0.0000],
          [-0.2404, -0.3188,  0.3194]]],


        [[[ 0.0434, -0.0618,  0.0368],
          [ 0.0000,  0.0000,  0.0000],
          [ 0.1729,  0.2978, -0.3020]]]], device='cuda:0', requires_grad=True)
  • ln_structured
prune.ln_structured(model.conv1,name="weight",amount=0.33,n=2,dim=2)
prune.remove(model.conv1,'weight')
print(model.conv1.weight)
Parameter containing:
tensor([[[[-0.1829,  0.2475,  0.2816],
          [ 0.0694, -0.1366,  0.0740],
          [-0.0000, -0.0000,  0.0000]]],


        [[[-0.2966, -0.2881,  0.2974],
          [ 0.3074,  0.2858,  0.1990],
          [-0.0000,  0.0000, -0.0000]]],


        [[[-0.0558, -0.3072, -0.0674],
          [-0.0860,  0.2881,  0.1865],
          [ 0.0000,  0.0000,  0.0000]]]], device='cuda:0', requires_grad=True)

模型性能指标统计工具

参数量计算

  • 理论上计算方式:
  1. 卷积层:param=out_channels(in_channelskernel_size^2)+out_channels
  2. 归一化层:param=2*out_channels
  3. 全连接层:param=in_feature*out_feature+out_features
  4. 其它层:无参数
  • 使用工具thop
    统计Flops和参数量,一个mac等于两个Flops
from thop import profile

inp=torch.randn(48,512,128).to('cuda')
macs,params=profile(audio_model,inputs=(inp,))
print(f"MACs {macs}, Parameters: {params}")
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
MACs 2462153674752.0, Parameters: 85258757.0
SSAST-Base-Patch-400 比例 参数量 准确率 推理速度 mac
0 85258757 87.61 1.104047 2462153674752
10 85258757 2462153674752
30 85258757 87.60 1.096145 2462153674752
ASTModel(
  (v): DistilledVisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(1, 768, kernel_size=(16, 16), stride=(10, 10)) => (proj): Conv2d(1, 384, kernel_size=(16, 16), stride=(10, 10))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0-11): 12 x Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True) => (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True) => (qkv): Linear(in_features=384, out_features=1152, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True) => (proj): Linear(in_features=384, out_features=384, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True) => (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True) => (fc1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=3072, out_features=768, bias=True) => (fc2): Linear(in_features=1536, out_features=384, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True) => (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
    (pre_logits): Identity()
    (head): Linear(in_features=768, out_features=1000, bias=True)
    (head_dist): Linear(in_features=768, out_features=1000, bias=True)
  )
  (mlp_head): Sequential(
    (0): LayerNorm((768,), eps=1e-05, elementwise_affine=True) => (0): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=768, out_features=5, bias=True) => (1): Linear(in_features=384, out_features=5, bias=True)
  )
)

MACs: 58.0542 G => 16.2705 G
Params: 87.2606 M => 23.1657 M

相关文章

  • 软件体系结构第二章

    1. 软件体系结构分类:结构模型、框架模型、动态模型、过程模型、功能模型 (1)结构模型:以体系结构的构件、连接件...

  • Pytorch 载入和保存模型(无格式整理,先记下)

    定义网络结构 使用网络结构定义模型: 载入模型参数 4.训练模型

  • note_21.4.2_DB

    mariadb(mysql): 数据管理模型:层次模型、网状模型、关系模型 数据分类:结构化数据、半结构化数据、非...

  • 网络基础

    OSI参考模型—层次结构 OSI参考模型和TCP/IP参考模型的比较 TCP/IP参考模型—层次结构 ·网络接口层...

  • keras----模型保存

    一、模型保存 m1保存的是模型图结构m2保存的是训练后的模型参数和模型图结构m3保存的是模型参数,没有保存图结构,...

  • ISM解释结构模型——研究系统结构关系情况

    一、解释结构模型ISM介绍 ISM(解释结构模型,Interpretative Structural Modeli...

  • 什么是ACE模型?

    首发:很牛 ACE模型是Dianne·R·Stober博士于2006提出的教练模型,该模型结构结构如下: Awar...

  • 如何高效学习七

    昨天介绍了结构,今天我们来学习模型。 如果说结构是整体性学习的门户,那么模型就是开门的钥匙,模型是简化的结构,它是...

  • 理解力到底怎么提升之模型

    原文理解 一、什么是模型 点和点的关系,称之为结构 结构和结构之间的关系,称之为模型,如SMART原则 二,模型为...

  • 评分卡模型前沿研究

    组合模型 两类结构的评分组合模型 串型组合模型 并行模型

网友评论

      本文标题:SSAST模型结构

      本文链接:https://www.haomeiwen.com/subject/ramzijtx.html