我们直接来看 Swin 在窗口注意力中使用的公式:
公式本身在形式上和 T5 是完全相同的,关键在于偏置矩阵 的构造上。
我们分点来展开:
2.1 直接将 RPE 推广到二维#
我们先来看看最直接的方法:
对于一个 的窗口,直接设计 ,其中 表示窗口内第 个 patch 和第 个 patch 之间的偏置值。
我们用一个简单的例子来演示为什么是 ,假设窗口大小: ,那么窗口就是:
现在,每个 token 都要和另外所有 token 建立关系。那么 计算的注意力得分矩阵形状就是这样的:
偏置矩阵必须和注意力矩阵一一对应。所以 。
这种方法当然是可以跑通的,但我们要考虑二维带来的参数问题:
如果直接学习一个 的参数矩阵,那每个注意力头就得维护 个参数。一个 Swin 有多个头和多个层,累计下来参数巨大。
因此, Swin 自然有对应的改进。
2.2 空间关系的平移不变性#
在 NLP 中,我们只针对每种相对位置设计偏置,但是在上面方案里,你会发现直接推广会带来很多无意义的参数,核心是因为:
在二维数据中相对逻辑更加凸显,窗口内大量位置对其实拥有相同的相对偏移。
比如,patch (0,0) 和 (1,0) 之间的偏移是 ,而 patch (2,0) 和 (3,0) 之间的偏移同样是 。
它们本质上描述的是同一种空间关系,理应共享同一个偏置值。
于是 Swin 的做法是:推广相对逻辑,不直接学习 ,而是学习一个小得多的偏置表,再通过二维索引从中查值。
3. 紧凑偏置表与查表逻辑#
3.1 二维相对位置的计算#
首先,对于一个 的窗口,给每个位置一个坐标 ,显然:
对于任意两个 patch ,二维相对偏移是:
那么, 的取值范围就是 ,一共 种可能。
同理,这部分的计算逻辑和 T5 是完全相同的。
现在,我们知道了:所有可能的 组合一共有 种,也就是说:
我们只需要一个 的偏置表,就能覆盖窗口内所有可能的位置关系。
这就是 Swin 的紧凑偏置表 :
建表本身的逻辑到此结束,但现在还有一个小问题:
和 大小不一,对于每组注意力计算,我要如何查表注入相应偏置?
3.2 查表过程#
其实这步可以理解为:如何将 内的值映射到总公式里的 中?
首先,前面我们已经知道了:
因此,真正参与 Attention 计算的偏置矩阵 ,也必须是 。
但我们刚刚学习的紧凑偏置表只有:
不难理解,为了让二者适配,Swin 的设计是这样的:
对于 Attention Matrix 中的每一个元素,都先计算两个 patch 的相对位移,再去 中查对应 bias。
展开来说, 中的每一个元素本质上都对应“一对 patch 的关系”,而每一对 patch 都有自己的 ,因此,我们可以计算相对位移,实现查表取值:
这就实现了相同相对位移的 patch 对,共享同一个偏置。
不过这在实现中还有一个问题:
数组索引没有负数,负偏移并不能和其索引直接对应。
而 ,因此 Swin 会先做一次平移去寻找正确索引:
现在:
于是查表过程就变成:
字母还是有些抽象,我们再举一个实例:设 ,那么 patch 网格可以就是:
此时 ,因此 ,如果当前 patch 为 ,它去关注 ,那么:
,
现在,我们需要查:
显然,数组索引不能为负数。所以进行平移:
于是:
原本的 ,就被平移成 :
这里可能容易疑惑的一点是:
中存储的并不是“偏移坐标本身”,而是“对应相对位移的偏移参数”。
展开来说:数学意义上的 会被映射到数组索引,因此, 实际存储的就是相对位移为 时对应的偏置。
这样,所有原本可能为负数的二维位移都被映射到了合法数组索引,可以稳定完成查表。
最终所有 patch 两两之间都会完成一次查表从而动态构造出完整的偏置矩阵:
随后:
即可完成二维相对位置信息的注入。
值得一提的是在具体实现中,二维紧凑表会被展平成一维,以类似“编号”的逻辑取值,根本逻辑没变,明白即可。
3.3 参数对比#
来看看两种方式的参数对比:
| 方式 | (Swin 默认) | |
|---|---|---|
| 暴力直接法 | ||
| Swin 紧凑法 | ||
| 压缩比 | 约 14 倍 | 约 53 倍 |
很明显,随着窗口增大,紧凑表的优势会更加明显。