支持向量机 (SVM),序列最小优化算法 (SMO)

本贴最后更新于 1383 天前,其中的信息可能已经事过境迁

支持向量机(Support Vector Machine)由V.N. Vapnik,A.Y. Chervonenkis,C. Cortes 等在1964年提出。序列最小优化算法(Sequential minimal optimization)是一种用于解决支持向量机训练过程中所产生优化问题的算法。由John C. Platt于1998年提出。

支持向量机的推导在西瓜书,各大网站已经有详细的介绍。本文主要依据 John C. Platt 发表的文章《Sequential Minimal Optimization: A Fast Algorithm for Training Support Vector Machines》来实现 SVM 与 SMO 算法。


算法的流程:

QQ 截图 20210306083526.png


import numpy as np
from sklearn import datasets
import matplotlib.pyplot as plt

定义需要的数据,包含数据样本,数据标签,偏置 b,拉格朗日乘子 α,容忍系数 C 等。

class Par:
	def __init__(self,n,D,C,eps,tol):
		self.X=datasets.make_blobs(n_samples=n,n_features=D,centers=2,cluster_std=1.0,shuffle=True,random_state=None)
		self.point=self.X[0]
		self.target=self.X[1]
		self.target[np.nonzero(self.target==0)[0]]=-1
		self.w=np.zeros((1,D))[0]
		self.b=0
		self.E=-self.target
		self.alpha=np.zeros((1,n))[0]
		self.n=n
		self.C=C
		self.eps=eps
		self.tol=tol

定义核函数,预测公式。

def kernel(x,y):
	return np.dot(x,y.T)

def f(x):
	s=0
	for i in range(n):
		s+=P.alpha[i]*P.target[i]*kernel(P.point[i],x)
	return s-P.b

被选中的一对 α 更新细节:

def takeStep(i1,i2):
	if i1==i2:
		return 0
	alph2=P.alpha[i2]
	alph1=P.alpha[i1]
	y1=P.target[i1]
	y2=P.target[i2]
	s=y1*y2
	#Compute L,H via equations (13) and (14)
	if y1!=y2:
		L=max(0,alph2-alph1)
		H=min(P.C,P.C+alph2-alph1)
	else:
		L=max(0,alph2+alph1-P.C)
		H=min(P.C,alph2+alph1)
	if L==H:
		return 0
	k11=kernel(P.point[i1],P.point[i1])
	k12=kernel(P.point[i1],P.point[i2])
	k22=kernel(P.point[i2],P.point[i2])
	eta=k11+k22-2*k12
	if eta>0:
		a2=alph2+y2*(P.E[i1]-P.E[i2])/eta
	if a2<L:
		a2=L
	elif a2>H:
		a2=H
	else:
		f1=y1*(P.E[i1]+b)-alph1*k11-s*alph2*k12
		f2=y2*(P.E[i2]+b)-s*alph1*k12-alph2*k22
		L1=alph1+s*(alph2-L)
		H1=alph1+s*(alph2+H)
		psiL=L1*f1+L*f2+0.5*L1**2*k11+0.5*L**2*k22+s*L*L1*k12
		psiH=H1*f1+H*f2+0.5*H1**2*k11+0.5*H**2*k22+s*H*H1*k12
		Lobj = psiL
		Hobj = psiH
	if Lobj<Hobj-eps:
		a2=L
	elif Lobj>Hobj+eps:
		a2=H
	else:
		a2=alph2
	if abs(a2-alph2)<P.eps*(a2+alph2+P.eps):
		return 0
	a1=alph1+s*(alph2-a2)
	#Update threshold to reflect change in Lagrange multipliers
	b1=P.E[i1]+y1*(a1-alph1)*k11+y2*(a2-alph2)*k12+P.b
	b2=P.E[i2]+y1*(a1-alph1)*k12+y2*(a2-alph2)*k22+P.b
	if a1>0 and a1<P.C:
		P.b=b1
	elif a2>0 and a2<P.C:
		P.b=b2
	else:
		P.b=(b1+b2)/2
	#Update weight vector to reflect change in a1 & a2, if SVM is linear
	P.w=P.w+y1*(a1-alph1)*P.point[i1]+y2*(a2-alph2)*P.point[i2]
	#Store a1 in the alpha array
	P.alpha[i1]=a1
	#Store a2 in the alpha array
	P.alpha[i2]=a2
	#Update error cache using new Lagrange multipliers
	P.E[i1]=f(P.point[i1])-P.target[i1]
	P.E[i2]=f(P.point[i2])-P.target[i2]
	return 1

内循环选择第二个 α:

def examineExample(i2):
	global valid
	alph2=P.alpha[i2]
	y2=P.target[i2]
	r2=P.E[i2]*y2
	if (r2<-P.tol and alph2<P.C) or (r2>P.tol and alph2>0):
		valid=np.where((P.alpha!=0) & (P.alpha!=C))[0]
		Long=len(valid)
		if Long > 1:
			#i1 = result of second choice heuristic (section 2.2)
			best=-1
			if len(valid)>1:
				for k in valid:
					deltaE=abs(P.E[i2]-P.E[k])
					if deltaE>best:
						best=deltaE
						i1=k
					if takeStep(i1,i2):
						return 1
		#loop over all non-zero and non-C alpha, starting at a random point
		if Long>0:
			random_index=np.random.randint(0,Long)
			for i in np.hstack((valid[random_index:Long],valid[0:random_index])):
				i1=i
				if takeStep(i1,i2):
					return 1
		#loop over all possible i1, starting at a random point
		random_index=np.random.randint(0,n)
		for i in np.hstack((np.arange(random_index,n),np.arange(0,random_index))):
			#i1=loop variable
			i1=i
			if takeStep(i1,i2):
				return 1
	return 0

外循环选择第一个 α:

def SMO():
	global valid
	numChanged=0
	examineAll=1
	while numChanged>0 or examineAll:
		numChanged=0
		if examineAll:
			for i in range(n):
				numChanged+=examineExample(i)
		else:
			#loop I over examples where alpha is not 0 & not C
			for i in valid:
				numChanged+=examineExample(i)
		if examineAll==1:
			examineAll=0
		elif numChanged==0:
			examineAll=1

主函数入口:

if __name__ == '__main__':
	n=100       #样本个数
	C=10         
	eps=0.001  #停止精度
	tol=0.001   #分类容错率
	D=2            #样本维度
	P=Par(n,D,C,eps,tol) 
	SMO()
	#绘制图像
	plt.scatter(P.point[:,0],P.point[:,1],c=P.target)
	x=np.arange(-10,10,0.1)
	y=(P.b-P.w[0]*x)/P.w[1]
	plt.plot(x,y)
	plt.show()
	Y=kernel(P.point,P.w)-P.b
	count=0
	for i in range(n):
	if Y[i]*P.target[i]<0:
	count+=1
	print('Error Point num:',count)

单次测试结果:

Figure1.png

  • 分类
    8 引用 • 10 回帖
  • 机器学习

    机器学习(Machine Learning)是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。

    83 引用 • 37 回帖

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...

推荐标签 标签

  • 游戏

    沉迷游戏伤身,强撸灰飞烟灭。

    177 引用 • 816 回帖
  • 大疆创新

    深圳市大疆创新科技有限公司(DJI-Innovations,简称 DJI),成立于 2006 年,是全球领先的无人飞行器控制系统及无人机解决方案的研发和生产商,客户遍布全球 100 多个国家。通过持续的创新,大疆致力于为无人机工业、行业用户以及专业航拍应用提供性能最强、体验最佳的革命性智能飞控产品和解决方案。

    2 引用 • 14 回帖
  • 安装

    你若安好,便是晴天。

    132 引用 • 1184 回帖 • 2 关注
  • 钉钉

    钉钉,专为中国企业打造的免费沟通协同多端平台, 阿里巴巴出品。

    15 引用 • 67 回帖 • 335 关注
  • 大数据

    大数据(big data)是指无法在一定时间范围内用常规软件工具进行捕捉、管理和处理的数据集合,是需要新处理模式才能具有更强的决策力、洞察发现力和流程优化能力的海量、高增长率和多样化的信息资产。

    93 引用 • 113 回帖 • 1 关注
  • Quicker

    Quicker 您的指尖工具箱!操作更少,收获更多!

    34 引用 • 148 回帖
  • OkHttp

    OkHttp 是一款 HTTP & HTTP/2 客户端库,专为 Android 和 Java 应用打造。

    16 引用 • 6 回帖 • 76 关注
  • Ngui

    Ngui 是一个 GUI 的排版显示引擎和跨平台的 GUI 应用程序开发框架,基于
    Node.js / OpenGL。目标是在此基础上开发 GUI 应用程序可拥有开发 WEB 应用般简单与速度同时兼顾 Native 应用程序的性能与体验。

    7 引用 • 9 回帖 • 394 关注
  • IBM

    IBM(国际商业机器公司)或万国商业机器公司,简称 IBM(International Business Machines Corporation),总公司在纽约州阿蒙克市。1911 年托马斯·沃森创立于美国,是全球最大的信息技术和业务解决方案公司,拥有全球雇员 30 多万人,业务遍及 160 多个国家和地区。

    17 引用 • 53 回帖 • 141 关注
  • Google

    Google(Google Inc.,NASDAQ:GOOG)是一家美国上市公司(公有股份公司),于 1998 年 9 月 7 日以私有股份公司的形式创立,设计并管理一个互联网搜索引擎。Google 公司的总部称作“Googleplex”,它位于加利福尼亚山景城。Google 目前被公认为是全球规模最大的搜索引擎,它提供了简单易用的免费服务。不作恶(Don't be evil)是谷歌公司的一项非正式的公司口号。

    49 引用 • 192 回帖 • 1 关注
  • Pipe

    Pipe 是一款小而美的开源博客平台。Pipe 有着非常活跃的社区,可将文章作为帖子推送到社区,来自社区的回帖将作为博客评论进行联动(具体细节请浏览 B3log 构思 - 分布式社区网络)。

    这是一种全新的网络社区体验,让热爱记录和分享的你不再感到孤单!

    132 引用 • 1114 回帖 • 126 关注
  • 正则表达式

    正则表达式(Regular Expression)使用单个字符串来描述、匹配一系列遵循某个句法规则的字符串。

    31 引用 • 94 回帖 • 2 关注
  • 机器学习

    机器学习(Machine Learning)是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。

    83 引用 • 37 回帖
  • 星云链

    星云链是一个开源公链,业内简单的将其称为区块链上的谷歌。其实它不仅仅是区块链搜索引擎,一个公链的所有功能,它基本都有,比如你可以用它来开发部署你的去中心化的 APP,你可以在上面编写智能合约,发送交易等等。3 分钟快速接入星云链 (NAS) 测试网

    3 引用 • 16 回帖 • 6 关注
  • Tomcat

    Tomcat 最早是由 Sun Microsystems 开发的一个 Servlet 容器,在 1999 年被捐献给 ASF(Apache Software Foundation),隶属于 Jakarta 项目,现在已经独立为一个顶级项目。Tomcat 主要实现了 JavaEE 中的 Servlet、JSP 规范,同时也提供 HTTP 服务,是市场上非常流行的 Java Web 容器。

    162 引用 • 529 回帖 • 5 关注
  • 快应用

    快应用 是基于手机硬件平台的新型应用形态;标准是由主流手机厂商组成的快应用联盟联合制定;快应用标准的诞生将在研发接口、能力接入、开发者服务等层面建设标准平台;以平台化的生态模式对个人开发者和企业开发者全品类开放。

    15 引用 • 127 回帖
  • HTML

    HTML5 是 HTML 下一个的主要修订版本,现在仍处于发展阶段。广义论及 HTML5 时,实际指的是包括 HTML、CSS 和 JavaScript 在内的一套技术组合。

    107 引用 • 295 回帖
  • BookxNote

    BookxNote 是一款全新的电子书学习工具,助力您的学习与思考,让您的大脑更高效的记忆。

    笔记整理交给我,一心只读圣贤书。

    1 引用 • 1 回帖
  • CentOS

    CentOS(Community Enterprise Operating System)是 Linux 发行版之一,它是来自于 Red Hat Enterprise Linux 依照开放源代码规定释出的源代码所编译而成。由于出自同样的源代码,因此有些要求高度稳定的服务器以 CentOS 替代商业版的 Red Hat Enterprise Linux 使用。两者的不同在于 CentOS 并不包含封闭源代码软件。

    238 引用 • 224 回帖
  • PWL

    组织简介

    用爱发电 (Programming With Love) 是一个以开源精神为核心的民间开源爱好者技术组织,“用爱发电”象征开源与贡献精神,加入组织,代表你将遵守组织的“个人开源爱好者”的各项条款。申请加入:用爱发电组织邀请帖
    用爱发电组织官网:https://programmingwithlove.stackoverflow.wiki/

    用爱发电组织的核心驱动力:

    • 遵守开源守则,体现开源&贡献精神:以分享为目的,拒绝非法牟利。
    • 自我保护:使用适当的 License 保护自己的原创作品。
    • 尊重他人:不以各种理由、各种漏洞进行未经允许的抄袭、散播、洩露;以礼相待,尊重所有对社区做出贡献的开发者;通过他人的分享习得知识,要留下足迹,表示感谢。
    • 热爱编程、热爱学习:加入组织,热爱编程是首当其要的。我们欢迎热爱讨论、分享、提问的朋友,也同样欢迎默默成就的朋友。
    • 倾听:正确并恳切对待、处理问题与建议,及时修复开源项目的 Bug ,及时与反馈者沟通。不抬杠、不无视、不辱骂。
    • 平视:不诋毁、轻视、嘲讽其他开发者,主动提出建议、施以帮助,以和谐为本。只要他人肯努力,你也可能会被昔日小看的人所超越,所以请保持谦虚。
    • 乐观且活跃:你的努力决定了你的高度。不要放弃,多年后回头俯瞰,才会发现自己已经成就往日所仰望的水平。积极地将项目开源,帮助他人学习、改进,自己也会获得相应的提升、成就与成就感。
    1 引用 • 487 回帖 • 3 关注
  • NetBeans

    NetBeans 是一个始于 1997 年的 Xelfi 计划,本身是捷克布拉格查理大学的数学及物理学院的学生计划。此计划延伸而成立了一家公司进而发展这个商用版本的 NetBeans IDE,直到 1999 年 Sun 买下此公司。Sun 于次年(2000 年)六月将 NetBeans IDE 开源,直到现在 NetBeans 的社群依然持续增长。

    78 引用 • 102 回帖 • 683 关注
  • 知乎

    知乎是网络问答社区,连接各行各业的用户。用户分享着彼此的知识、经验和见解,为中文互联网源源不断地提供多种多样的信息。

    10 引用 • 66 回帖 • 1 关注
  • 心情

    心是产生任何想法的源泉,心本体会陷入到对自己本体不能理解的状态中,因为心能产生任何想法,不能分出对错,不能分出自己。

    59 引用 • 369 回帖
  • Python

    Python 是一种面向对象、直译式电脑编程语言,具有近二十年的发展历史,成熟且稳定。它包含了一组完善而且容易理解的标准库,能够轻松完成很多常见的任务。它的语法简捷和清晰,尽量使用无异义的英语单词,与其它大多数程序设计语言使用大括号不一样,它使用缩进来定义语句块。

    545 引用 • 672 回帖
  • Sandbox

    如果帖子标签含有 Sandbox ,则该帖子会被视为“测试帖”,主要用于测试社区功能,排查 bug 等,该标签下内容不定期进行清理。

    409 引用 • 1246 回帖 • 587 关注
  • Telegram

    Telegram 是一个非盈利性、基于云端的即时消息服务。它提供了支持各大操作系统平台的开源的客户端,也提供了很多强大的 APIs 给开发者创建自己的客户端和机器人。

    5 引用 • 35 回帖
  • JSON

    JSON (JavaScript Object Notation)是一种轻量级的数据交换格式。易于人类阅读和编写。同时也易于机器解析和生成。

    52 引用 • 190 回帖 • 1 关注