本文介绍的剪枝算法是 Cost Complexity Pruning 又叫 Weakest Link Pruning. 本文介绍的是关于连续型变量的回归树的 Pruning 算法, 但是对于离散型变量,以及 Classification Tree 只要将相关指标进行替代即可
SSR
SSR : Sum of Square Residual
-
: 假设一棵树有
个分支。
-
: 第 i 个observation 的 reponse 值。
-
在第j 个分支上,所有预测值的均值。
SSR 是计算一棵树的数据预测能力的指标, 它代表了一棵树对于所有数据预测值与实际值只差。 SSR 越大树的预测错误越大。
SSR 会随着树变深而减少, 而且是单调减少。 如果以SSR 为Cost Fucntion 就会导致树为每个 Observation 划分一个branch, 那么整个模型就会走向过拟合
Tree Scroe
在SSR 基础上, 派生出 Tree Scroe , 用来给一棵树打分
也就是
其中
-
一个超参数,取值为 0 -
- T 是这棵树的节点个数
其中 类似惩罚项, 当
很大的时候, 就会迫使一个树只留下根节点, 而
为0的时候, 树越深Tree Scroe 就会越高。
剪枝算法
1. 获得可选
集合
- 使用数据为所有的 Training Data
从开始(
),构建出一个最深的树
。 然后不断增加
的值。 随着
增加,
也会增加, 在
等于一个值的时候, 会出现一个比较浅的树
, 它的
, 那么就把此时的
记为
。 不断提高
的值,直到只剩下一个树根。 如果
有
个节点, 此时, 我们就可以得到
颗树和每棵树对应的
。这
个
的集合, 记为
。
2. Cross Validation
将Training Data 进行 Fold Cross Validation拆分。 在每次使用中, 就得到一个新的数据集合
- 新训练集合, 由
个fold 的数据组成
- 新测试集合, 由
个fold 的数据组成
2.1 用
中每个
训练树,并获得其SSR
对于 Fold Cross Validation, 每个Fold 都进行如下操作
- 利用 “新训练集合”的数据, 以及
中的
个
, 来获得
颗树。
- 利用 “新测试集合” 测试本轮中
颗树的 SSR(注意,不是 Tree Score)
- 找到一个SSR 最低的树, 并记录其
2.2 选出一个最好的
步骤 2.1 会执行 次, 并获得
个
(可能有重复)。 表现最好(重复次数)最多的
, 将其定为最终的
。
3. 获得最终的Tree
将第1步中产生的 颗树中, 对应
的树定位 Final Model
总结
算法主要分如下几步:
- 利用 Tree Scroe 选出有限个不同的
- 利用 Cross Validation 和 SSR 从有限个
中选出最优的
- 找到
对应的树, 为最终的模型
参考资料
- https://www.youtube.com/watch?v=D0efHEJsfHo
- An Introduction to Statistical Learning














网友评论