Spark mllib API- classification

本贴最后更新于 3148 天前,其中的信息可能已经渤澥桑田

Apark mllib API 的翻译 - 分类篇。 对官方文档进行翻译的同时加入了一些常识性知识。

更多分类的相关知识可以查看我的另外一篇博客数据挖掘算法初窥门庭--分类回归

Spark 当前提供 LogisticRegression、SVM、NaiveBayes。


##LogisticRegression 逻辑回归

###背景知识

LinerRegression 是使用线性方程对数据进行两分类(在线的一侧属于同一类)。而 LogisticRegression 就是一个被 logistic 方程归一化后的 LinerRegression(归一化后值域为 0-1)。LogisticRegression 一般也用于两分类,预测样本属于某个类别的概率。

LogisticRegression 的过程是典型的监督机器学习,也就是在规则化参数的同时最小化误差。最小化误差是为了让我们的模型拟合我们的训练数据,而规则化参数是防止我们的模型过分拟合我们的训练数据。
大致步骤如下:

  • 目标函数为 f(f 为未知的),我们假定目标函数为 h。(假设)
  • 构造损失函数 cost(基于最大似然估计),表示 h 的预测结果与实际结果 f 之间的偏差。(预测并评估)
  • 通过迭代,调整 h,使 h 与 f 尽可能接近。(求最优解)

LogisticRegression 有很多不同的算法版本,大多数的主要不同在于求最优解。目前,spark 提供两种 LogisticRegression 方法:SGD(随机梯度下降)和 LBFGS(改进的拟牛顿法)。

特征选择:

  • LogisticRegression 假设向量的各个维度是独立不相互影响的。
  • 由于 LogisticRegression 的终止条件是收敛或达到最大迭代次数,因此在数据预处理时进行归一化,加快收敛速度。
  • 更多具体的变量选择方法,参考华山大师兄的 Logistic Regression--逻辑回归算法汇总

###Spark API

  • 类:pyspark.mllib.classification.LogisticRegressionWithSGD
    • 方法:
      train(data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None, regParam=0.01, regType='l2', intercept=False, validateData=True, convergenceTol=0.001)
      通过给定数据训练逻辑回归模型。
      • data:训练数据,LabeledPoint 格式的 RDD 数据集。
      • iterations:迭代次数,默认为 100。
      • step:SGD 的步长,默认为 1.0。(太大容易错过最优解,太小导致迭代次数过多)。
      • miniBatchFraction:用于每次 SGD 迭代的数据,默认 1.0。(SGD 每次迭代选用随机数据)。
      • initialWeights:初始权值,默认 None。
      • regParam:规则化参数,默认 0.01。
      • regType:用于训练模型的规则化类型,可选为 l1 或 l2(默认)。
      • intercept:布尔值,表示是否使用增强表现来训练数据,默认 False。
      • validateData:布尔值,表示算法是否在训练前检验数据,默认 True。
      • convergenceTol:终止迭代的收敛值,默认 0.001。

  • 类: pyspark.mllib.classification.LogisticRegressionWithLBFGS
    • 方法:
      train(data, iterations=100, initialWeights=None, regParam=0.01, regType='l2', intercept=False, corrections=10, tolerance=0.0001, validateData=True, numClasses=2)
      通过给定数据训练逻辑回归模型。
      • data:训练数据,LabeledPoint 格式的 RDD 数据集。
      • iterations:迭代次数,默认为 100。
      • initialWeights:初始权值,默认 None。
      • regParam:规则化参数,默认 0.01。
      • regType:用于训练模型的规则化类型,可选为 l1 或 l2(默认)。
      • intercept:布尔值,表示是否使用增强表现来训练数据,默认 False。
      • corrections:用于 LBFGS 更新的修正值,默认 10。
      • tolerance:LBFGS 迭代的收敛容忍系数,默认 1e-4。
      • validateData:布尔值,表示算法是否在训练前检验数据,默认 True。
      • numClasses:多分类逻辑回归中类别的个数,默认 2。

  • 类:pyspark.mllib.classification.LogisticRegressionModel
    使用多/两逻辑分类方法训练得到的模型。
    • 属性:
      • weights:每个向量计算的权值。
      • intercept:该模型的计算截距(只用于两逻辑回归)。
      • numFeatures:向量的维度。
      • numClasses:输出类别的个数。
      • threshold:用于区分正负样本的阈值。
    • 方法: clearThreshold()
      去除阈值,直接输出预测值,只用于两分类
    • 方法: load(sc, path)
      从指定路径加载模型
    • 方法: save(sc, path)
      将模型保存到指定路径
    • 方法: predict(x)
      预测,输入可以为单个向量或整个 RDD
    • 方法: setThreshold(value)
      设置用于区分正负样本的阈值。当预测值大于该预置时,判定为正样本。

SVM 支持向量机

###背景知识

SVM 是二分类的分类模型。给定包含正负样本的数据集,SVM 的目的是寻找一个超平面(WX+b=0)对样本进行分割,且使得离超平面比较近的点能有更大的间距。

(待补充)


###Spark API

  • 类:class pyspark.mllib.classification.SVMWithSGD
    • 方法:
      train(data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None, regType='l2', intercept=False, validateData=True, convergenceTol=0.001)
      通过给定的数据训练 SVM 模型。
      • data:训练数据,LabeledPoint 格式的 RDD 数据集。
      • iterations:迭代次数,默认为 100。
      • step:SGD 的步长,默认为 1.0。
      • regParam:规则化参数,默认 0.01。
      • miniBatchFraction:用于每次 SGD 迭代的数据,默认 1.0。
      • initialWeights:初始权值,默认 None。
      • regType:用于训练模型的规则化类型,可选为 l1 或 l2(默认)。
      • intercept:布尔值,表示是否使用增强表现来训练数据,默认 False。
      • validateData:布尔值,表示算法是否在训练前检验数据,默认 True。
      • convergenceTol:终止迭代的收敛值,默认 0.001。

  • 类: pyspark.mllib.classification.SVMModel
    支持向量机模型
    • 属性:
      • weights:每个向量计算的权值。
      • intercept:该模型的计算截距。
    • 方法: clearThreshold()
      去除阈值,直接输出预测值
    • 方法: load(sc, path)
      从指定路径加载模型
    • 方法: save(sc, path)
      将模型保存到指定路径
    • 方法: predict(x)
      预测,输入可以为单个向量或整个 RDD
    • 方法: setThreshold(value)
      设置用于区分正负样本的阈值。当预测值大于该预置时,判定为正样本。

##NaiveBayes 朴素贝叶斯
###背景知识
贝叶斯概率公式:
P(B[j]|A[i])=P(A[i]|B[j])P(B[j]) / P(A[i])
朴素贝叶斯分类器是使用贝叶斯概率公式为核心的分类算法,其基本思想为:对于给出的待分类项,求解在此项出现的条件下各个类别出现的概率,哪个最大,就认为此待分类项属于哪个类别。
朴素贝叶斯假定样本的不同特征属性对样本的归类影响时相互独立的。

(待补充)


###Spark API

  • 类:pyspark.mllib.classification.NaiveBayes
    • 方法:
      train(data, lambda_=1.0)
      通过给定数据集训练贝叶斯模型
      • data:训练数据,LabeledPoint 格式的 RDD 数据集。
      • lambda:平滑参数,默认 1.0

  • 类: pyspark.mllib.classification.NaiveBayesModel
    朴素贝叶斯分类器模型
    • 属性:
      • labels:label 列表
      • pi:每个类别的 priors
      • theta:使用矩阵存储每个向量划分到每个类的条件概率
    • 方法: load(sc, path)
      从指定路径加载模型
    • 方法: save(sc, path)
      将模型保存到指定路径
    • 方法: predict(x)
      预测,输入可以为单个向量或整个 RDD
  • Spark

    Spark 是 UC Berkeley AMP lab 所开源的类 Hadoop MapReduce 的通用并行框架。Spark 拥有 Hadoop MapReduce 所具有的优点;但不同于 MapReduce 的是 Job 中间输出结果可以保存在内存中,从而不再需要读写 HDFS,因此 Spark 能更好地适用于数据挖掘与机器学习等需要迭代的 MapReduce 的算法。

    74 引用 • 46 回帖 • 552 关注
  • 数据挖掘
    17 引用 • 32 回帖 • 3 关注
  • 默认
    5 引用 • 22 回帖

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...

推荐标签 标签

  • 友情链接

    确认过眼神后的灵魂连接,站在链在!

    24 引用 • 373 回帖
  • RESTful

    一种软件架构设计风格而不是标准,提供了一组设计原则和约束条件,主要用于客户端和服务器交互类的软件。基于这个风格设计的软件可以更简洁,更有层次,更易于实现缓存等机制。

    30 引用 • 114 回帖 • 2 关注
  • 新人

    让我们欢迎这对新人。哦,不好意思说错了,让我们欢迎这位新人!
    新手上路,请谨慎驾驶!

    52 引用 • 228 回帖
  • Kafka

    Kafka 是一种高吞吐量的分布式发布订阅消息系统,它可以处理消费者规模的网站中的所有动作流数据。 这种动作(网页浏览,搜索和其他用户的行动)是现代系统中许多功能的基础。 这些数据通常是由于吞吐量的要求而通过处理日志和日志聚合来解决。

    36 引用 • 35 回帖
  • danl
    132 关注
  • 强迫症

    强迫症(OCD)属于焦虑障碍的一种类型,是一组以强迫思维和强迫行为为主要临床表现的神经精神疾病,其特点为有意识的强迫和反强迫并存,一些毫无意义、甚至违背自己意愿的想法或冲动反反复复侵入患者的日常生活。

    15 引用 • 161 回帖
  • 自由行
    10 关注
  • 博客

    记录并分享人生的经历。

    273 引用 • 2388 回帖
  • GitLab

    GitLab 是利用 Ruby 一个开源的版本管理系统,实现一个自托管的 Git 项目仓库,可通过 Web 界面操作公开或私有项目。

    46 引用 • 72 回帖
  • 工具

    子曰:“工欲善其事,必先利其器。”

    286 引用 • 729 回帖
  • Jenkins

    Jenkins 是一套开源的持续集成工具。它提供了非常丰富的插件,让构建、部署、自动化集成项目变得简单易用。

    53 引用 • 37 回帖
  • ZeroNet

    ZeroNet 是一个基于比特币加密技术和 BT 网络技术的去中心化的、开放开源的网络和交流系统。

    1 引用 • 21 回帖 • 638 关注
  • GitHub

    GitHub 于 2008 年上线,目前,除了 Git 代码仓库托管及基本的 Web 管理界面以外,还提供了订阅、讨论组、文本渲染、在线文件编辑器、协作图谱(报表)、代码片段分享(Gist)等功能。正因为这些功能所提供的便利,又经过长期的积累,GitHub 的用户活跃度很高,在开源世界里享有深远的声望,并形成了社交化编程文化(Social Coding)。

    209 引用 • 2031 回帖
  • 安装

    你若安好,便是晴天。

    132 引用 • 1184 回帖
  • HHKB

    HHKB 是富士通的 Happy Hacking 系列电容键盘。电容键盘即无接点静电电容式键盘(Capacitive Keyboard)。

    5 引用 • 74 回帖 • 471 关注
  • Dubbo

    Dubbo 是一个分布式服务框架,致力于提供高性能和透明化的 RPC 远程服务调用方案,是 [阿里巴巴] SOA 服务化治理方案的核心框架,每天为 2,000+ 个服务提供 3,000,000,000+ 次访问量支持,并被广泛应用于阿里巴巴集团的各成员站点。

    60 引用 • 82 回帖 • 595 关注
  • 黑曜石

    黑曜石是一款强大的知识库工具,支持本地 Markdown 文件编辑,支持双向链接和关系图。

    A second brain, for you, forever.

    15 引用 • 122 回帖
  • 正则表达式

    正则表达式(Regular Expression)使用单个字符串来描述、匹配一系列遵循某个句法规则的字符串。

    31 引用 • 94 回帖
  • Spring

    Spring 是一个开源框架,是于 2003 年兴起的一个轻量级的 Java 开发框架,由 Rod Johnson 在其著作《Expert One-On-One J2EE Development and Design》中阐述的部分理念和原型衍生而来。它是为了解决企业应用开发的复杂性而创建的。框架的主要优势之一就是其分层架构,分层架构允许使用者选择使用哪一个组件,同时为 JavaEE 应用程序开发提供集成的框架。

    944 引用 • 1459 回帖 • 17 关注
  • 30Seconds

    📙 前端知识精选集,包含 HTML、CSS、JavaScript、React、Node、安全等方面,每天仅需 30 秒。

    • 精选常见面试题,帮助您准备下一次面试
    • 精选常见交互,帮助您拥有简洁酷炫的站点
    • 精选有用的 React 片段,帮助你获取最佳实践
    • 精选常见代码集,帮助您提高打码效率
    • 整理前端界的最新资讯,邀您一同探索新世界
    488 引用 • 384 回帖 • 8 关注
  • 职场

    找到自己的位置,萌新烦恼少。

    127 引用 • 1705 回帖 • 1 关注
  • App

    App(应用程序,Application 的缩写)一般指手机软件。

    91 引用 • 384 回帖
  • 深度学习

    深度学习(Deep Learning)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高层抽象的算法。

    53 引用 • 40 回帖 • 2 关注
  • Chrome

    Chrome 又称 Google 浏览器,是一个由谷歌公司开发的网页浏览器。该浏览器是基于其他开源软件所编写,包括 WebKit,目标是提升稳定性、速度和安全性,并创造出简单且有效率的使用者界面。

    62 引用 • 289 回帖 • 1 关注
  • GitBook

    GitBook 使您的团队可以轻松编写和维护高质量的文档。 分享知识,提高团队的工作效率,让用户满意。

    3 引用 • 8 回帖 • 4 关注
  • 域名

    域名(Domain Name),简称域名、网域,是由一串用点分隔的名字组成的 Internet 上某一台计算机或计算机组的名称,用于在数据传输时标识计算机的电子方位(有时也指地理位置)。

    43 引用 • 208 回帖
  • abitmean

    有点意思就行了

    29 关注