支持向量机 (SVM),序列最小优化算法 (SMO)

本贴最后更新于 1521 天前,其中的信息可能已经事过境迁

支持向量机(Support Vector Machine)由V.N. Vapnik,A.Y. Chervonenkis,C. Cortes 等在1964年提出。序列最小优化算法(Sequential minimal optimization)是一种用于解决支持向量机训练过程中所产生优化问题的算法。由John C. Platt于1998年提出。

支持向量机的推导在西瓜书,各大网站已经有详细的介绍。本文主要依据 John C. Platt 发表的文章《Sequential Minimal Optimization: A Fast Algorithm for Training Support Vector Machines》来实现 SVM 与 SMO 算法。


算法的流程:

QQ 截图 20210306083526.png


import numpy as np from sklearn import datasets import matplotlib.pyplot as plt

定义需要的数据,包含数据样本,数据标签,偏置 b,拉格朗日乘子 α,容忍系数 C 等。

class Par: def __init__(self,n,D,C,eps,tol): self.X=datasets.make_blobs(n_samples=n,n_features=D,centers=2,cluster_std=1.0,shuffle=True,random_state=None) self.point=self.X[0] self.target=self.X[1] self.target[np.nonzero(self.target==0)[0]]=-1 self.w=np.zeros((1,D))[0] self.b=0 self.E=-self.target self.alpha=np.zeros((1,n))[0] self.n=n self.C=C self.eps=eps self.tol=tol

定义核函数,预测公式。

def kernel(x,y): return np.dot(x,y.T) def f(x): s=0 for i in range(n): s+=P.alpha[i]*P.target[i]*kernel(P.point[i],x) return s-P.b

被选中的一对 α 更新细节:

def takeStep(i1,i2): if i1==i2: return 0 alph2=P.alpha[i2] alph1=P.alpha[i1] y1=P.target[i1] y2=P.target[i2] s=y1*y2 #Compute L,H via equations (13) and (14) if y1!=y2: L=max(0,alph2-alph1) H=min(P.C,P.C+alph2-alph1) else: L=max(0,alph2+alph1-P.C) H=min(P.C,alph2+alph1) if L==H: return 0 k11=kernel(P.point[i1],P.point[i1]) k12=kernel(P.point[i1],P.point[i2]) k22=kernel(P.point[i2],P.point[i2]) eta=k11+k22-2*k12 if eta>0: a2=alph2+y2*(P.E[i1]-P.E[i2])/eta if a2<L: a2=L elif a2>H: a2=H else: f1=y1*(P.E[i1]+b)-alph1*k11-s*alph2*k12 f2=y2*(P.E[i2]+b)-s*alph1*k12-alph2*k22 L1=alph1+s*(alph2-L) H1=alph1+s*(alph2+H) psiL=L1*f1+L*f2+0.5*L1**2*k11+0.5*L**2*k22+s*L*L1*k12 psiH=H1*f1+H*f2+0.5*H1**2*k11+0.5*H**2*k22+s*H*H1*k12 Lobj = psiL Hobj = psiH if Lobj<Hobj-eps: a2=L elif Lobj>Hobj+eps: a2=H else: a2=alph2 if abs(a2-alph2)<P.eps*(a2+alph2+P.eps): return 0 a1=alph1+s*(alph2-a2) #Update threshold to reflect change in Lagrange multipliers b1=P.E[i1]+y1*(a1-alph1)*k11+y2*(a2-alph2)*k12+P.b b2=P.E[i2]+y1*(a1-alph1)*k12+y2*(a2-alph2)*k22+P.b if a1>0 and a1<P.C: P.b=b1 elif a2>0 and a2<P.C: P.b=b2 else: P.b=(b1+b2)/2 #Update weight vector to reflect change in a1 & a2, if SVM is linear P.w=P.w+y1*(a1-alph1)*P.point[i1]+y2*(a2-alph2)*P.point[i2] #Store a1 in the alpha array P.alpha[i1]=a1 #Store a2 in the alpha array P.alpha[i2]=a2 #Update error cache using new Lagrange multipliers P.E[i1]=f(P.point[i1])-P.target[i1] P.E[i2]=f(P.point[i2])-P.target[i2] return 1

内循环选择第二个 α:

def examineExample(i2): global valid alph2=P.alpha[i2] y2=P.target[i2] r2=P.E[i2]*y2 if (r2<-P.tol and alph2<P.C) or (r2>P.tol and alph2>0): valid=np.where((P.alpha!=0) & (P.alpha!=C))[0] Long=len(valid) if Long > 1: #i1 = result of second choice heuristic (section 2.2) best=-1 if len(valid)>1: for k in valid: deltaE=abs(P.E[i2]-P.E[k]) if deltaE>best: best=deltaE i1=k if takeStep(i1,i2): return 1 #loop over all non-zero and non-C alpha, starting at a random point if Long>0: random_index=np.random.randint(0,Long) for i in np.hstack((valid[random_index:Long],valid[0:random_index])): i1=i if takeStep(i1,i2): return 1 #loop over all possible i1, starting at a random point random_index=np.random.randint(0,n) for i in np.hstack((np.arange(random_index,n),np.arange(0,random_index))): #i1=loop variable i1=i if takeStep(i1,i2): return 1 return 0

外循环选择第一个 α:

def SMO(): global valid numChanged=0 examineAll=1 while numChanged>0 or examineAll: numChanged=0 if examineAll: for i in range(n): numChanged+=examineExample(i) else: #loop I over examples where alpha is not 0 & not C for i in valid: numChanged+=examineExample(i) if examineAll==1: examineAll=0 elif numChanged==0: examineAll=1

主函数入口:

if __name__ == '__main__': n=100 #样本个数 C=10 eps=0.001 #停止精度 tol=0.001 #分类容错率 D=2 #样本维度 P=Par(n,D,C,eps,tol) SMO() #绘制图像 plt.scatter(P.point[:,0],P.point[:,1],c=P.target) x=np.arange(-10,10,0.1) y=(P.b-P.w[0]*x)/P.w[1] plt.plot(x,y) plt.show() Y=kernel(P.point,P.w)-P.b count=0 for i in range(n): if Y[i]*P.target[i]<0: count+=1 print('Error Point num:',count)

单次测试结果:

Figure1.png

  • 分类
    8 引用 • 10 回帖
  • 机器学习

    机器学习(Machine Learning)是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。

    83 引用 • 37 回帖

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...

推荐标签 标签

  • 强迫症

    强迫症(OCD)属于焦虑障碍的一种类型,是一组以强迫思维和强迫行为为主要临床表现的神经精神疾病,其特点为有意识的强迫和反强迫并存,一些毫无意义、甚至违背自己意愿的想法或冲动反反复复侵入患者的日常生活。

    15 引用 • 161 回帖
  • Gzip

    gzip (GNU zip)是 GNU 自由软件的文件压缩程序。我们在 Linux 中经常会用到后缀为 .gz 的文件,它们就是 Gzip 格式的。现今已经成为互联网上使用非常普遍的一种数据压缩格式,或者说一种文件格式。

    9 引用 • 12 回帖 • 166 关注
  • Logseq

    Logseq 是一个隐私优先、开源的知识库工具。

    Logseq is a joyful, open-source outliner that works on top of local plain-text Markdown and Org-mode files. Use it to write, organize and share your thoughts, keep your to-do list, and build your own digital garden.

    7 引用 • 69 回帖
  • GitLab

    GitLab 是利用 Ruby 一个开源的版本管理系统,实现一个自托管的 Git 项目仓库,可通过 Web 界面操作公开或私有项目。

    46 引用 • 72 回帖
  • 架构

    我们平时所说的“架构”主要是指软件架构,这是有关软件整体结构与组件的抽象描述,用于指导软件系统各个方面的设计。另外还有“业务架构”、“网络架构”、“硬件架构”等细分领域。

    143 引用 • 442 回帖 • 1 关注
  • Solidity

    Solidity 是一种智能合约高级语言,运行在 [以太坊] 虚拟机(EVM)之上。它的语法接近于 JavaScript,是一种面向对象的语言。

    3 引用 • 18 回帖 • 436 关注
  • 宕机

    宕机,多指一些网站、游戏、网络应用等服务器一种区别于正常运行的状态,也叫“Down 机”、“当机”或“死机”。宕机状态不仅仅是指服务器“挂掉了”、“死机了”状态,也包括服务器假死、停用、关闭等一些原因而导致出现的不能够正常运行的状态。

    13 引用 • 82 回帖 • 77 关注
  • Vim

    Vim 是类 UNIX 系统文本编辑器 Vi 的加强版本,加入了更多特性来帮助编辑源代码。Vim 的部分增强功能包括文件比较(vimdiff)、语法高亮、全面的帮助系统、本地脚本(Vimscript)和便于选择的可视化模式。

    29 引用 • 66 回帖 • 1 关注
  • 黑曜石

    黑曜石是一款强大的知识库工具,支持本地 Markdown 文件编辑,支持双向链接和关系图。

    A second brain, for you, forever.

    24 引用 • 241 回帖 • 1 关注
  • FreeMarker

    FreeMarker 是一款好用且功能强大的 Java 模版引擎。

    23 引用 • 20 回帖 • 467 关注
  • MyBatis

    MyBatis 本是 Apache 软件基金会 的一个开源项目 iBatis,2010 年这个项目由 Apache 软件基金会迁移到了 google code,并且改名为 MyBatis ,2013 年 11 月再次迁移到了 GitHub。

    173 引用 • 414 回帖 • 364 关注
  • 笔记

    好记性不如烂笔头。

    310 引用 • 794 回帖
  • Git

    Git 是 Linux Torvalds 为了帮助管理 Linux 内核开发而开发的一个开放源码的版本控制软件。

    211 引用 • 358 回帖 • 1 关注
  • 30Seconds

    📙 前端知识精选集,包含 HTML、CSS、JavaScript、React、Node、安全等方面,每天仅需 30 秒。

    • 精选常见面试题,帮助您准备下一次面试
    • 精选常见交互,帮助您拥有简洁酷炫的站点
    • 精选有用的 React 片段,帮助你获取最佳实践
    • 精选常见代码集,帮助您提高打码效率
    • 整理前端界的最新资讯,邀您一同探索新世界
    488 引用 • 384 回帖 • 10 关注
  • ActiveMQ

    ActiveMQ 是 Apache 旗下的一款开源消息总线系统,它完整实现了 JMS 规范,是一个企业级的消息中间件。

    19 引用 • 13 回帖 • 679 关注
  • Eclipse

    Eclipse 是一个开放源代码的、基于 Java 的可扩展开发平台。就其本身而言,它只是一个框架和一组服务,用于通过插件组件构建开发环境。

    76 引用 • 258 回帖 • 628 关注
  • Hexo

    Hexo 是一款快速、简洁且高效的博客框架,使用 Node.js 编写。

    22 引用 • 148 回帖 • 16 关注
  • 小说

    小说是以刻画人物形象为中心,通过完整的故事情节和环境描写来反映社会生活的文学体裁。

    32 引用 • 108 回帖
  • OkHttp

    OkHttp 是一款 HTTP & HTTP/2 客户端库,专为 Android 和 Java 应用打造。

    16 引用 • 6 回帖 • 85 关注
  • 爬虫

    网络爬虫(Spider、Crawler),是一种按照一定的规则,自动地抓取万维网信息的程序。

    106 引用 • 275 回帖
  • Firefox

    Mozilla Firefox 中文俗称“火狐”(正式缩写为 Fx 或 fx,非正式缩写为 FF),是一个开源的网页浏览器,使用 Gecko 排版引擎,支持多种操作系统,如 Windows、OSX 及 Linux 等。

    7 引用 • 30 回帖 • 385 关注
  • SQLServer

    SQL Server 是由 [微软] 开发和推广的关系数据库管理系统(DBMS),它最初是由 微软、Sybase 和 Ashton-Tate 三家公司共同开发的,并于 1988 年推出了第一个 OS/2 版本。

    21 引用 • 31 回帖
  • Pipe

    Pipe 是一款小而美的开源博客平台。Pipe 有着非常活跃的社区,可将文章作为帖子推送到社区,来自社区的回帖将作为博客评论进行联动(具体细节请浏览 B3log 构思 - 分布式社区网络)。

    这是一种全新的网络社区体验,让热爱记录和分享的你不再感到孤单!

    133 引用 • 1124 回帖 • 111 关注
  • Openfire

    Openfire 是开源的、基于可拓展通讯和表示协议 (XMPP)、采用 Java 编程语言开发的实时协作服务器。Openfire 的效率很高,单台服务器可支持上万并发用户。

    6 引用 • 7 回帖 • 106 关注
  • JavaScript

    JavaScript 一种动态类型、弱类型、基于原型的直译式脚本语言,内置支持类型。它的解释器被称为 JavaScript 引擎,为浏览器的一部分,广泛用于客户端的脚本语言,最早是在 HTML 网页上使用,用来给 HTML 网页增加动态功能。

    730 引用 • 1280 回帖 • 5 关注
  • Kubernetes

    Kubernetes 是 Google 开源的一个容器编排引擎,它支持自动化部署、大规模可伸缩、应用容器化管理。

    116 引用 • 54 回帖 • 4 关注
  • 一些有用的避坑指南。

    69 引用 • 93 回帖 • 1 关注