Spark mllib API- classification

本贴最后更新于 3417 天前,其中的信息可能已经渤澥桑田

Apark mllib API 的翻译 - 分类篇。 对官方文档进行翻译的同时加入了一些常识性知识。

更多分类的相关知识可以查看我的另外一篇博客数据挖掘算法初窥门庭--分类回归

Spark 当前提供 LogisticRegression、SVM、NaiveBayes。


##LogisticRegression 逻辑回归

###背景知识

LinerRegression 是使用线性方程对数据进行两分类(在线的一侧属于同一类)。而 LogisticRegression 就是一个被 logistic 方程归一化后的 LinerRegression(归一化后值域为 0-1)。LogisticRegression 一般也用于两分类,预测样本属于某个类别的概率。

LogisticRegression 的过程是典型的监督机器学习,也就是在规则化参数的同时最小化误差。最小化误差是为了让我们的模型拟合我们的训练数据,而规则化参数是防止我们的模型过分拟合我们的训练数据。
大致步骤如下:

  • 目标函数为 f(f 为未知的),我们假定目标函数为 h。(假设)
  • 构造损失函数 cost(基于最大似然估计),表示 h 的预测结果与实际结果 f 之间的偏差。(预测并评估)
  • 通过迭代,调整 h,使 h 与 f 尽可能接近。(求最优解)

LogisticRegression 有很多不同的算法版本,大多数的主要不同在于求最优解。目前,spark 提供两种 LogisticRegression 方法:SGD(随机梯度下降)和 LBFGS(改进的拟牛顿法)。

特征选择:

  • LogisticRegression 假设向量的各个维度是独立不相互影响的。
  • 由于 LogisticRegression 的终止条件是收敛或达到最大迭代次数,因此在数据预处理时进行归一化,加快收敛速度。
  • 更多具体的变量选择方法,参考华山大师兄的 Logistic Regression--逻辑回归算法汇总

###Spark API

  • 类:pyspark.mllib.classification.LogisticRegressionWithSGD
    • 方法:
      train(data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None, regParam=0.01, regType='l2', intercept=False, validateData=True, convergenceTol=0.001)
      通过给定数据训练逻辑回归模型。
      • data:训练数据,LabeledPoint 格式的 RDD 数据集。
      • iterations:迭代次数,默认为 100。
      • step:SGD 的步长,默认为 1.0。(太大容易错过最优解,太小导致迭代次数过多)。
      • miniBatchFraction:用于每次 SGD 迭代的数据,默认 1.0。(SGD 每次迭代选用随机数据)。
      • initialWeights:初始权值,默认 None。
      • regParam:规则化参数,默认 0.01。
      • regType:用于训练模型的规则化类型,可选为 l1 或 l2(默认)。
      • intercept:布尔值,表示是否使用增强表现来训练数据,默认 False。
      • validateData:布尔值,表示算法是否在训练前检验数据,默认 True。
      • convergenceTol:终止迭代的收敛值,默认 0.001。

  • 类: pyspark.mllib.classification.LogisticRegressionWithLBFGS
    • 方法:
      train(data, iterations=100, initialWeights=None, regParam=0.01, regType='l2', intercept=False, corrections=10, tolerance=0.0001, validateData=True, numClasses=2)
      通过给定数据训练逻辑回归模型。
      • data:训练数据,LabeledPoint 格式的 RDD 数据集。
      • iterations:迭代次数,默认为 100。
      • initialWeights:初始权值,默认 None。
      • regParam:规则化参数,默认 0.01。
      • regType:用于训练模型的规则化类型,可选为 l1 或 l2(默认)。
      • intercept:布尔值,表示是否使用增强表现来训练数据,默认 False。
      • corrections:用于 LBFGS 更新的修正值,默认 10。
      • tolerance:LBFGS 迭代的收敛容忍系数,默认 1e-4。
      • validateData:布尔值,表示算法是否在训练前检验数据,默认 True。
      • numClasses:多分类逻辑回归中类别的个数,默认 2。

  • 类:pyspark.mllib.classification.LogisticRegressionModel
    使用多/两逻辑分类方法训练得到的模型。
    • 属性:
      • weights:每个向量计算的权值。
      • intercept:该模型的计算截距(只用于两逻辑回归)。
      • numFeatures:向量的维度。
      • numClasses:输出类别的个数。
      • threshold:用于区分正负样本的阈值。
    • 方法: clearThreshold()
      去除阈值,直接输出预测值,只用于两分类
    • 方法: load(sc, path)
      从指定路径加载模型
    • 方法: save(sc, path)
      将模型保存到指定路径
    • 方法: predict(x)
      预测,输入可以为单个向量或整个 RDD
    • 方法: setThreshold(value)
      设置用于区分正负样本的阈值。当预测值大于该预置时,判定为正样本。

SVM 支持向量机

###背景知识

SVM 是二分类的分类模型。给定包含正负样本的数据集,SVM 的目的是寻找一个超平面(WX+b=0)对样本进行分割,且使得离超平面比较近的点能有更大的间距。

(待补充)


###Spark API

  • 类:class pyspark.mllib.classification.SVMWithSGD
    • 方法:
      train(data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, regType='l2', intercept=False, validateData=True, convergenceTol=0.001)
      通过给定的数据训练 SVM 模型。
      • data:训练数据,LabeledPoint 格式的 RDD 数据集。
      • iterations:迭代次数,默认为 100。
      • step:SGD 的步长,默认为 1.0。
      • regParam:规则化参数,默认 0.01。
      • miniBatchFraction:用于每次 SGD 迭代的数据,默认 1.0。
      • initialWeights:初始权值,默认 None。
      • regType:用于训练模型的规则化类型,可选为 l1 或 l2(默认)。
      • intercept:布尔值,表示是否使用增强表现来训练数据,默认 False。
      • validateData:布尔值,表示算法是否在训练前检验数据,默认 True。
      • convergenceTol:终止迭代的收敛值,默认 0.001。

  • 类: pyspark.mllib.classification.SVMModel
    支持向量机模型
    • 属性:
      • weights:每个向量计算的权值。
      • intercept:该模型的计算截距。
    • 方法: clearThreshold()
      去除阈值,直接输出预测值
    • 方法: load(sc, path)
      从指定路径加载模型
    • 方法: save(sc, path)
      将模型保存到指定路径
    • 方法: predict(x)
      预测,输入可以为单个向量或整个 RDD
    • 方法: setThreshold(value)
      设置用于区分正负样本的阈值。当预测值大于该预置时,判定为正样本。

##NaiveBayes 朴素贝叶斯
###背景知识
贝叶斯概率公式:
P(B[j]|A[i])=P(A[i]|B[j])P(B[j]) / P(A[i])
朴素贝叶斯分类器是使用贝叶斯概率公式为核心的分类算法,其基本思想为:对于给出的待分类项,求解在此项出现的条件下各个类别出现的概率,哪个最大,就认为此待分类项属于哪个类别。
朴素贝叶斯假定样本的不同特征属性对样本的归类影响时相互独立的。

(待补充)


###Spark API

  • 类:pyspark.mllib.classification.NaiveBayes
    • 方法:
      train(data, lambda_=1.0)
      通过给定数据集训练贝叶斯模型
      • data:训练数据,LabeledPoint 格式的 RDD 数据集。
      • lambda:平滑参数,默认 1.0

  • 类: pyspark.mllib.classification.NaiveBayesModel
    朴素贝叶斯分类器模型
    • 属性:
      • labels:label 列表
      • pi:每个类别的 priors
      • theta:使用矩阵存储每个向量划分到每个类的条件概率
    • 方法: load(sc, path)
      从指定路径加载模型
    • 方法: save(sc, path)
      将模型保存到指定路径
    • 方法: predict(x)
      预测,输入可以为单个向量或整个 RDD
  • Spark

    Spark 是 UC Berkeley AMP lab 所开源的类 Hadoop MapReduce 的通用并行框架。Spark 拥有 Hadoop MapReduce 所具有的优点;但不同于 MapReduce 的是 Job 中间输出结果可以保存在内存中,从而不再需要读写 HDFS,因此 Spark 能更好地适用于数据挖掘与机器学习等需要迭代的 MapReduce 的算法。

    74 引用 • 46 回帖 • 565 关注
  • 数据挖掘
    17 引用 • 32 回帖 • 3 关注
  • 默认
    5 引用 • 22 回帖

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...
  • zempty via macOS

    ????1???😇

推荐标签 标签

  • 支付宝

    支付宝是全球领先的独立第三方支付平台,致力于为广大用户提供安全快速的电子支付/网上支付/安全支付/手机支付体验,及转账收款/水电煤缴费/信用卡还款/AA 收款等生活服务应用。

    29 引用 • 347 回帖
  • NGINX

    NGINX 是一个高性能的 HTTP 和反向代理服务器,也是一个 IMAP/POP3/SMTP 代理服务器。 NGINX 是由 Igor Sysoev 为俄罗斯访问量第二的 Rambler.ru 站点开发的,第一个公开版本 0.1.0 发布于 2004 年 10 月 4 日。

    315 引用 • 547 回帖
  • OpenResty

    OpenResty 是一个基于 NGINX 与 Lua 的高性能 Web 平台,其内部集成了大量精良的 Lua 库、第三方模块以及大多数的依赖项。用于方便地搭建能够处理超高并发、扩展性极高的动态 Web 应用、Web 服务和动态网关。

    17 引用 • 50 关注
  • Openfire

    Openfire 是开源的、基于可拓展通讯和表示协议 (XMPP)、采用 Java 编程语言开发的实时协作服务器。Openfire 的效率很高,单台服务器可支持上万并发用户。

    6 引用 • 7 回帖 • 118 关注
  • Hprose

    Hprose 是一款先进的轻量级、跨语言、跨平台、无侵入式、高性能动态远程对象调用引擎库。它不仅简单易用,而且功能强大。你无需专门学习,只需看上几眼,就能用它轻松构建分布式应用系统。

    9 引用 • 17 回帖 • 639 关注
  • 音乐

    你听到信仰的声音了么?

    62 引用 • 512 回帖
  • Unity

    Unity 是由 Unity Technologies 开发的一个让开发者可以轻松创建诸如 2D、3D 多平台的综合型游戏开发工具,是一个全面整合的专业游戏引擎。

    25 引用 • 7 回帖 • 114 关注
  • HBase

    HBase 是一个分布式的、面向列的开源数据库,该技术来源于 Fay Chang 所撰写的 Google 论文 “Bigtable:一个结构化数据的分布式存储系统”。就像 Bigtable 利用了 Google 文件系统所提供的分布式数据存储一样,HBase 在 Hadoop 之上提供了类似于 Bigtable 的能力。

    17 引用 • 6 回帖 • 73 关注
  • MySQL

    MySQL 是一个关系型数据库管理系统,由瑞典 MySQL AB 公司开发,目前属于 Oracle 公司。MySQL 是最流行的关系型数据库管理系统之一。

    694 引用 • 537 回帖
  • GAE

    Google App Engine(GAE)是 Google 管理的数据中心中用于 WEB 应用程序的开发和托管的平台。2008 年 4 月 发布第一个测试版本。目前支持 Python、Java 和 Go 开发部署。全球已有数十万的开发者在其上开发了众多的应用。

    14 引用 • 42 回帖 • 838 关注
  • JSON

    JSON (JavaScript Object Notation)是一种轻量级的数据交换格式。易于人类阅读和编写。同时也易于机器解析和生成。

    53 引用 • 190 回帖
  • FFmpeg

    FFmpeg 是一套可以用来记录、转换数字音频、视频,并能将其转化为流的开源计算机程序。

    23 引用 • 32 回帖 • 1 关注
  • JVM

    JVM(Java Virtual Machine)Java 虚拟机是一个微型操作系统,有自己的硬件构架体系,还有相应的指令系统。能够识别 Java 独特的 .class 文件(字节码),能够将这些文件中的信息读取出来,使得 Java 程序只需要生成 Java 虚拟机上的字节码后就能在不同操作系统平台上进行运行。

    180 引用 • 120 回帖 • 2 关注
  • 爬虫

    网络爬虫(Spider、Crawler),是一种按照一定的规则,自动地抓取万维网信息的程序。

    106 引用 • 275 回帖
  • Swift

    Swift 是苹果于 2014 年 WWDC(苹果开发者大会)发布的开发语言,可与 Objective-C 共同运行于 Mac OS 和 iOS 平台,用于搭建基于苹果平台的应用程序。

    34 引用 • 37 回帖 • 554 关注
  • Love2D

    Love2D 是一个开源的, 跨平台的 2D 游戏引擎。使用纯 Lua 脚本来进行游戏开发。目前支持的平台有 Windows, Mac OS X, Linux, Android 和 iOS。

    14 引用 • 53 回帖 • 560 关注
  • flomo

    flomo 是新一代 「卡片笔记」 ,专注在碎片化时代,促进你的记录,帮你积累更多知识资产。

    6 引用 • 143 回帖
  • Oracle

    Oracle(甲骨文)公司,全称甲骨文股份有限公司(甲骨文软件系统有限公司),是全球最大的企业级软件公司,总部位于美国加利福尼亚州的红木滩。1989 年正式进入中国市场。2013 年,甲骨文已超越 IBM,成为继 Microsoft 后全球第二大软件公司。

    107 引用 • 127 回帖 • 340 关注
  • Spark

    Spark 是 UC Berkeley AMP lab 所开源的类 Hadoop MapReduce 的通用并行框架。Spark 拥有 Hadoop MapReduce 所具有的优点;但不同于 MapReduce 的是 Job 中间输出结果可以保存在内存中,从而不再需要读写 HDFS,因此 Spark 能更好地适用于数据挖掘与机器学习等需要迭代的 MapReduce 的算法。

    74 引用 • 46 回帖 • 565 关注
  • OpenStack

    OpenStack 是一个云操作系统,通过数据中心可控制大型的计算、存储、网络等资源池。所有的管理通过前端界面管理员就可以完成,同样也可以通过 Web 接口让最终用户部署资源。

    10 引用 • 4 关注
  • CodeMirror
    2 引用 • 17 回帖 • 177 关注
  • Wide

    Wide 是一款基于 Web 的 Go 语言 IDE。通过浏览器就可以进行 Go 开发,并有代码自动完成、查看表达式、编译反馈、Lint、实时结果输出等功能。

    欢迎访问我们运维的实例: https://wide.b3log.org

    30 引用 • 218 回帖 • 643 关注
  • Rust

    Rust 是一门赋予每个人构建可靠且高效软件能力的语言。Rust 由 Mozilla 开发,最早发布于 2014 年 9 月。

    59 引用 • 22 回帖 • 1 关注
  • OnlyOffice
    4 引用 • 18 关注
  • 996
    13 引用 • 200 回帖
  • 运维

    互联网运维工作,以服务为中心,以稳定、安全、高效为三个基本点,确保公司的互联网业务能够 7×24 小时为用户提供高质量的服务。

    151 引用 • 257 回帖
  • 大数据

    大数据(big data)是指无法在一定时间范围内用常规软件工具进行捕捉、管理和处理的数据集合,是需要新处理模式才能具有更强的决策力、洞察发现力和流程优化能力的海量、高增长率和多样化的信息资产。

    89 引用 • 113 回帖 • 1 关注