美文网首页
ceres solver 06 例子:鲁棒的曲线拟合

ceres solver 06 例子:鲁棒的曲线拟合

作者: book_02 | 来源:发表于2019-12-09 19:46 被阅读0次

前篇文章给出了如何用Ceres进行曲线拟合。

但是如果数据中有离群点,如果还是用上面介绍的方法,那么拟合的曲线将受离群点的影响而跑偏。
如下图所示:

怎么减少离群点的影响呢?

1. 损失函数

为了减少离群点的影响,就要使用损失函数。

再来回顾下之前的非线性最小二乘问题通用公式的表达:
\min_x \frac{1}{2}\sum_{i} \rho_i\left(\left\|f_i\left(x_{i_1}, ... ,x_{i_k}\right)\right\|^2\right) \\ \text{s.t.} l_j \le x_j \le u_j

之前例子我们的损失函数都是采用的默认损失函数: \rho_i(x) = x,这个相当于各个点的作用都一样,没有降低离群点的影响。

如果采用不同的损失函数,可以降低高残差点(即离群点)的影响。

2. 鲁棒的曲线拟合

所以鲁棒的曲线拟合可以采用如下的损失函数:

problem.AddResidualBlock(cost_function, new CauchyLoss(0.5) , &m, &c);

替代之前的

problem.AddResidualBlock(cost_function, NULL , &m, &c);

CauchyLoss是Ceres库中写好的损失函数的一种。

如此修改之后,上面的曲线拟合结果将会优化如下:


减轻了离群点的影响,更加接近真实数据。

3. Ceres 求解过程

最后还是把鲁棒的曲线拟合的步骤记录如下:

3.1 定义代价函数

y = e^{mx + c}录入到代价函数中去,不过这时有了数据项x和y

struct ExponentialResidual {
  ExponentialResidual(double x, double y)
      : x_(x), y_(y) {}

  template <typename T> bool operator()(const T* const m,
                                        const T* const c,
                                        T* residual) const {
    residual[0] = y_ - exp(m[0] * x_ + c[0]);
    return true;
  }

 private:
  const double x_;
  const double y_;
};

3.2 构建问题,添加残差项

  double m = 0.0;
  double c = 0.0;

  Problem problem;
  for (int i = 0; i < kNumObservations; ++i) {
    CostFunction* cost_function =
        new AutoDiffCostFunction<ExponentialResidual, 1, 1, 1>(
            new ExponentialResidual(data[2 * i], data[2 * i + 1]));
    problem.AddResidualBlock(cost_function,
                             new CauchyLoss(0.5),
                             &m, &c);
  }

data数组里存的是上面我们生成的带有噪声的曲线上点的数据,且包含多个离群点

3.3 定义求解器的选项参数,并求解

  Solver::Options options;
  options.linear_solver_type = ceres::DENSE_QR;
  options.minimizer_progress_to_stdout = true;

  Solver::Summary summary;
  Solve(options, &problem, &summary);

4. 完整代码

#include "ceres/ceres.h"
#include "glog/logging.h"

// Data generated using the following octave code.
//   randn('seed', 23497);
//   m = 0.3;
//   c = 0.1;
//   x=[0:0.075:5];
//   y = exp(m * x + c);
//   noise = randn(size(x)) * 0.2;
//   outlier_noise = rand(size(x)) < 0.05;
//   y_observed = y + noise + outlier_noise;
//   data = [x', y_observed'];

const int kNumObservations = 67;
const double data[] = {
0.000000e+00, 1.133898e+00,
7.500000e-02, 1.334902e+00,
1.500000e-01, 1.213546e+00,
2.250000e-01, 1.252016e+00,
3.000000e-01, 1.392265e+00,
3.750000e-01, 1.314458e+00,
4.500000e-01, 1.472541e+00,
5.250000e-01, 1.536218e+00,
6.000000e-01, 1.355679e+00,
6.750000e-01, 1.463566e+00,

7.500000e-01, 1.490201e+00,
8.250000e-01, 1.658699e+00,
9.000000e-01, 1.067574e+00,
9.750000e-01, 1.464629e+00,
1.050000e+00, 1.402653e+00,
1.125000e+00, 1.713141e+00,
1.200000e+00, 1.527021e+00,
1.275000e+00, 1.702632e+00,
1.350000e+00, 1.423899e+00,
1.425000e+00, 5.543078e+00, // Outlier point
1.500000e+00, 5.664015e+00, // Outlier point
1.575000e+00, 1.732484e+00,
1.650000e+00, 1.543296e+00,
1.725000e+00, 1.959523e+00,
1.800000e+00, 1.685132e+00,
1.875000e+00, 1.951791e+00,
1.950000e+00, 2.095346e+00,
2.025000e+00, 2.361460e+00,
2.100000e+00, 2.169119e+00,
2.175000e+00, 2.061745e+00,
2.250000e+00, 2.178641e+00,
2.325000e+00, 2.104346e+00,
2.400000e+00, 2.584470e+00,
2.475000e+00, 1.914158e+00,
2.550000e+00, 2.368375e+00,
2.625000e+00, 2.686125e+00,
2.700000e+00, 2.712395e+00,
2.775000e+00, 2.499511e+00,
2.850000e+00, 2.558897e+00,
2.925000e+00, 2.309154e+00,
3.000000e+00, 2.869503e+00,
3.075000e+00, 3.116645e+00,
3.150000e+00, 3.094907e+00,
3.225000e+00, 2.471759e+00,
3.300000e+00, 3.017131e+00,
3.375000e+00, 3.232381e+00,
3.450000e+00, 2.944596e+00,
3.525000e+00, 3.385343e+00,
3.600000e+00, 3.199826e+00,
3.675000e+00, 3.423039e+00,
3.750000e+00, 3.621552e+00,
3.825000e+00, 3.559255e+00,
3.900000e+00, 3.530713e+00,
3.975000e+00, 3.561766e+00,
4.050000e+00, 3.544574e+00,
4.125000e+00, 3.867945e+00,
4.200000e+00, 4.049776e+00,
4.275000e+00, 3.885601e+00,
4.350000e+00, 4.110505e+00,
4.425000e+00, 4.345320e+00,
4.500000e+00, 4.161241e+00,
4.575000e+00, 4.363407e+00,
4.650000e+00, 4.161576e+00,
4.725000e+00, 4.619728e+00,
4.800000e+00, 4.737410e+00,
4.875000e+00, 4.727863e+00,
4.950000e+00, 4.669206e+00
};

using ceres::AutoDiffCostFunction;
using ceres::CostFunction;
using ceres::CauchyLoss;
using ceres::Problem;
using ceres::Solve;
using ceres::Solver;

struct ExponentialResidual {
  ExponentialResidual(double x, double y)
      : x_(x), y_(y) {}

  template <typename T> bool operator()(const T* const m,
                                        const T* const c,
                                        T* residual) const {
    residual[0] = y_ - exp(m[0] * x_ + c[0]);
    return true;
  }

 private:
  const double x_;
  const double y_;
};

int main(int argc, char** argv) {
  google::InitGoogleLogging(argv[0]);

  double m = 0.0;
  double c = 0.0;

  Problem problem;
  for (int i = 0; i < kNumObservations; ++i) {
    CostFunction* cost_function =
        new AutoDiffCostFunction<ExponentialResidual, 1, 1, 1>(
            new ExponentialResidual(data[2 * i], data[2 * i + 1]));
    problem.AddResidualBlock(cost_function,
                             new CauchyLoss(0.5),
                             &m, &c);
  }

  Solver::Options options;
  options.linear_solver_type = ceres::DENSE_QR;
  options.minimizer_progress_to_stdout = true;

  Solver::Summary summary;
  Solve(options, &problem, &summary);
  std::cout << summary.BriefReport() << "\n";
  std::cout << "Initial m: " << 0.0 << " c: " << 0.0 << "\n";
  std::cout << "Final   m: " << m << " c: " << c << "\n";
  return 0;
}

5. 总结

总结使用步骤如下:

  1. 定义代价函数。根据目标函数定义代价代价函数
  2. 构建问题。添加残差项。这里要根据有噪声的数据点,不断地添加残差项。且使用new CauchyLoss(0.5)作为损失函数替代NULL
  3. 求解问题。定义求解器的选项参数和求解器的结果报告

6. 参考

http://ceres-solver.org/nnls_tutorial.html#robust-curve-fitting

相关文章

网友评论

      本文标题:ceres solver 06 例子:鲁棒的曲线拟合

      本文链接:https://www.haomeiwen.com/subject/htyigctx.html