支持向量机 (SVM),序列最小优化算法 (SMO)

本贴最后更新于 1337 天前,其中的信息可能已经事过境迁

支持向量机(Support Vector Machine)由V.N. Vapnik,A.Y. Chervonenkis,C. Cortes 等在1964年提出。序列最小优化算法(Sequential minimal optimization)是一种用于解决支持向量机训练过程中所产生优化问题的算法。由John C. Platt于1998年提出。

支持向量机的推导在西瓜书,各大网站已经有详细的介绍。本文主要依据 John C. Platt 发表的文章《Sequential Minimal Optimization: A Fast Algorithm for Training Support Vector Machines》来实现 SVM 与 SMO 算法。


算法的流程:

QQ 截图 20210306083526.png


import numpy as np
from sklearn import datasets
import matplotlib.pyplot as plt

定义需要的数据,包含数据样本,数据标签,偏置 b,拉格朗日乘子 α,容忍系数 C 等。

class Par:
	def __init__(self,n,D,C,eps,tol):
		self.X=datasets.make_blobs(n_samples=n,n_features=D,centers=2,cluster_std=1.0,shuffle=True,random_state=None)
		self.point=self.X[0]
		self.target=self.X[1]
		self.target[np.nonzero(self.target==0)[0]]=-1
		self.w=np.zeros((1,D))[0]
		self.b=0
		self.E=-self.target
		self.alpha=np.zeros((1,n))[0]
		self.n=n
		self.C=C
		self.eps=eps
		self.tol=tol

定义核函数,预测公式。

def kernel(x,y):
	return np.dot(x,y.T)

def f(x):
	s=0
	for i in range(n):
		s+=P.alpha[i]*P.target[i]*kernel(P.point[i],x)
	return s-P.b

被选中的一对 α 更新细节:

def takeStep(i1,i2):
	if i1==i2:
		return 0
	alph2=P.alpha[i2]
	alph1=P.alpha[i1]
	y1=P.target[i1]
	y2=P.target[i2]
	s=y1*y2
	#Compute L,H via equations (13) and (14)
	if y1!=y2:
		L=max(0,alph2-alph1)
		H=min(P.C,P.C+alph2-alph1)
	else:
		L=max(0,alph2+alph1-P.C)
		H=min(P.C,alph2+alph1)
	if L==H:
		return 0
	k11=kernel(P.point[i1],P.point[i1])
	k12=kernel(P.point[i1],P.point[i2])
	k22=kernel(P.point[i2],P.point[i2])
	eta=k11+k22-2*k12
	if eta>0:
		a2=alph2+y2*(P.E[i1]-P.E[i2])/eta
	if a2<L:
		a2=L
	elif a2>H:
		a2=H
	else:
		f1=y1*(P.E[i1]+b)-alph1*k11-s*alph2*k12
		f2=y2*(P.E[i2]+b)-s*alph1*k12-alph2*k22
		L1=alph1+s*(alph2-L)
		H1=alph1+s*(alph2+H)
		psiL=L1*f1+L*f2+0.5*L1**2*k11+0.5*L**2*k22+s*L*L1*k12
		psiH=H1*f1+H*f2+0.5*H1**2*k11+0.5*H**2*k22+s*H*H1*k12
		Lobj = psiL
		Hobj = psiH
	if Lobj<Hobj-eps:
		a2=L
	elif Lobj>Hobj+eps:
		a2=H
	else:
		a2=alph2
	if abs(a2-alph2)<P.eps*(a2+alph2+P.eps):
		return 0
	a1=alph1+s*(alph2-a2)
	#Update threshold to reflect change in Lagrange multipliers
	b1=P.E[i1]+y1*(a1-alph1)*k11+y2*(a2-alph2)*k12+P.b
	b2=P.E[i2]+y1*(a1-alph1)*k12+y2*(a2-alph2)*k22+P.b
	if a1>0 and a1<P.C:
		P.b=b1
	elif a2>0 and a2<P.C:
		P.b=b2
	else:
		P.b=(b1+b2)/2
	#Update weight vector to reflect change in a1 & a2, if SVM is linear
	P.w=P.w+y1*(a1-alph1)*P.point[i1]+y2*(a2-alph2)*P.point[i2]
	#Store a1 in the alpha array
	P.alpha[i1]=a1
	#Store a2 in the alpha array
	P.alpha[i2]=a2
	#Update error cache using new Lagrange multipliers
	P.E[i1]=f(P.point[i1])-P.target[i1]
	P.E[i2]=f(P.point[i2])-P.target[i2]
	return 1

内循环选择第二个 α:

def examineExample(i2):
	global valid
	alph2=P.alpha[i2]
	y2=P.target[i2]
	r2=P.E[i2]*y2
	if (r2<-P.tol and alph2<P.C) or (r2>P.tol and alph2>0):
		valid=np.where((P.alpha!=0) & (P.alpha!=C))[0]
		Long=len(valid)
		if Long > 1:
			#i1 = result of second choice heuristic (section 2.2)
			best=-1
			if len(valid)>1:
				for k in valid:
					deltaE=abs(P.E[i2]-P.E[k])
					if deltaE>best:
						best=deltaE
						i1=k
					if takeStep(i1,i2):
						return 1
		#loop over all non-zero and non-C alpha, starting at a random point
		if Long>0:
			random_index=np.random.randint(0,Long)
			for i in np.hstack((valid[random_index:Long],valid[0:random_index])):
				i1=i
				if takeStep(i1,i2):
					return 1
		#loop over all possible i1, starting at a random point
		random_index=np.random.randint(0,n)
		for i in np.hstack((np.arange(random_index,n),np.arange(0,random_index))):
			#i1=loop variable
			i1=i
			if takeStep(i1,i2):
				return 1
	return 0

外循环选择第一个 α:

def SMO():
	global valid
	numChanged=0
	examineAll=1
	while numChanged>0 or examineAll:
		numChanged=0
		if examineAll:
			for i in range(n):
				numChanged+=examineExample(i)
		else:
			#loop I over examples where alpha is not 0 & not C
			for i in valid:
				numChanged+=examineExample(i)
		if examineAll==1:
			examineAll=0
		elif numChanged==0:
			examineAll=1

主函数入口:

if __name__ == '__main__':
	n=100       #样本个数
	C=10         
	eps=0.001  #停止精度
	tol=0.001   #分类容错率
	D=2            #样本维度
	P=Par(n,D,C,eps,tol) 
	SMO()
	#绘制图像
	plt.scatter(P.point[:,0],P.point[:,1],c=P.target)
	x=np.arange(-10,10,0.1)
	y=(P.b-P.w[0]*x)/P.w[1]
	plt.plot(x,y)
	plt.show()
	Y=kernel(P.point,P.w)-P.b
	count=0
	for i in range(n):
	if Y[i]*P.target[i]<0:
	count+=1
	print('Error Point num:',count)

单次测试结果:

Figure1.png

  • 分类
    8 引用 • 10 回帖
  • 机器学习

    机器学习(Machine Learning)是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。

    83 引用 • 37 回帖 • 1 关注

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...

推荐标签 标签

  • GitBook

    GitBook 使您的团队可以轻松编写和维护高质量的文档。 分享知识,提高团队的工作效率,让用户满意。

    3 引用 • 8 回帖 • 4 关注
  • MyBatis

    MyBatis 本是 Apache 软件基金会 的一个开源项目 iBatis,2010 年这个项目由 Apache 软件基金会迁移到了 google code,并且改名为 MyBatis ,2013 年 11 月再次迁移到了 GitHub。

    170 引用 • 414 回帖 • 383 关注
  • Git

    Git 是 Linux Torvalds 为了帮助管理 Linux 内核开发而开发的一个开放源码的版本控制软件。

    209 引用 • 358 回帖
  • 国际化

    i18n(其来源是英文单词 internationalization 的首末字符 i 和 n,18 为中间的字符数)是“国际化”的简称。对程序来说,国际化是指在不修改代码的情况下,能根据不同语言及地区显示相应的界面。

    8 引用 • 26 回帖
  • Jenkins

    Jenkins 是一套开源的持续集成工具。它提供了非常丰富的插件,让构建、部署、自动化集成项目变得简单易用。

    53 引用 • 37 回帖
  • 数据库

    据说 99% 的性能瓶颈都在数据库。

    338 引用 • 705 回帖
  • RYMCU

    RYMCU 致力于打造一个即严谨又活泼、专业又不失有趣,为数百万人服务的开源嵌入式知识学习交流平台。

    4 引用 • 6 回帖 • 53 关注
  • Latke

    Latke 是一款以 JSON 为主的 Java Web 框架。

    70 引用 • 533 回帖 • 778 关注
  • Ant-Design

    Ant Design 是服务于企业级产品的设计体系,基于确定和自然的设计价值观上的模块化解决方案,让设计者和开发者专注于更好的用户体验。

    17 引用 • 23 回帖
  • 程序员

    程序员是从事程序开发、程序维护的专业人员。

    565 引用 • 3532 回帖
  • 开源中国

    开源中国是目前中国最大的开源技术社区。传播开源的理念,推广开源项目,为 IT 开发者提供了一个发现、使用、并交流开源技术的平台。目前开源中国社区已收录超过两万款开源软件。

    7 引用 • 86 回帖
  • 小薇

    小薇是一个用 Java 写的 QQ 聊天机器人 Web 服务,可以用于社群互动。

    由于 Smart QQ 从 2019 年 1 月 1 日起停止服务,所以该项目也已经停止维护了!

    34 引用 • 467 回帖 • 741 关注
  • 禅道

    禅道是一款国产的开源项目管理软件,她的核心管理思想基于敏捷方法 scrum,内置了产品管理和项目管理,同时又根据国内研发现状补充了测试管理、计划管理、发布管理、文档管理、事务管理等功能,在一个软件中就可以将软件研发中的需求、任务、bug、用例、计划、发布等要素有序的跟踪管理起来,完整地覆盖了项目管理的核心流程。

    6 引用 • 15 回帖 • 127 关注
  • Mobi.css

    Mobi.css is a lightweight, flexible CSS framework that focus on mobile.

    1 引用 • 6 回帖 • 733 关注
  • WebComponents

    Web Components 是 W3C 定义的标准,它给了前端开发者扩展浏览器标签的能力,可以方便地定制可复用组件,更好的进行模块化开发,解放了前端开发者的生产力。

    1 引用 • 2 关注
  • Bootstrap

    Bootstrap 是 Twitter 推出的一个用于前端开发的开源工具包。它由 Twitter 的设计师 Mark Otto 和 Jacob Thornton 合作开发,是一个 CSS / HTML 框架。

    18 引用 • 33 回帖 • 659 关注
  • Firefox

    Mozilla Firefox 中文俗称“火狐”(正式缩写为 Fx 或 fx,非正式缩写为 FF),是一个开源的网页浏览器,使用 Gecko 排版引擎,支持多种操作系统,如 Windows、OSX 及 Linux 等。

    8 引用 • 30 回帖 • 407 关注
  • 996
    13 引用 • 200 回帖 • 2 关注
  • Gitea

    Gitea 是一个开源社区驱动的轻量级代码托管解决方案,后端采用 Go 编写,采用 MIT 许可证。

    4 引用 • 16 回帖
  • Scala

    Scala 是一门多范式的编程语言,集成面向对象编程和函数式编程的各种特性。

    13 引用 • 11 回帖 • 123 关注
  • ZooKeeper

    ZooKeeper 是一个分布式的,开放源码的分布式应用程序协调服务,是 Google 的 Chubby 一个开源的实现,是 Hadoop 和 HBase 的重要组件。它是一个为分布式应用提供一致性服务的软件,提供的功能包括:配置维护、域名服务、分布式同步、组服务等。

    59 引用 • 29 回帖 • 3 关注
  • 正则表达式

    正则表达式(Regular Expression)使用单个字符串来描述、匹配一系列遵循某个句法规则的字符串。

    31 引用 • 94 回帖 • 1 关注
  • OpenShift

    红帽提供的 PaaS 云,支持多种编程语言,为开发人员提供了更为灵活的框架、存储选择。

    14 引用 • 20 回帖 • 623 关注
  • 阿里巴巴

    阿里巴巴网络技术有限公司(简称:阿里巴巴集团)是以曾担任英语教师的马云为首的 18 人,于 1999 年在中国杭州创立,他们相信互联网能够创造公平的竞争环境,让小企业通过创新与科技扩展业务,并在参与国内或全球市场竞争时处于更有利的位置。

    43 引用 • 221 回帖 • 127 关注
  • 一些有用的避坑指南。

    69 引用 • 93 回帖
  • Ubuntu

    Ubuntu(友帮拓、优般图、乌班图)是一个以桌面应用为主的 Linux 操作系统,其名称来自非洲南部祖鲁语或豪萨语的“ubuntu”一词,意思是“人性”、“我的存在是因为大家的存在”,是非洲传统的一种价值观,类似华人社会的“仁爱”思想。Ubuntu 的目标在于为一般用户提供一个最新的、同时又相当稳定的主要由自由软件构建而成的操作系统。

    124 引用 • 169 回帖
  • MongoDB

    MongoDB(来自于英文单词“Humongous”,中文含义为“庞大”)是一个基于分布式文件存储的数据库,由 C++ 语言编写。旨在为应用提供可扩展的高性能数据存储解决方案。MongoDB 是一个介于关系数据库和非关系数据库之间的产品,是非关系数据库当中功能最丰富,最像关系数据库的。它支持的数据结构非常松散,是类似 JSON 的 BSON 格式,因此可以存储比较复杂的数据类型。

    90 引用 • 59 回帖 • 4 关注