美文网首页
手写一下正弦编码和旋转位置编码的代码?

手写一下正弦编码和旋转位置编码的代码?

作者: bd7e4a65be2b | 来源:发表于2024-11-08 16:25 被阅读0次

获取更多面试真题的集合,请移步至 https://pica.zhimg.com/80/v2-7fd6e77f69aa02c34ca8c334870b3bcd_720w.webp?source=d16d100b

面试中不太可能让写出整体的代码,只要写出核心代码就可以了。

  • 正弦位置编码
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
two_i = torch.arange(0, d_model, 2, dtype=torch.float32)
div_term = torch.exp(two_i * -(math.log(10000.0) / d_model))
encodings = torch.zeros(max_len, d_model)
encodings[:, 0::2] = torch.sin(position * div_term)
encodings[:, 1::2] = torch.cos(position * div_term)

上述代码中的一些变量的含义这里就不作介绍了,懂得都懂。

  • 旋转位置编码
    这个部分比较复杂一些,其中重要的逻辑是在计算旋转变换矩阵那里,中间设计到复数域的变化,比较麻烦,建议调试一下代码看下整体的逻辑:
# 生成旋转矩阵
def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
    # 计算词向量元素两两分组之后,每组元素对应的旋转角度\theta_i
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    # 生成 token 序列索引 t = [0, 1,..., seq_len-1]
    t = torch.arange(seq_len, device=freqs.device)
    # freqs.shape = [seq_len, dim // 2] 
    freqs = torch.outer(t, freqs).float()  # 计算m * \theta

    # 计算结果是个复数向量
    # 假设 freqs = [x, y]
    # 则 freqs_cis = [cos(x) + sin(x)i, cos(y) + sin(y)i]
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs) 
    return freqs_cis

# 旋转位置编码计算
def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    # xq.shape = [batch_size, seq_len, dim]
    # xq_.shape = [batch_size, seq_len, dim // 2, 2]
    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2)
    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 2)
    
    # 转为复数域
    xq_ = torch.view_as_complex(xq_)
    xk_ = torch.view_as_complex(xk_)
    
    # 应用旋转操作,然后将结果转回实数域
    # xq_out.shape = [batch_size, seq_len, dim]
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2)
    return xq_out.type_as(xq), xk_out.type_as(xk)

class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        self.wq = Linear(...)
        self.wk = Linear(...)
        self.wv = Linear(...)
        
        self.freqs_cis = precompute_freqs_cis(dim, max_seq_len * 2)

    def forward(self, x: torch.Tensor):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(batch_size, seq_len, dim)
        xk = xk.view(batch_size, seq_len, dim)
        xv = xv.view(batch_size, seq_len, dim)

        # attention 操作之前,应用旋转位置编码
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
        
        # scores.shape = (bs, seqlen, seqlen)
        scores = torch.matmul(xq, xk.transpose(1, 2)) / math.sqrt(dim)
        scores = F.softmax(scores.float(), dim=-1)
        output = torch.matmul(scores, xv)  # (batch_size, seq_len, dim)
  # ......

参考:
[1] https://zhuanlan.zhihu.com/p/689205140
[2] https://zhuanlan.zhihu.com/p/647109286

本文由mdnice多平台发布

相关文章

  • 树莓派基础实验26:旋转编码器实验

    一、介绍    旋转编码器是一种机电装置,可将轴或轴的角位置或运动,转换为模拟或数字代码。旋转编码器通常放置在垂直...

  • 旋转编码器安装不得不关注的三大事项

    旋转编码器是编码器的一种,旋转编码器作为一种重要的测试装置,被广泛应用于很多场合中.旋转编码器的工作原理是基于不同...

  • 地理编码与反编码

    首先我们要了解地理编码和反编码的含义和作用:<1>地理编码:把地名转换成位置信息作用:把文字描述的 位置转换成地图...

  • iOS开发之CoreLocaiton框架使用(地理编码,反地理编

    什么是地理编码和反地理编码? 地理编码 地理编码:根据给定的地名,获得具体的位置信息(比如经纬度、地址的全称等)。...

  • 地理编码与反地理编码

    使用CLGeocoder可以完成“地理编码”和“反地理编码” 地理编码:根据给定的地名,获得具体的位置信息(比如经...

  • Python风格规范

    为了便于项目的管理和代码的阅读,养成良好的编码风格以及沟通方便,编码Python代码时应遵循以下编码规范: 每行长...

  • 位置编码

    https://blog.csdn.net/qq_27590277/article/details/1062644...

  • 位置编码

    Transformer: 不可学习位置编码 可以看出不同的column(j)之间,周期不一样 不同的row(i)之...

  • python批量查看修改文件编码

    使用python批量查看文件编码,或者批量修改文件编码 代码 结果 查看文件编码 执行编码转换 再次查看转换后的编码

  • 地图-->地名VS地理坐标(根据"北京"

    地理编码 除了提供位置跟踪功能之外,在定位服务中还包含CLGeocoder类用于处理地理编码和逆地理编码(又叫反地...

网友评论

      本文标题:手写一下正弦编码和旋转位置编码的代码?

      本文链接:https://www.haomeiwen.com/subject/eaaadjtx.html