美文网首页
[PyTorch] register_hook

[PyTorch] register_hook

作者: VanJordan | 来源:发表于2019-06-07 19:28 被阅读0次
  • 每一个tensor都有register_hook方法,每次当关于这个参数的gradient被计算出来以后都会调用这个方法,因此可以用于debug等等,下面是对一部分梯度进行mask
    def _emb_hook(self, grad):
        return grad * Variable(self.grad_mask.unsqueeze(1)).type_as(grad)

    def set_grad_mask(self, mask):
        self.grad_mask = torch.from_numpy(mask)
        self.embedding.weight.register_hook(self._emb_hook)

相关文章

网友评论

      本文标题:[PyTorch] register_hook

      本文链接:https://www.haomeiwen.com/subject/htsvxctx.html