美文网首页
PyTorch Gather 函数

PyTorch Gather 函数

作者: 数科每日 | 来源:发表于2022-02-14 23:52 被阅读0次

PyTorch 的 Gather 函数很实用,但是理解起来有些困难,本文试图用图例和代码给出解释。 完整代码

Gather 主要有三个参数

  • input: 源数据
  • index: 需要选取的数据的index
  • dim: 筛选数据的方式

Gather 函数返回值和 index 相同

Dim=0
Dim=0
dim = 0
input = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
index = torch.tensor([[0, 1, 2], [1, 2, 0]])
                     
output = torch.gather(input, dim, index)
output

Dim = 0 的时候, 从外层选择, 最内层的 list Tensor 会被拆开:

image.png
Dim=1
Dim=1
dim = 1
input = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
index = torch.tensor([[0, 1], [1, 2], [2, 0]])

output = torch.gather(input, dim, index)
output

Dim = 1 的时候, 从内层选择:

image.png

相关文章

网友评论

      本文标题:PyTorch Gather 函数

      本文链接:https://www.haomeiwen.com/subject/tmrrlrtx.html