layout | title | date | summary | categories |
---|---|---|---|---|
post |
RoPE 位置编码算法详解 |
2024-10-24 13:00:00 -0700 |
旋转位置编码(Rotary Position Embedding,RoPE)是论文 Roformer Enhanced Transformer With Rotray Position Embedding 提出的一种能够将相对位置信息依赖集成到 self-attention 中并提升 transformer 架构性能的位置编码方式。 |
Transformer |
旋转位置编码(Rotary Position Embedding,RoPE
)是论文 Roformer: Enhanced Transformer With Rotray Position Embedding 提出的一种能够将相对位置信息依赖集成到 self-attention 中并提升 transformer 架构性能的位置编码方式。
和相对位置编码相比,RoPE 具有更好的外推性,目前是大模型相对位置编码中应用最广的方式之一。这里的外推性实质是一个训练和预测的文本长度不一致的问题。具体来说,不一致的地方有两点:
- 预测的时候用到了没训练过的位置编码(不管绝对还是相对);
- 预测的时候注意力机制所处理的 token 数量远超训练时的数量。
RoPE 的核心思想是将位置编码与词向量通过旋转矩阵相乘,使得词向量不仅包含词汇的语义信息,还融入了位置信息,其具有以下优点:
- 相对位置感知:RoPE 能够自然地捕捉词汇之间的相对位置关系。
- 无需额外的计算:位置编码与词向量的结合在计算上是高效的。
- 适应不同长度的序列:RoPE 可以灵活处理不同长度的输入序列。
三角函数、旋转矩阵、欧拉公式、复数等数学背景知识可以参考这篇文章学习。
1,torch.outer
函数作用:torch.outer(a, b) 计算两个 1D 向量 a 和 b 的外积,生成一个二维矩阵,其中每个元素的计算方式为:
即,矩阵的第 i 行、第 j 列的元素等于向量 a 的第 i 个元素与向量 b 的第 j 个元素的乘积。
外积(outer product)是指两个向量 a 和 b 通过外积操作生成的矩阵:
其中
>>> a = torch.tensor([2,3,1,1,2], dtype=torch.int8)
>>> b = torch.tensor([4,2,3], dtype=torch.int8)
>>> c = torch.outer(a, b)
>>> c.shape
torch.Size([5, 3])
>>> c
tensor([[ 8, 4, 6],
[12, 6, 9],
[ 4, 2, 3],
[ 4, 2, 3],
[ 8, 4, 6]], dtype=torch.int8)
2,torch.matmul
可以处理更高维的张量。当输入张量的维度大于 2 时,它将执行批量矩阵乘法。
>>> A = torch.randn(10, 3, 4)
>>> B = torch.randn(10, 4, 7)
>>> C = torch.matmul(A, B)
>>> D = torch.bmm(A, B)
>>> assert C.shape == D.shape # shape is torch.Size([10, 3, 7])
>>> True
3,torch.polar
# 第一个参数是绝对值(模),第二个参数是角度
torch.polar(abs, angle, *, out=None) → Tensor
构造一个复数张量,其元素是极坐标对应的笛卡尔坐标,绝对值为 abs,角度为 angle。
# 假设 freqs = [x, y], 则 torch.polar(torch.ones_like(freqs), freqs)
# = [cos(x) + sin(x)j, cos(y) + sin(y)j]
>>> angle = torch.tensor([np.pi / 2, 5 * np.pi / 4], dtype=torch.float64)
>>> z = torch.polar(torch.ones_like(angle), angle)
>>> z
tensor([ 6.1232e-17+1.0000j, -7.0711e-01-0.7071j], dtype=torch.complex128)
>>> a = torch.tensor([np.pi / 2], dtype=torch.float64) # 数据类型必须和前面一样
>>> torch.cos(a)
tensor([6.1232e-17], dtype=torch.float64)
4,torch.repeat_interleave
# 第一个参数是输入张量
# 第二个参数是重复次数
# dim: 沿着该维度重复元素。如果未指定维度,默认会将输入数组展平成一维,并返回一个平坦的输出数组。
torch.repeat_interleave(input, repeats, dim=None, *, output_size=None) → Tensor
返回一个具有与输入相同维度的重复张量
>>> keys = torch.randn([2, 12, 8, 512])
>>> keys2 = torch.repeat_interleave(keys, 8, dim = 2)
>>> keys2.shape
torch.Size([2, 12, 64, 512])
>>> x
tensor([[1, 2],
[3, 4]])
>>> torch.repeat_interleave(x, 3, dim=1)
tensor([[1, 1, 1, 2, 2, 2],
[3, 3, 3, 4, 4, 4]])
>>> torch.repeat_interleave(x, 3)
tensor([1, 1, 1, 3, 3, 3, 4, 4, 4, 5, 5, 5])
注意重复后元素的顺序,以简单的一维为例 x = [a,b,c,d]
,torch.repeat_interleave(x, 3)
后,结果是 [a,a,a,b,b,b,c,c,c,d,d,d]
。
设 token
对应的词向量 token
)之后的 key
和 value
向量,$q_m、k_n、v_n$ 的表达用如下公式:
注意,这里的
$f_q$ 其实是把$\text{embedding}_\text{vector} \times W_q$ 的矩阵乘法过程包含进去了,至于为什么要这样构造,下文会讲。
其中函数
方程 (1) 的一种常见选择是:
其中,$p_i \in \mathbb{R}^d$ 是与 token
其中,
RoPE 论文提出为了能利用 token 之间的相对位置信息($m-n$),假定 query 向量
注意,这里只有
$f_q(x_m, m)$ ,$f_k(x_n, n)$ 是需要求解的函数,$\langle \rangle$ 表示内积操作,而对于$g$ ,我们要求是表达式中有$x_m, x_n, (m-n)$ ,也可以说是$q_m, k_n$ 的内积会受相对位置$m-n$ 影响。
接下来的目标就是找到一个等价的位置编码方式
假设现在词嵌入向量的维度是两维
其中 ( Re ) 表示复数的实部,( (W_k x_n)^* ) 表示 ( (W_k x_n) ) 的共轭复数。
1,query
向量乘以了一个旋转矩阵,即:
2,key
向量乘以了一个旋转矩阵,即:
3,同样可得
公式(9)的证明可通过旋转矩阵性质得到,先将公式 (9) 抽象成
上述推导过程分别应用了:展开内积、矩阵乘法的结合律、旋转矩阵性质1、旋转矩阵性质2。
前面的公式推导,是假设的词嵌入维度是 2 维向量,将二维推广到任意维度,$f_{{q,k}}$ 可以表示如下:
其中,$R_{\Theta, m}^d$ 为
[sqe_len, dim//2]
。$可以看出,对于
将 RoPE 应用到前面公式(2)的 Self-Attention 计算,可以得到包含相对位置信息的Self-Attetion:
其中,
Rotary Position Embedding(RoPE) 实现的可视化如下图所示:
最后总结结合 RoPE 的 self-attention 操作的流程如下:
- 首先,对于
token
序列中的每个词嵌入向量,都计算其对应的 query 和 key 向量; - 然后在得到 query 和 key 向量的基础上,应用公式(7)和(8)对每个
token
位置都计算对应的旋转位置编码; - 接着对每个
token
位置的 query 和 key 向量的元素按照两两一组应用旋转变换; - 最后再计算
query
和key
之间的内积得到 self-attention 的计算结果。
先看旋转矩阵用于旋转一个二维向量过程示例:
但是 Llama 模型的嵌入维度高达 4096,比二维复杂得多,如何在更高维度的嵌入上应用旋转操作呢?通过 RoPE 算法原理我们知道,嵌入向量的旋转实际是将每个嵌入向量元素位置
RoPE 通过实现旋转矩阵,是既捕获绝对位置信息,又结合相对位置信息的方式(论文公式有更详细体现)。
图中每组的旋转角度计算方式如下:
在实现 RoPE 算法之前,需要注意:为了方便代码实现,在进行旋转之前,需要将旋转矩阵转换为极坐标形式,嵌入向量($q$、$k$)需要转换为复数形式。完成旋转后,旋转后的嵌入需要转换回实数形式,以便进行注意力计算。此外,RoPE 仅应用于查询(Query)和键(Key)的嵌入,不适用于值(Value)的嵌入。
通过仔细阅读和一步步分析了 llama
官方代码后,会发现作者直接转化为复数相乘形式来计算
所以,作者在 RoPE
算法实现中,没有使用矩阵相乘的形式,而是把旋转角度张量和 RoPE
算法实现的流程和代码理解的难点如下:
- 如何生成旋转角度
$\theta$ 向量,$\Theta = { \theta_i = 10000^{-2(i-1)/d}, i \in \left [1, 2, \dots, d/2 \right ] }$ ; - 如何将旋转角度和
token
位置索引相乘,并构造一个矩阵,该矩阵包含了每个位置和每个维度对应的旋转角度。 - 得到所有
token
位置和其对应旋转角度后,如何转换为复数形式$e^{im\theta}$ 的旋转矩阵; - 如何对
RoPE
函数的输入参数x_q
做形状变换,并将实数张量转换为复数张量形式; - 两个复数张量相乘(应用旋转操作)后,转换回实数域,并恢复原始形状。
最后,如果你直接看 pytorch
代码,其实很难理解 rope
是如何应用相对位置信息的,这个只能通过前面的公式推导才能理解。
结合 llama 官方实现代码,下述是我修改优化和添加注释后的代码,更容易看懂:
def compute_theta(dim: int, base: float = 10000.0, device: torch.device = torch.device('cpu')) -> torch.Tensor:
"""
计算旋转位置编码中的 Theta 角度值。
参数:
- d (int): 嵌入向量的维度(必须为偶数)。
- base (float): 基础频率参数, 默认为10000.0。
- device (torch.device): 计算设备, 默认为CPU。
返回:
- torch.Tensor: 包含Theta值的1D张量, 形状为 [d/2]。
"""
if dim % 2 != 0:
print("嵌入维度 dim 必须为偶数")
i = torch.arange(1, (dim//2) + 1, dtype=torch.float32, device=device)
theta_i = base ** (-2*(i - 1) / dim)
return theta_i
def precompute_freqs_cis(dim: int, seq_len: int, base: float = 10000.0, device: torch.device = torch.device('cpu')):
theta = compute_theta(dim, base, device) # theta 角度值序列,向量, 大小为 dim // 2
m = torch.arange(seq_len, device=device) # # token 位置值序列,向量,大小为 seq_len
m_theta = torch.outer(m, theta) # 所有 token 位置的所有 Theta 值范围, 矩阵,尺寸为 [seq_len, dim // 2]
freqs_cis = torch.polar(torch.ones_like(m_theta), m_theta) # e^{i*m*\theta},本质上是旋转矩阵
return freqs_cis
def reshape_for_broadcast(freqs_cis, x):
ndim = x.ndim
assert ndim >= 2
assert freqs_cis.shape == (x.shape[1],x.shape[-1]), "the last two dimension of freqs_cis, x must match"
shape = [d if i==1 or i==ndim-1 else 1 for i,d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, device: torch.device = torch.device('cpu')):
"""
参数:
- x_q(torch.Tensor): 实际上是权重 W_q * 词嵌入向量值, 来自上一个线性层的输出, 形状为 [batch_size, seq_len, n_heads, head_dim]
- x_k(torch.Tensor): 实际上是权重 W_k * 词嵌入向量值, 来自上一个线性层的输出, 形状为 [batch_size, seq_len, n_heads, head_dim]
- freqs_cis (torch.Tensor): 频率复数张量, 形状为 [max_seq_len, head_dim]
返回:
- Tuple[torch.Tensor, torch.Tensor]: 旋转编码后的查询和键张量
"""
# 实数域张量转为复数域张量
xq_reshape = xq.reshape(*xq.shape[:-1], -1, 2) # [batch_size, seq_len, dim] -> [batch_size, seq_len, dim//2, 2]
xk_reshape = xk.reshape(*xk.shape[:-1], -1, 2) # [batch_size, seq_len, dim] -> [batch_size, seq_len, dim//2, 2]
xq_complex = torch.view_as_complex(xq_reshape) # 复数形式张量
xk_complex = torch.view_as_complex(xk_reshape) # 复数形式张量
# 旋转矩阵(freqs_cis)的维度在序列长度(seq_len,维度 1)和头部维度(head_dim,维度 3)上需要与嵌入的维度一致。
# 此外,freqs_cis 的形状必须与 xq 和 xk 相匹配,因此我们需要将 freqs_cis 的形状从 [max_seq_len, head_dim] 调整为 [1, max_seq_len, 1, head_dim]。
freqs_cis = reshape_for_broadcast(freqs_cis, xq_complex) # [max_seq_len, 1, 1, dim // 2]
# 应用旋转操作,并将结果转回实数域。# flatten(2) 将后面两个维度压成一个维度
xq_out = torch.view_as_real(xq_complex * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_complex * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
结合 rope 位置编码的 attention 结构的完整代码在这里,运行后,单元测试结果如下所示:
test_compute_theta passed.
test_precompute_freqs_cis passed.
test_apply_rotary_emb passed, xq_out and xq [0][0][0][0]: -1.3532123565673828 -1.3532123565673828.
test_attention passed.
transformers 库提供的 llama rope 实现在这里