Hard margin SVM
在分类问题中,决策边界决定了模型的预测结果,为了提高模型的泛化能力,我们通常期望决策边界不仅能将特征空间中不同类别的数据进行划分,还期望这条边界与与两方甚至多方数据达到一个距离上的平等,从而避免过拟合,提高模型的泛化能力。
从而问题就变为,我们要找到一条决策边界,这条边界离各个数据(特征)都要尽可能地远。
SVM尝试寻找到一个最优的决策边界,距离两个类别的样本最远,而最优边界的寻找是依赖于支持向量
因此,我们希望能找到最大的margin–>d,在解析几何中,(x,y)点到(Ax+By+C=0)直线的距离为
因而对于前面描述的SVM距离d 有如下关系式:
从而推导出了与决策边界平行的两条直线表达式:如下
可以得到距离关系式为:
可以得到距离的表达式为:
从而,我们的目标就转换为了:
这是一个有条件的优化问题。
Soft Margin和SVM的正则化
在很多时候我们希望牺牲比较小的模型精度从而提高模型的泛化能力,比如下图:
因此我们希望引入一个机制能使得模型拥有容错能力,在训练过程中可以容忍小部分的错误来使得模型泛化能力更强,或者说不同数据间的Margin更大。
而这种SVM就叫做Soft Margin SVM。在Hard Margin SVM的限制条件中稍作改动即可转化为Soft Margin,
SVM的使用
首先,因为SVM是设计到对距离的度量的,所以和KNN算法一样需要考虑将数据标准化。
下面就scikit-learn提供的SVM来进行一些使用。
scikit-learn中的SVM
加载鸢尾花的数据集
import numpy as np
import matplotlib.pyplot as plt
import warnings
from sklearn import datasets
warnings.filterwarnings("ignore")
iris = datasets.load_iris()
X = iris.data
y = iris.target
X = X[y<2,:2]
y = y[y<2]
plt.scatter(X[y==0,0],X[y==0,1],color='r')
plt.scatter(X[y==1,0],X[y==1,1],color='b')
plt.show()
- 对数据集进行标准化工作
from sklearn.preprocessing import StandardScaler
std_scaler = StandardScaler()
std_scaler.fit(X)
X_std = std_scaler.transform(X)
- 调用SVM算法
from sklearn.svm import LinearSVC
svc = LinearSVC(C=1e9) # C取得越大,对精度要求越高,对泛化能力要求越低
svc.fit(X_std,y)
LinearSVC(C=1000000000.0, class_weight=None, dual=True, fit_intercept=True,
intercept_scaling=1, loss='squared_hinge', max_iter=1000,
multi_class='ovr', penalty='l2', random_state=None, tol=0.0001,
verbose=0
- 绘制决策边界
hard margin
def plot_decision_boundary(model,axis):
x0,x1 = np.meshgrid(
np.linspace(axis[0],axis[1],int((axis[1]-axis[0])*100)).reshape(-1,1),
np.linspace(axis[2],axis[3],int((axis[3]-axis[2])*100)).reshape(-1,1),
)
X_new = np.c_[x0.ravel(),x1.ravel()]
y_predict = model.predict(X_new)
zz = y_predict.reshape(x0.shape)
from matplotlib.colors import ListedColormap
custom_cmap = ListedColormap(['#EF9A9A','#FFF59D','#90CAF9'])
plt.contourf(x0,x1,zz,linewidth=5,cmap=custom_cmap)
plot_decision_boundary(svc,axis=[-3,3,-3,3])
plt.scatter(X_std[y==0,0],X_std[y==0,1],color='r')
plt.scatter(X_std[y==1,0],X_std[y==1,1],color='b')
plt.show()













网友评论