K近邻算法,不具有显示的学习过程,利用训练数据集对特征向量空间进行划分。算法关键在于使用特殊的数据结构来存储训练数据,减少测试阶段的计算次数。这个特殊的数据结构就是kd树。
本文主要是根据这份代码来讲述如何构建kd树并找到测试数据的k个近邻点。主要看576行的create()函数和203行的KDNode类,其中KDNode包含6个属性,分别是data(样本数据)、left(左孩子)、right(右孩子)、axis(当前节点用于划分数据的特征index)、sel_axis(子节点用于划分数据的特征index)和dimensions(特征总数量)。
kd树的构建主要是递归调用create()函数,函数返回的是kd树的头节点(KDNode)。具体步骤如下:
- 输入数据集point_list和当前阶段用于划分数据的特征序号axis后,create()会先确定下一阶段用于划分数据的特征序号sel_axis。
sel_axis = sel_axis or (lambda prev_axis: (prev_axis+1) % dimensions)
可见sel_axis要么是直接指定,要么是当前特征序号+1。当然如果想要得到平衡的kd树,可以计算划分后数据中有着最大方差的特征,将其对应序号作为下一阶段的划分特征序号直接传给sel_axis。
- 由于kd树的创建是是递归调用create()函数,所以需要加上create()的终止条件,那就是训练数据使用完了,此时返回kd树的叶子节点(也是KDNode,但是没有属性data、left和right的输入)。
if not point_list:
return KDNode(sel_axis=sel_axis, axis=axis, dimensions=dimensions)
- 将训练数据按照划分特征的对应值进行排序。
point_list.sort(key=lambda point: point[axis])
- 取中位数位置的样本,将其保存到该阶段创建的KDNode中。同时将输入训练数据划分成两部分,并分别调用create()函数创建新的子kd树。
median = len(point_list)
loc = point_list[median]
left = create(point_list[:median], dimensions, sel_axis(axis))
right = create(point_list[median + 1:], dimensions, sel_axis(axis))
return KDNode(loc, left, right, axis=axis, sel_axis=sel_axis, dimensions=dimensions)
kd树构建的完整代码如下:
def create(point_list=None, dimensions=None, axis=0, sel_axis=None):
if not point_list and not dimensions:
raise ValueError("either point_list or dimensions must be provided")
elif point_list:
dimensions = check_dimensionality(point_list, dimensions)
if not point_list:
return KDNode(sel_axis=sel_axis, axis=axis, dimensions=dimensions)
point_list = list(point_list)
point_list.sort(key=lambda point: point[axis])
median = len(point_list) // 2
loc = point_list[median]
left = create(point_list[:median], dimensions, sel_axis(axis))
right = create(point_list[median+1:], dimensions, sel_axis(axis))
return KDNode(loc, left, right, axis=axis, sel_axis=sel_axis, dimensions=dimensions)
kd树构建好后,就可以进行基于kd树的最近邻搜索:
- 将输入测试样本不停递归下移,直到子节点是叶节点为止(不能移到叶节点,因为叶节点是不存数据的)。此时利用当前节点的值与输入样本求一个距离,将这个距离定位“最近距离”。
- 递归往上回退,在每个节点上进行如下操作:
- 计算父节点与输入样本的距离,如果这个距离小于“最近距离”,那么就将这个新的距离设为“最近距离”;
- 判断该节点的另外一个子节点中是否存在“最近距离”。如果存在,就在另外一个子节点中递归计算并更新“最近距离”,如果不存在则继续往上回退。
- 直到回退到根节点,搜索结束,返回“最近距离”。
这一过程的实现可以看源代码431行的_search_node()函数。_search_node()函数的输入参数包括point(测试样本)、k(KNN中的K)、results(保存的是KDNode)、get_dist(计算距离的方式,是函数名)和counter(用于计数)。具体步骤如下:
- 首先计算当前遍历节点的距离。
nodeDist = get_dist(self)
因为_search_node()是类KDNode的方法,所以上面代码中的self就表示了当前遍历节点。
- 将1步中计算的距离nodeDist插入results中。如果results没有填满(数量少于k个),则直接插入;如果results已经填满,则需要判断nodeDist是否小于results中的最大距离,如果小于就将results中的最大距离弹出,并将nodeDist插入results中。
item = (-nodeDist, next(counter), self)
if len(results) >= k:
if -nodeDist > results[0][0]:
heapq.heapreplace(results, item)
else:
heapq.heaprush(results, item)
- 然后计算一个距离plane_dist2,用于判断后面是否需要对节点的另外一个子节点(样本不在的子节点)进行计算nodeDist。代码中plane_dist2的实现就是将输入样本point和节点数据self.data在切分维度上的值求平方差。实现代码如下:
split_plane = self.data[self.axis]
plane_dist = point[self.axis] - split_plane
plane_dist2 = plane_dist * plane_dist
- 将当前节点移动到和输入样本point在同一区域的子节点上,递归调用_search_node()函数。代码实现如下:
if point[self.axis] < split_plane:
if self.left is not None:
self.left._search_node(point, k, results, get_dist, counter)
else:
if self.right is not None:
self.right._search_node(point, k, results, get_dist, counter)
- 根据3步中计算的plane_dist2,实现在另外一个子节点的nodeDist的计算。
if -plane_dist2 > results[0][0] or len(results) < k:
if point[self.axis] < self.data[self.axis]:
if self.right is not None:
self.right._search_node(point, k, results, get_dist, counter)
else:
if self.left is not None:
self.left._search_node(point, k, results, get_dist, counter)
代码实现的是判断plane_dist2是否比results中的最大值小,如果小的话,就需要在另外一个节点中继续计算nodeDist。其实这个判断的依据就是看“以目标点为球心,以最小距离results[0][0]为半径的超球体是否和另外一个子节点所在区域相交”,也就是看“这个超球体是否和该节点的切分超平面相交”,也就是看“目标点到切分超平面的距离是否小于最小距离results[0][0]”。
我们注意到判断条件还有一个就是,也就是说最开始results没有填满时,节点的两个子节点都要计算最小距离nodeDist的,因为此时results[0][0]并不一定是最小距离(也许另外一个子节点的最小距离nodeDist比父节点的最小距离nodeDist还要小,但是比results[0][0]要大)。可见4步和5步是不能调换顺序的。
补充对KD树的插入和删除的思考。
删除:《C++数据结构与算法(第4版)》中223页对kd树的删除操作进行了详细的讲解。首先要认识到,kd树也是二叉查找树,只不过保存在树中的项有k个键值(k-dimensions),而一般的二叉查找树的删除算法中删除具有两个后代的节点时,转向右子树并一直向左找到后继节点,或者转向左子树并一直向左找到前驱节点。但是这一策略并不适用于kd树,因为父子节点的分裂键值是不同的,此时只能对左右子树都进行查找。所以kd树的删除算法总结如下:
- 如果删除的是叶子节点的父节点(这里需要说明一下,由于在生成kd树的时候叶子节点并没有保存数据信息,也就是可以看做是红黑树中的NIL节点,所以这里强调的是叶子节点的父节点),那么就直接删除该节点。
- 如果没有右子树,那么可以查找左子树来定位上移的节点;如果没有左子树,就查找右子树来定位上移的节点。
- 如果左右子树都有,那就需要考虑更多。可以选择从删除节点的左子树找前驱节点,此时要注意在下移寻找过程中,如果节点分裂键值和删除节点分裂键值不同,就需要对左右子树都进行寻找前驱节点,如果相同就对右子树寻找前驱节点;或者可以选择从删除节点的右子树找后继节点,如果节点分裂键值和删除节点分裂键值不同,就需要对左右子树都进行寻找后继节点,如果相同就对左子树寻找后继节点。
这里将《C++数据结构与算法(第4版)》中225页的删除算法伪代码抄下:
delete(el)
p=包含el的节点;
delete(p, p的识别字符索引i);
delete(p)
if p是叶节点
删除p;
else if p->right!=0
q=smallest(p->right, i, (i+1) mod k);
else q=smallest(p->left, i, (i+1) mod k);
p->right=p->left;
p->left=0;
p->el=q->el;
delete(q, i);
smallest(q, i, j)
qq=q;
if i==j
if q->left!=0
qq=q=q->left
else return q;
if q->left!=0
lt=smallest(q->left, i, (j+1) mod k);
if qq->el.keys[i]>=lt->el.keys[i]
qq=lt;
if q->right!=0
rt=smallest(q->right, i, (j+1) mod k);
if qq->el.keys[i]>=rt->el.keys[i]
qq=rt;
return qq;
插入:我暂时还没在网上找到kd树的插入算法,我个人有两个思路。一个就是不插入,对新数据积累一段时间后直接将其加到原数据中重新构建kd树;另一个方法就是对每个kd树节点再多保存一个值。(和同学讨论了下,这种方式还是不好,因为加入新的信息后每个维度上的中间值都会发生变化,我的第二种方法会导致树很不平衡,从而影响测试时的查找效率,所以还是得老老实实重新建树!kd树不适合训练集不停变化的情景),就是比该kd树节点分裂值稍小的值,当新数据下移到该kd树节点时判断对应值是否在
和分裂值之间,不在就接着下移,在就将以该kd树节点为根节点的子树中保存的所有信息拿出来并拿出来重新构建新树
网友评论