Spark mllib API- classification

本贴最后更新于 3281 天前,其中的信息可能已经渤澥桑田

Apark mllib API 的翻译 - 分类篇。 对官方文档进行翻译的同时加入了一些常识性知识。

更多分类的相关知识可以查看我的另外一篇博客数据挖掘算法初窥门庭--分类回归

Spark 当前提供 LogisticRegression、SVM、NaiveBayes。


##LogisticRegression 逻辑回归

###背景知识

LinerRegression 是使用线性方程对数据进行两分类(在线的一侧属于同一类)。而 LogisticRegression 就是一个被 logistic 方程归一化后的 LinerRegression(归一化后值域为 0-1)。LogisticRegression 一般也用于两分类,预测样本属于某个类别的概率。

LogisticRegression 的过程是典型的监督机器学习,也就是在规则化参数的同时最小化误差。最小化误差是为了让我们的模型拟合我们的训练数据,而规则化参数是防止我们的模型过分拟合我们的训练数据。
大致步骤如下:

  • 目标函数为 f(f 为未知的),我们假定目标函数为 h。(假设)
  • 构造损失函数 cost(基于最大似然估计),表示 h 的预测结果与实际结果 f 之间的偏差。(预测并评估)
  • 通过迭代,调整 h,使 h 与 f 尽可能接近。(求最优解)

LogisticRegression 有很多不同的算法版本,大多数的主要不同在于求最优解。目前,spark 提供两种 LogisticRegression 方法:SGD(随机梯度下降)和 LBFGS(改进的拟牛顿法)。

特征选择:

  • LogisticRegression 假设向量的各个维度是独立不相互影响的。
  • 由于 LogisticRegression 的终止条件是收敛或达到最大迭代次数,因此在数据预处理时进行归一化,加快收敛速度。
  • 更多具体的变量选择方法,参考华山大师兄的 Logistic Regression--逻辑回归算法汇总

###Spark API

  • 类:pyspark.mllib.classification.LogisticRegressionWithSGD
    • 方法:
      train(data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None, regParam=0.01, regType='l2', intercept=False, validateData=True, convergenceTol=0.001)
      通过给定数据训练逻辑回归模型。
      • data:训练数据,LabeledPoint 格式的 RDD 数据集。
      • iterations:迭代次数,默认为 100。
      • step:SGD 的步长,默认为 1.0。(太大容易错过最优解,太小导致迭代次数过多)。
      • miniBatchFraction:用于每次 SGD 迭代的数据,默认 1.0。(SGD 每次迭代选用随机数据)。
      • initialWeights:初始权值,默认 None。
      • regParam:规则化参数,默认 0.01。
      • regType:用于训练模型的规则化类型,可选为 l1 或 l2(默认)。
      • intercept:布尔值,表示是否使用增强表现来训练数据,默认 False。
      • validateData:布尔值,表示算法是否在训练前检验数据,默认 True。
      • convergenceTol:终止迭代的收敛值,默认 0.001。

  • 类: pyspark.mllib.classification.LogisticRegressionWithLBFGS
    • 方法:
      train(data, iterations=100, initialWeights=None, regParam=0.01, regType='l2', intercept=False, corrections=10, tolerance=0.0001, validateData=True, numClasses=2)
      通过给定数据训练逻辑回归模型。
      • data:训练数据,LabeledPoint 格式的 RDD 数据集。
      • iterations:迭代次数,默认为 100。
      • initialWeights:初始权值,默认 None。
      • regParam:规则化参数,默认 0.01。
      • regType:用于训练模型的规则化类型,可选为 l1 或 l2(默认)。
      • intercept:布尔值,表示是否使用增强表现来训练数据,默认 False。
      • corrections:用于 LBFGS 更新的修正值,默认 10。
      • tolerance:LBFGS 迭代的收敛容忍系数,默认 1e-4。
      • validateData:布尔值,表示算法是否在训练前检验数据,默认 True。
      • numClasses:多分类逻辑回归中类别的个数,默认 2。

  • 类:pyspark.mllib.classification.LogisticRegressionModel
    使用多/两逻辑分类方法训练得到的模型。
    • 属性:
      • weights:每个向量计算的权值。
      • intercept:该模型的计算截距(只用于两逻辑回归)。
      • numFeatures:向量的维度。
      • numClasses:输出类别的个数。
      • threshold:用于区分正负样本的阈值。
    • 方法: clearThreshold()
      去除阈值,直接输出预测值,只用于两分类
    • 方法: load(sc, path)
      从指定路径加载模型
    • 方法: save(sc, path)
      将模型保存到指定路径
    • 方法: predict(x)
      预测,输入可以为单个向量或整个 RDD
    • 方法: setThreshold(value)
      设置用于区分正负样本的阈值。当预测值大于该预置时,判定为正样本。

SVM 支持向量机

###背景知识

SVM 是二分类的分类模型。给定包含正负样本的数据集,SVM 的目的是寻找一个超平面(WX+b=0)对样本进行分割,且使得离超平面比较近的点能有更大的间距。

(待补充)


###Spark API

  • 类:class pyspark.mllib.classification.SVMWithSGD
    • 方法:
      train(data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, regType='l2', intercept=False, validateData=True, convergenceTol=0.001)
      通过给定的数据训练 SVM 模型。
      • data:训练数据,LabeledPoint 格式的 RDD 数据集。
      • iterations:迭代次数,默认为 100。
      • step:SGD 的步长,默认为 1.0。
      • regParam:规则化参数,默认 0.01。
      • miniBatchFraction:用于每次 SGD 迭代的数据,默认 1.0。
      • initialWeights:初始权值,默认 None。
      • regType:用于训练模型的规则化类型,可选为 l1 或 l2(默认)。
      • intercept:布尔值,表示是否使用增强表现来训练数据,默认 False。
      • validateData:布尔值,表示算法是否在训练前检验数据,默认 True。
      • convergenceTol:终止迭代的收敛值,默认 0.001。

  • 类: pyspark.mllib.classification.SVMModel
    支持向量机模型
    • 属性:
      • weights:每个向量计算的权值。
      • intercept:该模型的计算截距。
    • 方法: clearThreshold()
      去除阈值,直接输出预测值
    • 方法: load(sc, path)
      从指定路径加载模型
    • 方法: save(sc, path)
      将模型保存到指定路径
    • 方法: predict(x)
      预测,输入可以为单个向量或整个 RDD
    • 方法: setThreshold(value)
      设置用于区分正负样本的阈值。当预测值大于该预置时,判定为正样本。

##NaiveBayes 朴素贝叶斯
###背景知识
贝叶斯概率公式:
P(B[j]|A[i])=P(A[i]|B[j])P(B[j]) / P(A[i])
朴素贝叶斯分类器是使用贝叶斯概率公式为核心的分类算法,其基本思想为:对于给出的待分类项,求解在此项出现的条件下各个类别出现的概率,哪个最大,就认为此待分类项属于哪个类别。
朴素贝叶斯假定样本的不同特征属性对样本的归类影响时相互独立的。

(待补充)


###Spark API

  • 类:pyspark.mllib.classification.NaiveBayes
    • 方法:
      train(data, lambda_=1.0)
      通过给定数据集训练贝叶斯模型
      • data:训练数据,LabeledPoint 格式的 RDD 数据集。
      • lambda:平滑参数,默认 1.0

  • 类: pyspark.mllib.classification.NaiveBayesModel
    朴素贝叶斯分类器模型
    • 属性:
      • labels:label 列表
      • pi:每个类别的 priors
      • theta:使用矩阵存储每个向量划分到每个类的条件概率
    • 方法: load(sc, path)
      从指定路径加载模型
    • 方法: save(sc, path)
      将模型保存到指定路径
    • 方法: predict(x)
      预测,输入可以为单个向量或整个 RDD
  • Spark

    Spark 是 UC Berkeley AMP lab 所开源的类 Hadoop MapReduce 的通用并行框架。Spark 拥有 Hadoop MapReduce 所具有的优点;但不同于 MapReduce 的是 Job 中间输出结果可以保存在内存中,从而不再需要读写 HDFS,因此 Spark 能更好地适用于数据挖掘与机器学习等需要迭代的 MapReduce 的算法。

    74 引用 • 46 回帖 • 568 关注
  • 数据挖掘
    17 引用 • 32 回帖 • 3 关注
  • 默认
    5 引用 • 22 回帖

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...
  • zempty via macOS

    ????1???😇

推荐标签 标签

  • FFmpeg

    FFmpeg 是一套可以用来记录、转换数字音频、视频,并能将其转化为流的开源计算机程序。

    23 引用 • 32 回帖 • 2 关注
  • Windows

    Microsoft Windows 是美国微软公司研发的一套操作系统,它问世于 1985 年,起初仅仅是 Microsoft-DOS 模拟环境,后续的系统版本由于微软不断的更新升级,不但易用,也慢慢的成为家家户户人们最喜爱的操作系统。

    226 引用 • 476 回帖 • 1 关注
  • Logseq

    Logseq 是一个隐私优先、开源的知识库工具。

    Logseq is a joyful, open-source outliner that works on top of local plain-text Markdown and Org-mode files. Use it to write, organize and share your thoughts, keep your to-do list, and build your own digital garden.

    7 引用 • 69 回帖 • 2 关注
  • Spring

    Spring 是一个开源框架,是于 2003 年兴起的一个轻量级的 Java 开发框架,由 Rod Johnson 在其著作《Expert One-On-One J2EE Development and Design》中阐述的部分理念和原型衍生而来。它是为了解决企业应用开发的复杂性而创建的。框架的主要优势之一就是其分层架构,分层架构允许使用者选择使用哪一个组件,同时为 JavaEE 应用程序开发提供集成的框架。

    946 引用 • 1460 回帖
  • wolai

    我来 wolai:不仅仅是未来的云端笔记!

    2 引用 • 14 回帖 • 2 关注
  • Outlook
    1 引用 • 5 回帖 • 2 关注
  • Bootstrap

    Bootstrap 是 Twitter 推出的一个用于前端开发的开源工具包。它由 Twitter 的设计师 Mark Otto 和 Jacob Thornton 合作开发,是一个 CSS / HTML 框架。

    18 引用 • 33 回帖 • 653 关注
  • HTML

    HTML5 是 HTML 下一个的主要修订版本,现在仍处于发展阶段。广义论及 HTML5 时,实际指的是包括 HTML、CSS 和 JavaScript 在内的一套技术组合。

    108 引用 • 295 回帖 • 3 关注
  • Google

    Google(Google Inc.,NASDAQ:GOOG)是一家美国上市公司(公有股份公司),于 1998 年 9 月 7 日以私有股份公司的形式创立,设计并管理一个互联网搜索引擎。Google 公司的总部称作“Googleplex”,它位于加利福尼亚山景城。Google 目前被公认为是全球规模最大的搜索引擎,它提供了简单易用的免费服务。不作恶(Don't be evil)是谷歌公司的一项非正式的公司口号。

    49 引用 • 192 回帖
  • CSDN

    CSDN (Chinese Software Developer Network) 创立于 1999 年,是中国的 IT 社区和服务平台,为中国的软件开发者和 IT 从业者提供知识传播、职业发展、软件开发等全生命周期服务,满足他们在职业发展中学习及共享知识和信息、建立职业发展社交圈、通过软件开发实现技术商业化等刚性需求。

    14 引用 • 155 回帖
  • QQ

    1999 年 2 月腾讯正式推出“腾讯 QQ”,在线用户由 1999 年的 2 人(马化腾和张志东)到现在已经发展到上亿用户了,在线人数超过一亿,是目前使用最广泛的聊天软件之一。

    45 引用 • 557 回帖
  • Thymeleaf

    Thymeleaf 是一款用于渲染 XML/XHTML/HTML5 内容的模板引擎。类似 Velocity、 FreeMarker 等,它也可以轻易的与 Spring 等 Web 框架进行集成作为 Web 应用的模板引擎。与其它模板引擎相比,Thymeleaf 最大的特点是能够直接在浏览器中打开并正确显示模板页面,而不需要启动整个 Web 应用。

    11 引用 • 19 回帖 • 386 关注
  • Mac

    Mac 是苹果公司自 1984 年起以“Macintosh”开始开发的个人消费型计算机,如:iMac、Mac mini、Macbook Air、Macbook Pro、Macbook、Mac Pro 等计算机。

    168 引用 • 595 回帖
  • CongSec

    本标签主要用于分享网络空间安全专业的学习笔记

    1 引用 • 1 回帖 • 25 关注
  • Telegram

    Telegram 是一个非盈利性、基于云端的即时消息服务。它提供了支持各大操作系统平台的开源的客户端,也提供了很多强大的 APIs 给开发者创建自己的客户端和机器人。

    5 引用 • 35 回帖
  • Office

    Office 现已更名为 Microsoft 365. Microsoft 365 将高级 Office 应用(如 Word、Excel 和 PowerPoint)与 1 TB 的 OneDrive 云存储空间、高级安全性等结合在一起,可帮助你在任何设备上完成操作。

    5 引用 • 34 回帖 • 5 关注
  • C

    C 语言是一门通用计算机编程语言,应用广泛。C 语言的设计目标是提供一种能以简易的方式编译、处理低级存储器、产生少量的机器码以及不需要任何运行环境支持便能运行的编程语言。

    85 引用 • 165 回帖
  • 音乐

    你听到信仰的声音了么?

    62 引用 • 512 回帖
  • Angular

    AngularAngularJS 的新版本。

    26 引用 • 66 回帖 • 543 关注
  • IBM

    IBM(国际商业机器公司)或万国商业机器公司,简称 IBM(International Business Machines Corporation),总公司在纽约州阿蒙克市。1911 年托马斯·沃森创立于美国,是全球最大的信息技术和业务解决方案公司,拥有全球雇员 30 多万人,业务遍及 160 多个国家和地区。

    17 引用 • 53 回帖 • 147 关注
  • 运维

    互联网运维工作,以服务为中心,以稳定、安全、高效为三个基本点,确保公司的互联网业务能够 7×24 小时为用户提供高质量的服务。

    151 引用 • 257 回帖 • 2 关注
  • BND

    BND(Baidu Netdisk Downloader)是一款图形界面的百度网盘不限速下载器,支持 Windows、Linux 和 Mac,详细介绍请看这里

    107 引用 • 1281 回帖 • 33 关注
  • JetBrains

    JetBrains 是一家捷克的软件开发公司,该公司位于捷克的布拉格,并在俄国的圣彼得堡及美国麻州波士顿都设有办公室,该公司最为人所熟知的产品是 Java 编程语言开发撰写时所用的集成开发环境:IntelliJ IDEA

    18 引用 • 54 回帖 • 2 关注
  • OkHttp

    OkHttp 是一款 HTTP & HTTP/2 客户端库,专为 Android 和 Java 应用打造。

    16 引用 • 6 回帖 • 83 关注
  • PWA

    PWA(Progressive Web App)是 Google 在 2015 年提出、2016 年 6 月开始推广的项目。它结合了一系列现代 Web 技术,在网页应用中实现和原生应用相近的用户体验。

    14 引用 • 69 回帖 • 176 关注
  • Swift

    Swift 是苹果于 2014 年 WWDC(苹果开发者大会)发布的开发语言,可与 Objective-C 共同运行于 Mac OS 和 iOS 平台,用于搭建基于苹果平台的应用程序。

    36 引用 • 37 回帖 • 542 关注
  • Maven

    Maven 是基于项目对象模型(POM)、通过一小段描述信息来管理项目的构建、报告和文档的软件项目管理工具。

    186 引用 • 318 回帖 • 257 关注