Spark mllib API- classification

本贴最后更新于 3322 天前,其中的信息可能已经渤澥桑田

Apark mllib API 的翻译 - 分类篇。 对官方文档进行翻译的同时加入了一些常识性知识。

更多分类的相关知识可以查看我的另外一篇博客数据挖掘算法初窥门庭--分类回归

Spark 当前提供 LogisticRegression、SVM、NaiveBayes。


##LogisticRegression 逻辑回归

###背景知识

LinerRegression 是使用线性方程对数据进行两分类(在线的一侧属于同一类)。而 LogisticRegression 就是一个被 logistic 方程归一化后的 LinerRegression(归一化后值域为 0-1)。LogisticRegression 一般也用于两分类,预测样本属于某个类别的概率。

LogisticRegression 的过程是典型的监督机器学习,也就是在规则化参数的同时最小化误差。最小化误差是为了让我们的模型拟合我们的训练数据,而规则化参数是防止我们的模型过分拟合我们的训练数据。
大致步骤如下:

  • 目标函数为 f(f 为未知的),我们假定目标函数为 h。(假设)
  • 构造损失函数 cost(基于最大似然估计),表示 h 的预测结果与实际结果 f 之间的偏差。(预测并评估)
  • 通过迭代,调整 h,使 h 与 f 尽可能接近。(求最优解)

LogisticRegression 有很多不同的算法版本,大多数的主要不同在于求最优解。目前,spark 提供两种 LogisticRegression 方法:SGD(随机梯度下降)和 LBFGS(改进的拟牛顿法)。

特征选择:

  • LogisticRegression 假设向量的各个维度是独立不相互影响的。
  • 由于 LogisticRegression 的终止条件是收敛或达到最大迭代次数,因此在数据预处理时进行归一化,加快收敛速度。
  • 更多具体的变量选择方法,参考华山大师兄的 Logistic Regression--逻辑回归算法汇总

###Spark API

  • 类:pyspark.mllib.classification.LogisticRegressionWithSGD
    • 方法:
      train(data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None, regParam=0.01, regType='l2', intercept=False, validateData=True, convergenceTol=0.001)
      通过给定数据训练逻辑回归模型。
      • data:训练数据,LabeledPoint 格式的 RDD 数据集。
      • iterations:迭代次数,默认为 100。
      • step:SGD 的步长,默认为 1.0。(太大容易错过最优解,太小导致迭代次数过多)。
      • miniBatchFraction:用于每次 SGD 迭代的数据,默认 1.0。(SGD 每次迭代选用随机数据)。
      • initialWeights:初始权值,默认 None。
      • regParam:规则化参数,默认 0.01。
      • regType:用于训练模型的规则化类型,可选为 l1 或 l2(默认)。
      • intercept:布尔值,表示是否使用增强表现来训练数据,默认 False。
      • validateData:布尔值,表示算法是否在训练前检验数据,默认 True。
      • convergenceTol:终止迭代的收敛值,默认 0.001。

  • 类: pyspark.mllib.classification.LogisticRegressionWithLBFGS
    • 方法:
      train(data, iterations=100, initialWeights=None, regParam=0.01, regType='l2', intercept=False, corrections=10, tolerance=0.0001, validateData=True, numClasses=2)
      通过给定数据训练逻辑回归模型。
      • data:训练数据,LabeledPoint 格式的 RDD 数据集。
      • iterations:迭代次数,默认为 100。
      • initialWeights:初始权值,默认 None。
      • regParam:规则化参数,默认 0.01。
      • regType:用于训练模型的规则化类型,可选为 l1 或 l2(默认)。
      • intercept:布尔值,表示是否使用增强表现来训练数据,默认 False。
      • corrections:用于 LBFGS 更新的修正值,默认 10。
      • tolerance:LBFGS 迭代的收敛容忍系数,默认 1e-4。
      • validateData:布尔值,表示算法是否在训练前检验数据,默认 True。
      • numClasses:多分类逻辑回归中类别的个数,默认 2。

  • 类:pyspark.mllib.classification.LogisticRegressionModel
    使用多/两逻辑分类方法训练得到的模型。
    • 属性:
      • weights:每个向量计算的权值。
      • intercept:该模型的计算截距(只用于两逻辑回归)。
      • numFeatures:向量的维度。
      • numClasses:输出类别的个数。
      • threshold:用于区分正负样本的阈值。
    • 方法: clearThreshold()
      去除阈值,直接输出预测值,只用于两分类
    • 方法: load(sc, path)
      从指定路径加载模型
    • 方法: save(sc, path)
      将模型保存到指定路径
    • 方法: predict(x)
      预测,输入可以为单个向量或整个 RDD
    • 方法: setThreshold(value)
      设置用于区分正负样本的阈值。当预测值大于该预置时,判定为正样本。

SVM 支持向量机

###背景知识

SVM 是二分类的分类模型。给定包含正负样本的数据集,SVM 的目的是寻找一个超平面(WX+b=0)对样本进行分割,且使得离超平面比较近的点能有更大的间距。

(待补充)


###Spark API

  • 类:class pyspark.mllib.classification.SVMWithSGD
    • 方法:
      train(data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, regType='l2', intercept=False, validateData=True, convergenceTol=0.001)
      通过给定的数据训练 SVM 模型。
      • data:训练数据,LabeledPoint 格式的 RDD 数据集。
      • iterations:迭代次数,默认为 100。
      • step:SGD 的步长,默认为 1.0。
      • regParam:规则化参数,默认 0.01。
      • miniBatchFraction:用于每次 SGD 迭代的数据,默认 1.0。
      • initialWeights:初始权值,默认 None。
      • regType:用于训练模型的规则化类型,可选为 l1 或 l2(默认)。
      • intercept:布尔值,表示是否使用增强表现来训练数据,默认 False。
      • validateData:布尔值,表示算法是否在训练前检验数据,默认 True。
      • convergenceTol:终止迭代的收敛值,默认 0.001。

  • 类: pyspark.mllib.classification.SVMModel
    支持向量机模型
    • 属性:
      • weights:每个向量计算的权值。
      • intercept:该模型的计算截距。
    • 方法: clearThreshold()
      去除阈值,直接输出预测值
    • 方法: load(sc, path)
      从指定路径加载模型
    • 方法: save(sc, path)
      将模型保存到指定路径
    • 方法: predict(x)
      预测,输入可以为单个向量或整个 RDD
    • 方法: setThreshold(value)
      设置用于区分正负样本的阈值。当预测值大于该预置时,判定为正样本。

##NaiveBayes 朴素贝叶斯
###背景知识
贝叶斯概率公式:
P(B[j]|A[i])=P(A[i]|B[j])P(B[j]) / P(A[i])
朴素贝叶斯分类器是使用贝叶斯概率公式为核心的分类算法,其基本思想为:对于给出的待分类项,求解在此项出现的条件下各个类别出现的概率,哪个最大,就认为此待分类项属于哪个类别。
朴素贝叶斯假定样本的不同特征属性对样本的归类影响时相互独立的。

(待补充)


###Spark API

  • 类:pyspark.mllib.classification.NaiveBayes
    • 方法:
      train(data, lambda_=1.0)
      通过给定数据集训练贝叶斯模型
      • data:训练数据,LabeledPoint 格式的 RDD 数据集。
      • lambda:平滑参数,默认 1.0

  • 类: pyspark.mllib.classification.NaiveBayesModel
    朴素贝叶斯分类器模型
    • 属性:
      • labels:label 列表
      • pi:每个类别的 priors
      • theta:使用矩阵存储每个向量划分到每个类的条件概率
    • 方法: load(sc, path)
      从指定路径加载模型
    • 方法: save(sc, path)
      将模型保存到指定路径
    • 方法: predict(x)
      预测,输入可以为单个向量或整个 RDD
  • Spark

    Spark 是 UC Berkeley AMP lab 所开源的类 Hadoop MapReduce 的通用并行框架。Spark 拥有 Hadoop MapReduce 所具有的优点;但不同于 MapReduce 的是 Job 中间输出结果可以保存在内存中,从而不再需要读写 HDFS,因此 Spark 能更好地适用于数据挖掘与机器学习等需要迭代的 MapReduce 的算法。

    74 引用 • 46 回帖 • 563 关注
  • 数据挖掘
    17 引用 • 32 回帖 • 3 关注
  • 默认
    5 引用 • 22 回帖

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...
  • zempty via macOS

    ????1???😇

推荐标签 标签

  • Java

    Java 是一种可以撰写跨平台应用软件的面向对象的程序设计语言,是由 Sun Microsystems 公司于 1995 年 5 月推出的。Java 技术具有卓越的通用性、高效性、平台移植性和安全性。

    3201 引用 • 8216 回帖
  • WebComponents

    Web Components 是 W3C 定义的标准,它给了前端开发者扩展浏览器标签的能力,可以方便地定制可复用组件,更好的进行模块化开发,解放了前端开发者的生产力。

    1 引用 • 8 关注
  • 小说

    小说是以刻画人物形象为中心,通过完整的故事情节和环境描写来反映社会生活的文学体裁。

    32 引用 • 108 回帖 • 1 关注
  • etcd

    etcd 是一个分布式、高可用的 key-value 数据存储,专门用于在分布式系统中保存关键数据。

    6 引用 • 26 回帖 • 543 关注
  • Maven

    Maven 是基于项目对象模型(POM)、通过一小段描述信息来管理项目的构建、报告和文档的软件项目管理工具。

    188 引用 • 319 回帖 • 252 关注
  • Elasticsearch

    Elasticsearch 是一个基于 Lucene 的搜索服务器。它提供了一个分布式多用户能力的全文搜索引擎,基于 RESTful 接口。Elasticsearch 是用 Java 开发的,并作为 Apache 许可条款下的开放源码发布,是当前流行的企业级搜索引擎。设计用于云计算中,能够达到实时搜索,稳定,可靠,快速,安装使用方便。

    117 引用 • 99 回帖 • 201 关注
  • 房星科技

    房星网,我们不和没有钱的程序员谈理想,我们要让程序员又有理想又有钱。我们有雄厚的房地产行业线下资源,遍布昆明全城的 100 家门店、四千地产经纪人是我们坚实的后盾。

    6 引用 • 141 回帖 • 591 关注
  • CongSec

    本标签主要用于分享网络空间安全专业的学习笔记

    1 引用 • 1 回帖 • 31 关注
  • 笔记

    好记性不如烂笔头。

    310 引用 • 794 回帖
  • 书籍

    宋真宗赵恒曾经说过:“书中自有黄金屋,书中自有颜如玉。”

    78 引用 • 396 回帖
  • BND

    BND(Baidu Netdisk Downloader)是一款图形界面的百度网盘不限速下载器,支持 Windows、Linux 和 Mac,详细介绍请看这里

    107 引用 • 1281 回帖 • 35 关注
  • HHKB

    HHKB 是富士通的 Happy Hacking 系列电容键盘。电容键盘即无接点静电电容式键盘(Capacitive Keyboard)。

    5 引用 • 74 回帖 • 504 关注
  • 微服务

    微服务架构是一种架构模式,它提倡将单一应用划分成一组小的服务。服务之间互相协调,互相配合,为用户提供最终价值。每个服务运行在独立的进程中。服务于服务之间才用轻量级的通信机制互相沟通。每个服务都围绕着具体业务构建,能够被独立的部署。

    96 引用 • 155 回帖 • 4 关注
  • LeetCode

    LeetCode(力扣)是一个全球极客挚爱的高质量技术成长平台,想要学习和提升专业能力从这里开始,充足技术干货等你来啃,轻松拿下 Dream Offer!

    209 引用 • 72 回帖
  • CentOS

    CentOS(Community Enterprise Operating System)是 Linux 发行版之一,它是来自于 Red Hat Enterprise Linux 依照开放源代码规定释出的源代码所编译而成。由于出自同样的源代码,因此有些要求高度稳定的服务器以 CentOS 替代商业版的 Red Hat Enterprise Linux 使用。两者的不同在于 CentOS 并不包含封闭源代码软件。

    239 引用 • 224 回帖
  • golang

    Go 语言是 Google 推出的一种全新的编程语言,可以在不损失应用程序性能的情况下降低代码的复杂性。谷歌首席软件工程师罗布派克(Rob Pike)说:我们之所以开发 Go,是因为过去 10 多年间软件开发的难度令人沮丧。Go 是谷歌 2009 发布的第二款编程语言。

    499 引用 • 1395 回帖 • 246 关注
  • 代码片段

    代码片段分为 CSS 与 JS 两种代码,添加在 [设置 - 外观 - 代码片段] 中,这些代码会在思源笔记加载时自动执行,用于改善笔记的样式或功能。

    用户在该标签下分享代码片段时需在帖子标题前添加 [css] [js] 用于区分代码片段类型。

    170 引用 • 1150 回帖
  • Mac

    Mac 是苹果公司自 1984 年起以“Macintosh”开始开发的个人消费型计算机,如:iMac、Mac mini、Macbook Air、Macbook Pro、Macbook、Mac Pro 等计算机。

    168 引用 • 597 回帖
  • Notion

    Notion - The all-in-one workspace for your notes, tasks, wikis, and databases.

    10 引用 • 77 回帖
  • WebClipper

    Web Clipper 是一款浏览器剪藏扩展,它可以帮助你把网页内容剪藏到本地。

    3 引用 • 9 回帖 • 1 关注
  • JetBrains

    JetBrains 是一家捷克的软件开发公司,该公司位于捷克的布拉格,并在俄国的圣彼得堡及美国麻州波士顿都设有办公室,该公司最为人所熟知的产品是 Java 编程语言开发撰写时所用的集成开发环境:IntelliJ IDEA

    18 引用 • 54 回帖 • 3 关注
  • 京东

    京东是中国最大的自营式电商企业,2015 年第一季度在中国自营式 B2C 电商市场的占有率为 56.3%。2014 年 5 月,京东在美国纳斯达克证券交易所正式挂牌上市(股票代码:JD),是中国第一个成功赴美上市的大型综合型电商平台,与腾讯、百度等中国互联网巨头共同跻身全球前十大互联网公司排行榜。

    14 引用 • 102 回帖 • 317 关注
  • sts
    2 引用 • 2 回帖 • 229 关注
  • CAP

    CAP 指的是在一个分布式系统中, Consistency(一致性)、 Availability(可用性)、Partition tolerance(分区容错性),三者不可兼得。

    12 引用 • 5 回帖 • 631 关注
  • Sublime

    Sublime Text 是一款可以用来写代码、写文章的文本编辑器。支持代码高亮、自动完成,还支持通过插件进行扩展。

    10 引用 • 5 回帖 • 3 关注
  • 域名

    域名(Domain Name),简称域名、网域,是由一串用点分隔的名字组成的 Internet 上某一台计算机或计算机组的名称,用于在数据传输时标识计算机的电子方位(有时也指地理位置)。

    43 引用 • 208 回帖 • 1 关注
  • Redis

    Redis 是一个开源的使用 ANSI C 语言编写、支持网络、可基于内存亦可持久化的日志型、Key-Value 数据库,并提供多种语言的 API。从 2010 年 3 月 15 日起,Redis 的开发工作由 VMware 主持。从 2013 年 5 月开始,Redis 的开发由 Pivotal 赞助。

    286 引用 • 248 回帖