加权框融合(Weighted Boxes Fusion, WBF)详解与实现
在目标检测任务中,常常会遇到这样一种情况:
我们使用多个模型对同一张图片进行检测,不同模型的检测框(Bounding Box)可能会有所差异。
为了得到更稳定、更鲁棒的检测结果,结果融合就显得尤为重要。
常见的框融合方法有:
- NMS(Non-Maximum Suppression,非极大值抑制)
- Soft-NMS
- WBF(Weighted Boxes Fusion,加权框融合)
其中,WBF 是近年来应用较多的一种方法,它能够有效利用不同模型检测框之间的置信度和空间信息,得到比 NMS 更优的融合结果。
一、WBF 的核心思想
与传统的 NMS 不同,WBF 不会直接丢弃重叠度高的框,而是根据置信度对候选框进行加权平均,得到一个新的、更准确的框。
其基本流程如下:
-
按类别分组
将不同类别的检测框分别处理,避免不同类别之间相互干扰。 -
置信度排序
按照检测框的置信度从高到低进行排序,优先处理高置信度的框。 -
簇形成(Clustering)
如果两个框的 IoU(交并比)大于设定阈值(例如 0.55),则它们属于同一个簇(cluster)。 -
加权融合(Weighted Fusion)
对同一簇内的框,按照置信度加权平均,得到一个新的融合框。
相比 NMS 只保留最高分的框,WBF 可以让多个模型的框“共同投票”,使结果更加稳定。
二、WBF的典型用法
加权框融合(WBF)因其独特的融合机制,在目标检测的后处理流程中扮演着日益重要的角色。它的应用场景灵活且强大,主要可以归纳为以下三个方面:
1. 作为NMS的直接替代品(单模型优化)
这是WBF最基础也是最常见的用法:在单个模型的推理流程中,直接用WBF替换掉传统的非极大值抑制(NMS)。
-
处理流程:
- 单个目标检测模型(如YOLOv8, Faster R-CNN等)完成前向传播,输出大量的原始预测框(Raw Detections)。
- 将这些原始预测框(通常经过一个较低的置信度预过滤)全部送入WBF算法。
- WBF对框进行聚类和加权融合,输出最终的、优化后的检测结果。
-
核心优势:
- 提升定位精度:对于同一个物体,模型可能输出多个位置略有偏差的框。NMS会保留得分最高的那个,而忽略其他框的定位信息。WBF则会将这些框的坐标进行加权平均,生成一个几何位置上更精确的新框,有效平滑了模型的定位误差。
- 提高召回率:当一个物体被部分遮挡或特征不明显时,模型可能会输出两个置信度都中等的框,分别对应物体的不同部分。NMS可能会因为IoU过高而丢弃其中一个,导致信息丢失。WBF则倾向于将它们融合,保留对该物体的检测。
2. 用于多模型集成(Ensemble)
这是WBF真正大放异彩的“杀手级应用”。模型集成是通过结合多个不同模型的预测来获得比任何单一模型都更好的性能,而WBF是实现检测模型集成的最佳工具之一。
-
处理流程:
- 准备多个不同的模型。这些模型可以是:
不同架构的模型(例如,YOLOv8 + DETR)。
使用不同训练数据或超参训练的同架构模型。 - 同一个模型在训练过程中的不同检查点(Checkpoints)。
让每个模型都对同一张图片进行推理,得到多组成绩(原始预测框)。 - 将所有模型输出的所有预测框收集到一个大的列表中。
- 将这个聚合了所有模型智慧的列表送入WBF算法进行最终的融合。
- 准备多个不同的模型。这些模型可以是:
-
核心优势:
- 优势互补:模型A可能擅长识别大物体,模型B擅长识别小物体。WBF能将它们的预测结果完美结合,取长补短。
- 极强的鲁棒性:如果某个模型出现漏检(False Negative),其他模型大概率仍然能检测到。WBF综合所有信息,能显著降低漏检率。
- 抑制误报:如果只有一个模型产生了错误的检测(False Positive),它在WBF中很难形成一个高置信度的簇,从而被有效抑制。这要求“多个模型达成共识”,结果才可信。
三、代码实现
下面给出一个 C++ 实现版本
#include <unordered_map>
#include <vector>
#include <algorithm>
#include <memory>
#include <string>
namespace object {
struct Rect {
float left, top, right, bottom;
};
struct DetectionBox {
Rect box;
float score;
int class_id;
std::string class_name;
};
using DetectionBoxArray = std::vector<DetectionBox>;
}
// 计算 IoU
float calculate_iou(const object::DetectionBox &a, const object::DetectionBox &b) {
float inter_left = std::max(a.box.left, b.box.left);
float inter_top = std::max(a.box.top, b.box.top);
float inter_right = std::min(a.box.right, b.box.right);
float inter_bottom = std::min(a.box.bottom, b.box.bottom);
float inter_area = std::max(0.0f, inter_right - inter_left) *
std::max(0.0f, inter_bottom - inter_top);
float area_a = (a.box.right - a.box.left) * (a.box.bottom - a.box.top);
float area_b = (b.box.right - b.box.left) * (b.box.bottom - b.box.top);
float union_area = area_a + area_b - inter_area;
if (union_area <= 0.0f) return 0.0f;
return inter_area / union_area;
}
// WBF 融合过程
object::DetectionBoxArray weighted_boxes_fusion(
const std::vector<object::DetectionBox> &boxes,
float iou_threshold = 0.55)
{
if (boxes.empty()) return {};
// 按置信度排序
auto sorted_boxes = boxes;
std::sort(sorted_boxes.begin(), sorted_boxes.end(),
[](const object::DetectionBox &a, const object::DetectionBox &b) {
return a.score > b.score;
});
std::vector<std::vector<object::DetectionBox>> clusters;
// 聚类
for (const auto &box : sorted_boxes) {
bool matched = false;
for (auto &cluster : clusters) {
if (calculate_iou(box, cluster.front()) > iou_threshold) {
cluster.push_back(box);
matched = true;
break;
}
}
if (!matched) {
clusters.push_back({box});
}
}
// 对每个簇进行加权融合
object::DetectionBoxArray fused_boxes;
for (const auto &cluster : clusters) {
if (cluster.empty()) continue;
float total_score = 0.0f;
float fused_left = 0.0f, fused_top = 0.0f, fused_right = 0.0f, fused_bottom = 0.0f;
for (const auto &box : cluster) {
total_score += box.score;
fused_left += box.box.left * box.score;
fused_top += box.box.top * box.score;
fused_right += box.box.right * box.score;
fused_bottom += box.box.bottom * box.score;
}
object::DetectionBox fused;
fused.box.left = fused_left / total_score;
fused.box.top = fused_top / total_score;
fused.box.right = fused_right / total_score;
fused.box.bottom = fused_bottom / total_score;
fused.score = total_score / cluster.size();
fused.class_id = cluster.front().class_id;
fused.class_name = cluster.front().class_name;
fused_boxes.push_back(fused);
}
return fused_boxes;
}








网友评论