美文网首页
pyTorch学习笔记——gather函数详解

pyTorch学习笔记——gather函数详解

作者: 韧心222 | 来源:发表于2021-10-13 16:11 被阅读0次

参考文献:图解PyTorch中的torch.gather函数 - 知乎 (zhihu.com)

gather 函数的声明为:

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

这个函数我大概研究了半个小时,虽然明白了基本的运算方法,但是其具体用法还理解的不够深入,如果以后有心得的话,再和大家来交流。

根据gather函数的声明,可以看到gather函数主要由三个参数,分别是:

  • input
  • dim
  • index

下面,我结合一个具体的例子,来给出其具体的计算方法:

import torch

torch.manual_seed(100)
x=torch.randn(2,3)
index = torch.LongTensor([[0,1,1]])
a = torch.gather(x, 0, index)

此时,x的值为:

tensor([[ 0.3607, -0.2859, -0.3938],
        [ 0.2429, -1.3833, -2.3134]])

a的值为:

tensor([[ 0.3607, -1.3833, -2.3134]])

下面,我们来看看其具体的计算过程:
第一步,获得index的index
index的值为[[0,1,1]],其每个元素对应的index为:
(0, 0)、(0, 1)、(0, 2)

第二步,看dim的值
在调用gather函数的时候,需要指定dim的值,此处我们的dim值为0

第三步,根据dim用index的值来替换第一步中得到的对应维度的值
因为dim=0,因此我们用index的值,来代替第一步中得到的索引中第一个维度的值,替换后的值为:
(0, 0)、(1, 1)、(1, 2)

第四步,根据第三步得到的新的索引值,在input中进行取值
input中的(0, 0)、(1, 1)、(1, 2)分别对应了 0.3607、 -1.3833、 -2.3134,而这正是我们得到的计算结果。

相关文章

网友评论

      本文标题:pyTorch学习笔记——gather函数详解

      本文链接:https://www.haomeiwen.com/subject/qmfeoltx.html