支持向量机 (SVM),序列最小优化算法 (SMO)

本贴最后更新于 1353 天前,其中的信息可能已经事过境迁

支持向量机(Support Vector Machine)由V.N. Vapnik,A.Y. Chervonenkis,C. Cortes 等在1964年提出。序列最小优化算法(Sequential minimal optimization)是一种用于解决支持向量机训练过程中所产生优化问题的算法。由John C. Platt于1998年提出。

支持向量机的推导在西瓜书,各大网站已经有详细的介绍。本文主要依据 John C. Platt 发表的文章《Sequential Minimal Optimization: A Fast Algorithm for Training Support Vector Machines》来实现 SVM 与 SMO 算法。


算法的流程:

QQ 截图 20210306083526.png


import numpy as np
from sklearn import datasets
import matplotlib.pyplot as plt

定义需要的数据,包含数据样本,数据标签,偏置 b,拉格朗日乘子 α,容忍系数 C 等。

class Par:
	def __init__(self,n,D,C,eps,tol):
		self.X=datasets.make_blobs(n_samples=n,n_features=D,centers=2,cluster_std=1.0,shuffle=True,random_state=None)
		self.point=self.X[0]
		self.target=self.X[1]
		self.target[np.nonzero(self.target==0)[0]]=-1
		self.w=np.zeros((1,D))[0]
		self.b=0
		self.E=-self.target
		self.alpha=np.zeros((1,n))[0]
		self.n=n
		self.C=C
		self.eps=eps
		self.tol=tol

定义核函数,预测公式。

def kernel(x,y):
	return np.dot(x,y.T)

def f(x):
	s=0
	for i in range(n):
		s+=P.alpha[i]*P.target[i]*kernel(P.point[i],x)
	return s-P.b

被选中的一对 α 更新细节:

def takeStep(i1,i2):
	if i1==i2:
		return 0
	alph2=P.alpha[i2]
	alph1=P.alpha[i1]
	y1=P.target[i1]
	y2=P.target[i2]
	s=y1*y2
	#Compute L,H via equations (13) and (14)
	if y1!=y2:
		L=max(0,alph2-alph1)
		H=min(P.C,P.C+alph2-alph1)
	else:
		L=max(0,alph2+alph1-P.C)
		H=min(P.C,alph2+alph1)
	if L==H:
		return 0
	k11=kernel(P.point[i1],P.point[i1])
	k12=kernel(P.point[i1],P.point[i2])
	k22=kernel(P.point[i2],P.point[i2])
	eta=k11+k22-2*k12
	if eta>0:
		a2=alph2+y2*(P.E[i1]-P.E[i2])/eta
	if a2<L:
		a2=L
	elif a2>H:
		a2=H
	else:
		f1=y1*(P.E[i1]+b)-alph1*k11-s*alph2*k12
		f2=y2*(P.E[i2]+b)-s*alph1*k12-alph2*k22
		L1=alph1+s*(alph2-L)
		H1=alph1+s*(alph2+H)
		psiL=L1*f1+L*f2+0.5*L1**2*k11+0.5*L**2*k22+s*L*L1*k12
		psiH=H1*f1+H*f2+0.5*H1**2*k11+0.5*H**2*k22+s*H*H1*k12
		Lobj = psiL
		Hobj = psiH
	if Lobj<Hobj-eps:
		a2=L
	elif Lobj>Hobj+eps:
		a2=H
	else:
		a2=alph2
	if abs(a2-alph2)<P.eps*(a2+alph2+P.eps):
		return 0
	a1=alph1+s*(alph2-a2)
	#Update threshold to reflect change in Lagrange multipliers
	b1=P.E[i1]+y1*(a1-alph1)*k11+y2*(a2-alph2)*k12+P.b
	b2=P.E[i2]+y1*(a1-alph1)*k12+y2*(a2-alph2)*k22+P.b
	if a1>0 and a1<P.C:
		P.b=b1
	elif a2>0 and a2<P.C:
		P.b=b2
	else:
		P.b=(b1+b2)/2
	#Update weight vector to reflect change in a1 & a2, if SVM is linear
	P.w=P.w+y1*(a1-alph1)*P.point[i1]+y2*(a2-alph2)*P.point[i2]
	#Store a1 in the alpha array
	P.alpha[i1]=a1
	#Store a2 in the alpha array
	P.alpha[i2]=a2
	#Update error cache using new Lagrange multipliers
	P.E[i1]=f(P.point[i1])-P.target[i1]
	P.E[i2]=f(P.point[i2])-P.target[i2]
	return 1

内循环选择第二个 α:

def examineExample(i2):
	global valid
	alph2=P.alpha[i2]
	y2=P.target[i2]
	r2=P.E[i2]*y2
	if (r2<-P.tol and alph2<P.C) or (r2>P.tol and alph2>0):
		valid=np.where((P.alpha!=0) & (P.alpha!=C))[0]
		Long=len(valid)
		if Long > 1:
			#i1 = result of second choice heuristic (section 2.2)
			best=-1
			if len(valid)>1:
				for k in valid:
					deltaE=abs(P.E[i2]-P.E[k])
					if deltaE>best:
						best=deltaE
						i1=k
					if takeStep(i1,i2):
						return 1
		#loop over all non-zero and non-C alpha, starting at a random point
		if Long>0:
			random_index=np.random.randint(0,Long)
			for i in np.hstack((valid[random_index:Long],valid[0:random_index])):
				i1=i
				if takeStep(i1,i2):
					return 1
		#loop over all possible i1, starting at a random point
		random_index=np.random.randint(0,n)
		for i in np.hstack((np.arange(random_index,n),np.arange(0,random_index))):
			#i1=loop variable
			i1=i
			if takeStep(i1,i2):
				return 1
	return 0

外循环选择第一个 α:

def SMO():
	global valid
	numChanged=0
	examineAll=1
	while numChanged>0 or examineAll:
		numChanged=0
		if examineAll:
			for i in range(n):
				numChanged+=examineExample(i)
		else:
			#loop I over examples where alpha is not 0 & not C
			for i in valid:
				numChanged+=examineExample(i)
		if examineAll==1:
			examineAll=0
		elif numChanged==0:
			examineAll=1

主函数入口:

if __name__ == '__main__':
	n=100       #样本个数
	C=10         
	eps=0.001  #停止精度
	tol=0.001   #分类容错率
	D=2            #样本维度
	P=Par(n,D,C,eps,tol) 
	SMO()
	#绘制图像
	plt.scatter(P.point[:,0],P.point[:,1],c=P.target)
	x=np.arange(-10,10,0.1)
	y=(P.b-P.w[0]*x)/P.w[1]
	plt.plot(x,y)
	plt.show()
	Y=kernel(P.point,P.w)-P.b
	count=0
	for i in range(n):
	if Y[i]*P.target[i]<0:
	count+=1
	print('Error Point num:',count)

单次测试结果:

Figure1.png

  • 分类
    8 引用 • 10 回帖
  • 机器学习

    机器学习(Machine Learning)是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。

    83 引用 • 37 回帖 • 1 关注

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...

推荐标签 标签

  • 倾城之链
    23 引用 • 66 回帖 • 137 关注
  • Postman

    Postman 是一款简单好用的 HTTP API 调试工具。

    4 引用 • 3 回帖 • 3 关注
  • 又拍云

    又拍云是国内领先的 CDN 服务提供商,国家工信部认证通过的“可信云”,乌云众测平台认证的“安全云”,为移动时代的创业者提供新一代的 CDN 加速服务。

    21 引用 • 37 回帖 • 545 关注
  • Git

    Git 是 Linux Torvalds 为了帮助管理 Linux 内核开发而开发的一个开放源码的版本控制软件。

    209 引用 • 358 回帖
  • Facebook

    Facebook 是一个联系朋友的社交工具。大家可以通过它和朋友、同事、同学以及周围的人保持互动交流,分享无限上传的图片,发布链接和视频,更可以增进对朋友的了解。

    4 引用 • 15 回帖 • 453 关注
  • API

    应用程序编程接口(Application Programming Interface)是一些预先定义的函数,目的是提供应用程序与开发人员基于某软件或硬件得以访问一组例程的能力,而又无需访问源码,或理解内部工作机制的细节。

    77 引用 • 430 回帖 • 2 关注
  • 单点登录

    单点登录(Single Sign On)是目前比较流行的企业业务整合的解决方案之一。SSO 的定义是在多个应用系统中,用户只需要登录一次就可以访问所有相互信任的应用系统。

    9 引用 • 25 回帖
  • WiFiDog

    WiFiDog 是一套开源的无线热点认证管理工具,主要功能包括:位置相关的内容递送;用户认证和授权;集中式网络监控。

    1 引用 • 7 回帖 • 587 关注
  • OAuth

    OAuth 协议为用户资源的授权提供了一个安全的、开放而又简易的标准。与以往的授权方式不同之处是 oAuth 的授权不会使第三方触及到用户的帐号信息(如用户名与密码),即第三方无需使用用户的用户名与密码就可以申请获得该用户资源的授权,因此 oAuth 是安全的。oAuth 是 Open Authorization 的简写。

    36 引用 • 103 回帖 • 9 关注
  • V2Ray
    1 引用 • 15 回帖 • 1 关注
  • golang

    Go 语言是 Google 推出的一种全新的编程语言,可以在不损失应用程序性能的情况下降低代码的复杂性。谷歌首席软件工程师罗布派克(Rob Pike)说:我们之所以开发 Go,是因为过去 10 多年间软件开发的难度令人沮丧。Go 是谷歌 2009 发布的第二款编程语言。

    497 引用 • 1387 回帖 • 283 关注
  • 旅游

    希望你我能在旅途中找到人生的下一站。

    90 引用 • 899 回帖
  • 周末

    星期六到星期天晚,实行五天工作制后,指每周的最后两天。再过几年可能就是三天了。

    14 引用 • 297 回帖
  • 心情

    心是产生任何想法的源泉,心本体会陷入到对自己本体不能理解的状态中,因为心能产生任何想法,不能分出对错,不能分出自己。

    59 引用 • 369 回帖
  • 房星科技

    房星网,我们不和没有钱的程序员谈理想,我们要让程序员又有理想又有钱。我们有雄厚的房地产行业线下资源,遍布昆明全城的 100 家门店、四千地产经纪人是我们坚实的后盾。

    6 引用 • 141 回帖 • 585 关注
  • 链书

    链书(Chainbook)是 B3log 开源社区提供的区块链纸质书交易平台,通过 B3T 实现共享激励与价值链。可将你的闲置书籍上架到链书,我们共同构建这个全新的交易平台,让闲置书籍继续发挥它的价值。

    链书社

    链书目前已经下线,也许以后还有计划重制上线。

    14 引用 • 257 回帖
  • V2EX

    V2EX 是创意工作者们的社区。这里目前汇聚了超过 400,000 名主要来自互联网行业、游戏行业和媒体行业的创意工作者。V2EX 希望能够成为创意工作者们的生活和事业的一部分。

    17 引用 • 236 回帖 • 325 关注
  • SSL

    SSL(Secure Sockets Layer 安全套接层),及其继任者传输层安全(Transport Layer Security,TLS)是为网络通信提供安全及数据完整性的一种安全协议。TLS 与 SSL 在传输层对网络连接进行加密。

    70 引用 • 193 回帖 • 432 关注
  • C

    C 语言是一门通用计算机编程语言,应用广泛。C 语言的设计目标是提供一种能以简易的方式编译、处理低级存储器、产生少量的机器码以及不需要任何运行环境支持便能运行的编程语言。

    85 引用 • 165 回帖 • 1 关注
  • Latke

    Latke 是一款以 JSON 为主的 Java Web 框架。

    71 引用 • 535 回帖 • 786 关注
  • Kafka

    Kafka 是一种高吞吐量的分布式发布订阅消息系统,它可以处理消费者规模的网站中的所有动作流数据。 这种动作(网页浏览,搜索和其他用户的行动)是现代系统中许多功能的基础。 这些数据通常是由于吞吐量的要求而通过处理日志和日志聚合来解决。

    36 引用 • 35 回帖
  • CloudFoundry

    Cloud Foundry 是 VMware 推出的业界第一个开源 PaaS 云平台,它支持多种框架、语言、运行时环境、云平台及应用服务,使开发人员能够在几秒钟内进行应用程序的部署和扩展,无需担心任何基础架构的问题。

    5 引用 • 18 回帖 • 167 关注
  • ngrok

    ngrok 是一个反向代理,通过在公共的端点和本地运行的 Web 服务器之间建立一个安全的通道。

    7 引用 • 63 回帖 • 624 关注
  • 创业

    你比 99% 的人都优秀么?

    84 引用 • 1399 回帖 • 1 关注
  • Java

    Java 是一种可以撰写跨平台应用软件的面向对象的程序设计语言,是由 Sun Microsystems 公司于 1995 年 5 月推出的。Java 技术具有卓越的通用性、高效性、平台移植性和安全性。

    3187 引用 • 8213 回帖
  • Lute

    Lute 是一款结构化的 Markdown 引擎,支持 Go 和 JavaScript。

    25 引用 • 191 回帖 • 16 关注
  • TextBundle

    TextBundle 文件格式旨在应用程序之间交换 Markdown 或 Fountain 之类的纯文本文件时,提供更无缝的用户体验。

    1 引用 • 2 回帖 • 47 关注