支持向量机 (SVM),序列最小优化算法 (SMO)

本贴最后更新于 1138 天前,其中的信息可能已经事过境迁

支持向量机(Support Vector Machine)由V.N. Vapnik,A.Y. Chervonenkis,C. Cortes 等在1964年提出。序列最小优化算法(Sequential minimal optimization)是一种用于解决支持向量机训练过程中所产生优化问题的算法。由John C. Platt于1998年提出。

支持向量机的推导在西瓜书,各大网站已经有详细的介绍。本文主要依据 John C. Platt 发表的文章《Sequential Minimal Optimization: A Fast Algorithm for Training Support Vector Machines》来实现 SVM 与 SMO 算法。


算法的流程:

QQ 截图 20210306083526.png


import numpy as np
from sklearn import datasets
import matplotlib.pyplot as plt

定义需要的数据,包含数据样本,数据标签,偏置 b,拉格朗日乘子 α,容忍系数 C 等。

class Par:
	def __init__(self,n,D,C,eps,tol):
		self.X=datasets.make_blobs(n_samples=n,n_features=D,centers=2,cluster_std=1.0,shuffle=True,random_state=None)
		self.point=self.X[0]
		self.target=self.X[1]
		self.target[np.nonzero(self.target==0)[0]]=-1
		self.w=np.zeros((1,D))[0]
		self.b=0
		self.E=-self.target
		self.alpha=np.zeros((1,n))[0]
		self.n=n
		self.C=C
		self.eps=eps
		self.tol=tol

定义核函数,预测公式。

def kernel(x,y):
	return np.dot(x,y.T)

def f(x):
	s=0
	for i in range(n):
		s+=P.alpha[i]*P.target[i]*kernel(P.point[i],x)
	return s-P.b

被选中的一对 α 更新细节:

def takeStep(i1,i2):
	if i1==i2:
		return 0
	alph2=P.alpha[i2]
	alph1=P.alpha[i1]
	y1=P.target[i1]
	y2=P.target[i2]
	s=y1*y2
	#Compute L,H via equations (13) and (14)
	if y1!=y2:
		L=max(0,alph2-alph1)
		H=min(P.C,P.C+alph2-alph1)
	else:
		L=max(0,alph2+alph1-P.C)
		H=min(P.C,alph2+alph1)
	if L==H:
		return 0
	k11=kernel(P.point[i1],P.point[i1])
	k12=kernel(P.point[i1],P.point[i2])
	k22=kernel(P.point[i2],P.point[i2])
	eta=k11+k22-2*k12
	if eta>0:
		a2=alph2+y2*(P.E[i1]-P.E[i2])/eta
	if a2<L:
		a2=L
	elif a2>H:
		a2=H
	else:
		f1=y1*(P.E[i1]+b)-alph1*k11-s*alph2*k12
		f2=y2*(P.E[i2]+b)-s*alph1*k12-alph2*k22
		L1=alph1+s*(alph2-L)
		H1=alph1+s*(alph2+H)
		psiL=L1*f1+L*f2+0.5*L1**2*k11+0.5*L**2*k22+s*L*L1*k12
		psiH=H1*f1+H*f2+0.5*H1**2*k11+0.5*H**2*k22+s*H*H1*k12
		Lobj = psiL
		Hobj = psiH
	if Lobj<Hobj-eps:
		a2=L
	elif Lobj>Hobj+eps:
		a2=H
	else:
		a2=alph2
	if abs(a2-alph2)<P.eps*(a2+alph2+P.eps):
		return 0
	a1=alph1+s*(alph2-a2)
	#Update threshold to reflect change in Lagrange multipliers
	b1=P.E[i1]+y1*(a1-alph1)*k11+y2*(a2-alph2)*k12+P.b
	b2=P.E[i2]+y1*(a1-alph1)*k12+y2*(a2-alph2)*k22+P.b
	if a1>0 and a1<P.C:
		P.b=b1
	elif a2>0 and a2<P.C:
		P.b=b2
	else:
		P.b=(b1+b2)/2
	#Update weight vector to reflect change in a1 & a2, if SVM is linear
	P.w=P.w+y1*(a1-alph1)*P.point[i1]+y2*(a2-alph2)*P.point[i2]
	#Store a1 in the alpha array
	P.alpha[i1]=a1
	#Store a2 in the alpha array
	P.alpha[i2]=a2
	#Update error cache using new Lagrange multipliers
	P.E[i1]=f(P.point[i1])-P.target[i1]
	P.E[i2]=f(P.point[i2])-P.target[i2]
	return 1

内循环选择第二个 α:

def examineExample(i2):
	global valid
	alph2=P.alpha[i2]
	y2=P.target[i2]
	r2=P.E[i2]*y2
	if (r2<-P.tol and alph2<P.C) or (r2>P.tol and alph2>0):
		valid=np.where((P.alpha!=0) & (P.alpha!=C))[0]
		Long=len(valid)
		if Long > 1:
			#i1 = result of second choice heuristic (section 2.2)
			best=-1
			if len(valid)>1:
				for k in valid:
					deltaE=abs(P.E[i2]-P.E[k])
					if deltaE>best:
						best=deltaE
						i1=k
					if takeStep(i1,i2):
						return 1
		#loop over all non-zero and non-C alpha, starting at a random point
		if Long>0:
			random_index=np.random.randint(0,Long)
			for i in np.hstack((valid[random_index:Long],valid[0:random_index])):
				i1=i
				if takeStep(i1,i2):
					return 1
		#loop over all possible i1, starting at a random point
		random_index=np.random.randint(0,n)
		for i in np.hstack((np.arange(random_index,n),np.arange(0,random_index))):
			#i1=loop variable
			i1=i
			if takeStep(i1,i2):
				return 1
	return 0

外循环选择第一个 α:

def SMO():
	global valid
	numChanged=0
	examineAll=1
	while numChanged>0 or examineAll:
		numChanged=0
		if examineAll:
			for i in range(n):
				numChanged+=examineExample(i)
		else:
			#loop I over examples where alpha is not 0 & not C
			for i in valid:
				numChanged+=examineExample(i)
		if examineAll==1:
			examineAll=0
		elif numChanged==0:
			examineAll=1

主函数入口:

if __name__ == '__main__':
	n=100       #样本个数
	C=10         
	eps=0.001  #停止精度
	tol=0.001   #分类容错率
	D=2            #样本维度
	P=Par(n,D,C,eps,tol) 
	SMO()
	#绘制图像
	plt.scatter(P.point[:,0],P.point[:,1],c=P.target)
	x=np.arange(-10,10,0.1)
	y=(P.b-P.w[0]*x)/P.w[1]
	plt.plot(x,y)
	plt.show()
	Y=kernel(P.point,P.w)-P.b
	count=0
	for i in range(n):
	if Y[i]*P.target[i]<0:
	count+=1
	print('Error Point num:',count)

单次测试结果:

Figure1.png

  • 分类
    7 引用 • 10 回帖
  • 机器学习

    机器学习(Machine Learning)是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。

    76 引用 • 37 回帖

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...

推荐标签 标签

  • 心情

    心是产生任何想法的源泉,心本体会陷入到对自己本体不能理解的状态中,因为心能产生任何想法,不能分出对错,不能分出自己。

    59 引用 • 369 回帖
  • uTools

    uTools 是一个极简、插件化、跨平台的现代桌面软件。通过自由选配丰富的插件,打造你得心应手的工具集合。

    5 引用 • 13 回帖
  • Hexo

    Hexo 是一款快速、简洁且高效的博客框架,使用 Node.js 编写。

    21 引用 • 140 回帖 • 25 关注
  • V2Ray
    1 引用 • 15 回帖
  • 小说

    小说是以刻画人物形象为中心,通过完整的故事情节和环境描写来反映社会生活的文学体裁。

    28 引用 • 108 回帖 • 2 关注
  • Bootstrap

    Bootstrap 是 Twitter 推出的一个用于前端开发的开源工具包。它由 Twitter 的设计师 Mark Otto 和 Jacob Thornton 合作开发,是一个 CSS / HTML 框架。

    18 引用 • 33 回帖 • 685 关注
  • OnlyOffice
    4 引用 • 27 关注
  • 房星科技

    房星网,我们不和没有钱的程序员谈理想,我们要让程序员又有理想又有钱。我们有雄厚的房地产行业线下资源,遍布昆明全城的 100 家门店、四千地产经纪人是我们坚实的后盾。

    6 引用 • 141 回帖 • 553 关注
  • Firefox

    Mozilla Firefox 中文俗称“火狐”(正式缩写为 Fx 或 fx,非正式缩写为 FF),是一个开源的网页浏览器,使用 Gecko 排版引擎,支持多种操作系统,如 Windows、OSX 及 Linux 等。

    7 引用 • 30 回帖 • 455 关注
  • Angular

    AngularAngularJS 的新版本。

    26 引用 • 66 回帖 • 511 关注
  • Unity

    Unity 是由 Unity Technologies 开发的一个让开发者可以轻松创建诸如 2D、3D 多平台的综合型游戏开发工具,是一个全面整合的专业游戏引擎。

    25 引用 • 7 回帖 • 250 关注
  • 分享

    有什么新发现就分享给大家吧!

    242 引用 • 1746 回帖 • 1 关注
  • iOS

    iOS 是由苹果公司开发的移动操作系统,最早于 2007 年 1 月 9 日的 Macworld 大会上公布这个系统,最初是设计给 iPhone 使用的,后来陆续套用到 iPod touch、iPad 以及 Apple TV 等产品上。iOS 与苹果的 Mac OS X 操作系统一样,属于类 Unix 的商业操作系统。

    84 引用 • 139 回帖 • 1 关注
  • SQLServer

    SQL Server 是由 [微软] 开发和推广的关系数据库管理系统(DBMS),它最初是由 微软、Sybase 和 Ashton-Tate 三家公司共同开发的,并于 1988 年推出了第一个 OS/2 版本。

    19 引用 • 31 回帖 • 3 关注
  • 导航

    各种网址链接、内容导航。

    37 引用 • 168 回帖
  • JSON

    JSON (JavaScript Object Notation)是一种轻量级的数据交换格式。易于人类阅读和编写。同时也易于机器解析和生成。

    51 引用 • 190 回帖 • 2 关注
  • 星云链

    星云链是一个开源公链,业内简单的将其称为区块链上的谷歌。其实它不仅仅是区块链搜索引擎,一个公链的所有功能,它基本都有,比如你可以用它来开发部署你的去中心化的 APP,你可以在上面编写智能合约,发送交易等等。3 分钟快速接入星云链 (NAS) 测试网

    3 引用 • 16 回帖 • 5 关注
  • Facebook

    Facebook 是一个联系朋友的社交工具。大家可以通过它和朋友、同事、同学以及周围的人保持互动交流,分享无限上传的图片,发布链接和视频,更可以增进对朋友的了解。

    4 引用 • 15 回帖 • 454 关注
  • 单点登录

    单点登录(Single Sign On)是目前比较流行的企业业务整合的解决方案之一。SSO 的定义是在多个应用系统中,用户只需要登录一次就可以访问所有相互信任的应用系统。

    9 引用 • 25 回帖 • 3 关注
  • 书籍

    宋真宗赵恒曾经说过:“书中自有黄金屋,书中自有颜如玉。”

    76 引用 • 390 回帖 • 1 关注
  • frp

    frp 是一个可用于内网穿透的高性能的反向代理应用,支持 TCP、UDP、 HTTP 和 HTTPS 协议。

    15 引用 • 7 回帖 • 9 关注
  • Mobi.css

    Mobi.css is a lightweight, flexible CSS framework that focus on mobile.

    1 引用 • 6 回帖 • 696 关注
  • Maven

    Maven 是基于项目对象模型(POM)、通过一小段描述信息来管理项目的构建、报告和文档的软件项目管理工具。

    185 引用 • 318 回帖 • 348 关注
  • danl
    61 关注
  • Rust

    Rust 是一门赋予每个人构建可靠且高效软件能力的语言。Rust 由 Mozilla 开发,最早发布于 2014 年 9 月。

    57 引用 • 22 回帖 • 2 关注
  • SEO

    发布对别人有帮助的原创内容是最好的 SEO 方式。

    35 引用 • 200 回帖 • 24 关注
  • AngularJS

    AngularJS 诞生于 2009 年,由 Misko Hevery 等人创建,后为 Google 所收购。是一款优秀的前端 JS 框架,已经被用于 Google 的多款产品当中。AngularJS 有着诸多特性,最为核心的是:MVC、模块化、自动化双向数据绑定、语义化标签、依赖注入等。2.0 版本后已经改名为 Angular。

    12 引用 • 50 回帖 • 422 关注