原文链接:https://arxiv.org/abs/1910.02653
我个人认为非常优美的一篇论文,居然没人发过笔记,太奇怪了,那就我自己稍微写写。分段和原文不同,但公式标号基本是一样的。
1. background
这篇论文要解决的问题是训练过程中显存不够,如何节约开销。一个非常自然的想法就是把不用的显存释放掉,从而节约开销。这在推理过程中是很有用的,很自然能发现在线性网络(VGG等)中,这可以让显存开销达到的水平。然而在训练过程中,由于反向计算梯度的过程中这些中间结果还会被用到,当backward开始时,显存开销为
(n为算子个数,因此也近似是中间张量的个数)。
在Checkpoint这篇论文中,给出了一个自然的思路:只保留一部分中间结果,对于其它中间结果,当需要时再用保留到的来计算。如果计算的开销很小,这个trade-off是值得的(BTW,这些计算通常可以被fuse到其它算子上)。典型代表就是CNN中常见的Conv-Batch Norm-Noise-Activate,除卷积层外的算子都是开销低的。
Checkpoint只讨论了形如VGG的线形网络(即input依次经过算子1、2、3...没有其它路径):一言蔽之,将网络分段(segment),每个segment保存最后一个算子的输出,其输入为上一个segment的输出,在backward计算梯度的时候,segment内部的算子会重新跑一遍,用重新跑的结果来算梯度。
这样,我们至多将网络再跑了一遍,但可以将显存开销降到:每
个算子分成一个segment。实际应用中不会这么好。
结合实际需求,作者又提出了两个算法,一个是用户提供指示函数,用来表示每个算子可否被重用(虽然原文写的
的值域是自然数,但实际只能是
,否则原文中的map
还需要修正),另一个是提供显存上界
和重计算高开销的算子集合
(在
的值域修改成
后,
和
是一样的)。后者会根据显存上界贪心地确认最终的
,然后调用前者。(这个贪心也有点问题,实际上会可能超过显存上界)
这个算法可以进一步优化:一个segment可以视为一个子图,从而可以递归调用该算法。(这样最终的终于大于1了,尽管在原文alg.2里它还是只在
上)
虽然上面一段说到有一些奇怪的问题,但它已经的确在torch.util.checkpoint和MXNet中用到了。最终性能不会有明显下降,但显存开销有明显优化。
2. intro
事情没有到此结束。上文的算法在VGG形状的网络里没什么大问题,值域假了都无所谓。但在ResNet甚至ResNeXt广泛使用后,网络不再只是线形,而是有更多样的形状,这时segment的观点就不适用了。
Checkmate解决了这个问题(适用于任意图),同时修复了的值在非递归时不大于1的问题。文章中还提到它可以兼顾跨层保存的张量和层内计算中的张量,不过Checkpoint也可以做到,因此不是一个nontrivial的特性。总的来说,Checkmate将segment方法拓广为rematerialization,即有需求时再重新计算。
Checkmate的思想简单且优美,即利用线性规划。概括来说,它的主要思路是:
- 每个时间点的显存使用量是关于哪些张量目前被保存的线性函数;
- 张量在每个时间点是否重新计算、保存或释放之间是布尔表达式的关系。对应的布尔不等式可以转换为线性不等式,这是论文中最优美的一点。
- 重新计算的代价与每个时间点是否重新计算是线性关系,且需要尽可能小,因此将其作为线性规划的目标。
3. method
然后我们就可以开始列线性规划的方程了:
3.1 basic
首先引入符号:为计算图,包含了backward的部分。图上的点为
,其下标符合拓扑序。对于
的输出,其显存占用为
,计算的代价为
。假设算子都是单一输出,从而可以用
来表示
算子所输出的张量。
接着引入两组布尔变量,其中
表示在阶段
,
是否重计算,
表示
是否会从上一阶段保留到当前阶段。
由此,我们可以写出最基本的约束条件:张量的计算要满足依赖关系、张量的保存需要首先经过计算、边界条件。这会被写成:
subject to:
最优化目标为计算的总代价最少,其约束的含义依次为:
- 计算
前,考虑其依赖的
:若在之前的阶段计算,则需要保存到该阶段(
)。在本阶段计算(
),则需要在本阶段保存到
的计算,这点会体现在后续的显存约束;
- 张量为了保存要么在本阶段计算,要么在上一阶段保存至本阶段;
- 初始状态下所有张量都没有保存;
- 最终整个图都被计算过,保证有输出;
- R,S的值域限制
3.2 memory: intro of U
然后写出显存占用的表达式。显存占用要满足的是每个时刻(注意不是阶段)的占用都不大于一个给定常数。
首先还是引入符号:使用表示阶段
中,算子
的重计算被考虑后的显存使用量。由此,我们可以写出显存上限对应的约束是
且有
其中前两项是假设输入、参数和参数的梯度都留有保存空间,均为常数。
3.3 memory: intro of FREE
为了约束每个阶段中间过程的显存使用,继续引入符号:
使用表示阶段
中,计算
后,
的显存是否被释放;
使用表示
在图上的前驱,即
,使用
表示图上的后继。
因此有:
即算子计算后考虑其前驱是否可以释放。
其中判断是否释放的函数为
这个式子可以用布尔表达式的观点来看:计算后释放
的条件是
被计算(
)、且不再需要保存
(
)、且后续依赖
的张量(
)都不会计算(
)。
由于在一个stage至多计算一次、因此至多释放一次,有约束
3.4 analysis of FREE
到上一步,显存的约束已经写完了。但是需要注意到,不是线性表达式,而是由布尔与运算构成的乘法。
解决这个问题是这篇文章最优美的地方。它使用了下述两个引理:
Lemma 4.1 If , then
Lemma 4.2 If , then
if and only if
and
这两个引理都很容易证明。将4.1应用在上,有
其中
至此,乘法已经转换为加法。然而它还在示性函数的条件上,需要使用4.2进一步变换:
相当精彩。
3.5 conclusion and prune
因此,我们可以将原问题转化为:
优化1a,约束条件为1b, 1c, 1d, 1e, 1f, 2a, 2b, 3, 4, 6', 7a, 7b, 7c。
注意这是整数线性规划(ILP),属于NP问题,同时FREE的下标表明参数量至多为(当然现实场景的计算图通常会稀疏很多),因此需要一些优化:
首先是根据拓扑序,指定每个阶段计算的算子,从而将条件下三角化:
这三个条件用来替换原条件中的两个边界条件,
。
论文中实验表明在八层的网络(加上backward后共有17个算子),原问题需要9.4小时,新问题只需要0.23秒。
然后观察到时,可以直接令
。因此将
中求和范围里的
删去,假设
。这可以减少
个变量。
3.6 approximation
ILP是NP问题,所以还是要考虑近似算法。作者采用的是将原问题的整数域放到实数域
上来做,然后获得一个近似解
。将
近似为
:
然后通过来计算对应的
,即在
处填补对应的
或
,然后再考虑计算
所依赖的张量是否保存,若未保存则在该阶段添加其计算。
该近似过程可能让显存约束条件被违背,解决策略是在求实数解的时候把原本的约束改为
,
时实验效果较好。
网友评论