Spark mllib API- tree

本贴最后更新于 3125 天前,其中的信息可能已经天翻地覆

spark 在 tree 这个模块中提供了 DecisionTree、RandomForest、GradientBoostedTrees 三种算法。均属于分类/回归 树模型。
三种算法均可用于回归预测。其中决策树和决策森林可用于二元或多元分类,GBT 只能用于二元分类。

随机森林和 GBT 均属于组合模型,解决模型过拟合问题。

##DecisionTree 决策树

  • 类:pyspark.mllib.tree.DecisionTree
    决策树算法,训练决策树模型,提供分类和回归。

    • 方法:
      trainClassifier(data, numClasses, categoricalFeaturesInfo, impurity='gini', maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0)
      训练用于分类的二叉树模型。

      • data:训练数据集,格式为 LabeledPoint 的 RDD,LabeledPoint 中的 Label 是整数。
      • numClasses:分类的个数。
      • categoricalFeaturesInfo:向量中为分类属性的索引表。任务没有出现在该列表中的特征将会以连续值处理。{n:k}表示第 n 个特征,是 0-k 的分类属性。
      • impurity:纯度计算,支持“entropy”和“gini”(默认)
      • maxDepth:决策树的最大深度,默认 5
      • maxBins:每个特征分裂时,最大划分(桶)数量,默认 32
      • minInstancesPerNode:需要保证节点分割出的左右子节点的最少的样本数量达到这个值,默认 1
      • minInfoGain:当前节点的所有属性分割带来的信息增益都比这个值要小,默认 0.0
    • 方法:
      trainRegressor(data, categoricalFeaturesInfo, impurity='variance', maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0)
      训练用于回归的二叉树模型。

      • data:训练数据集,格式为 LabeledPoint 的 RDD,LabeledPoint 中的 Label 是实数。
      • categoricalFeaturesInfo:向量中为分类属性的索引表。任务没有出现在该列表中的特征将会以连续值处理。{n:k}表示第 n 个特征,是 0-k 的分类属性。
      • impurity:纯度计算,支持“variance”(默认)
      • maxDepth:决策树的最大深度,默认 5
      • maxBins:每个特征分裂时,最大划分(桶)数量,默认 32
      • minInstancesPerNode:需要保证节点分割出的左右子节点的最少的样本数量达到这个值,默认 1
      • minInfoGain:当前节点的所有属性分割带来的信息增益都比这个值要小,默认 0.0

  • 类:pyspark.mllib.tree.DecisionTreeModel(java_model)

    • 方法: call(name, *a)
      调用 java 模型
    • 方法: depth()
      获取决策树的深度
    • 方法: load(sc, path)
      从指定 path 加载决策树模型
    • 方法: numNodes()
      获取决策树的节点数量,包括叶子节点
    • 方法: predict(x)
      预测一个或多个样本的 label 值
    • 方法: save(sc, path)
      将决策树模型持久化到指定 path
    • 方法: toDebugString()
      以 string 输出整个模型的信息

##RandomForest 随机森林

  • 类:pyspark.mllib.tree.RandomForest
    • 方法:
      trainClassifier(data, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy='auto', impurity='gini', maxDepth=4, maxBins=32, seed=None)
      训练一个用于二元或多元分类的随机森林

      • data:训练数据集,格式为 LabeledPoint 的 RDD,LabeledPoint 中的 Label 是整数。
      • numClasses:分类的个数。
      • categoricalFeaturesInfo:向量中为分类属性的索引表。任务没有出现在该列表中的特征将会以连续值处理。{n:k}表示第 n 个特征,是 0-k 的分类属性。
      • numTrees:随机森林中,树的数量。
      • featureSubsetStrategy:特征子集采样策略,支持"auto"(默认),"all","aqrt","log2","onethird"
      • impurity:纯度计算,支持“entropy”和“gini”(建议)
      • maxDepth:树的最大深度。
      • maxBins:每个特征分裂时,最大划分(桶)数量,默认 32
      • seed:用于引导和选择特征子集的随机种子。
    • 方法:
      trainRegressor(data, categoricalFeaturesInfo, numTrees, featureSubsetStrategy='auto', impurity='variance', maxDepth=4, maxBins=32, seed=None)
      训练一个用于回归预测的随机森林

      • data:训练数据集,格式为 LabeledPoint 的 RDD,LabeledPoint 中的 Label 是实数。
      • categoricalFeaturesInfo:向量中为分类属性的索引表。任务没有出现在该列表中的特征将会以连续值处理。{n:k}表示第 n 个特征,是 0-k 的分类属性。
      • numTrees:随机森林中,树的数量。
      • featureSubsetStrategy:特征子集采样策略,支持"auto"(默认),"all","aqrt","log2","onethird"
      • impurity:纯度计算,支持“variance”
      • maxDepth:树的最大深度。
      • maxBins:每个特征分裂时,最大划分(桶)数量,默认 32
      • seed:用于引导和选择特征子集的随机种子。

  • 类:pyspark.mllib.tree.RandomForestModel(java_model)

    • 方法: call(name, *a)
      调用 java 模型
    • 方法: load(sc, path)
      从指定 path 加载决策树模型
    • 方法: numTrees()
      获取随机森林中树的数量
    • 方法: predict(x)
      预测一个或多个样本的 label 值
    • 方法: save(sc, path)
      将决策树模型持久化到指定 path
    • 方法: toDebugString()
      以 string 输出整个模型的信息
    • 方法: totalNumNodes()
      获得森林中所有树的节点总和

##GradientBoostedTrees(GBT) 梯度提升决策树
这是一种模型组合的方法,利用简单模型的组合克服过拟合等问题。常用于推荐系统。

  • 类:pyspark.mllib.tree.GradientBoostedTrees
    • 方法:
      trainClassifier(data, categoricalFeaturesInfo, loss='logLoss', numIterations=100, learningRate=0.1, maxDepth=3, maxBins=32)
      训练一个用于二元分类预测的梯度提升决策树模型。

      • data:训练数据集,格式为 LabeledPoint 的 RDD。label 必须为 0 或 1.
      • categoricalFeaturesInfo:向量中为分类属性的索引表。任务没有出现在该列表中的特征将会以连续值处理。{n:k}表示第 n 个特征,是 0-k 的分类属性。
      • loss:损失函数,梯度提升计算时需要最小化的该函数。支持“logLoss” (默认), “leastSquaresError”, “leastAbsoluteError”
      • numIterations:提升的迭代次数,默认 100.
      • learningRate:学习率,取值(0,1]
      • maxDepth:树的最大深度
      • maxBins:每个特征分裂时,最大划分(桶)数量,默认 32
    • 方法:
      trainRegressor(data, categoricalFeaturesInfo, loss='leastSquaresError', numIterations=100, learningRate=0.1, maxDepth=3, maxBins=32)
      训练一个用于回归预测的梯度提升决策树模型。

      • data:训练数据集,格式为 LabeledPoint 的 RDD。label 为实数。
      • categoricalFeaturesInfo:向量中为分类属性的索引表。任务没有出现在该列表中的特征将会以连续值处理。{n:k}表示第 n 个特征,是 0-k 的分类属性。
      • loss:损失函数,梯度提升计算时需要最小化的该函数。支持“logLoss” (默认), “leastSquaresError”, “leastAbsoluteError”
      • numIterations:提升的迭代次数,默认 100.
      • learningRate:学习率,取值(0,1]
      • maxDepth:树的最大深度
      • maxBins:每个特征分裂时,最大划分(桶)数量,默认 32

  • 类: pyspark.mllib.tree.GradientBoostedTreesModel(java_model)

    • 方法: call(name, *a)
      调用 java 模型
    • 方法: load(sc, path)
      从指定 path 加载决策树模型
    • 方法: numTrees()
      获取随机森林中树的数量
    • 方法: predict(x)
      预测一个或多个样本的 label 值
    • 方法: save(sc, path)
      将决策树模型持久化到指定 path
    • 方法: toDebugString()
      以 string 输出整个模型的信息
    • 方法: totalNumNodes()
      获得森林中所有树的节点总和
  • Spark

    Spark 是 UC Berkeley AMP lab 所开源的类 Hadoop MapReduce 的通用并行框架。Spark 拥有 Hadoop MapReduce 所具有的优点;但不同于 MapReduce 的是 Job 中间输出结果可以保存在内存中,从而不再需要读写 HDFS,因此 Spark 能更好地适用于数据挖掘与机器学习等需要迭代的 MapReduce 的算法。

    74 引用 • 46 回帖 • 552 关注
  • 数据挖掘
    17 引用 • 32 回帖 • 3 关注

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...

推荐标签 标签

  • React

    React 是 Facebook 开源的一个用于构建 UI 的 JavaScript 库。

    192 引用 • 291 回帖 • 384 关注
  • 深度学习

    深度学习(Deep Learning)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高层抽象的算法。

    53 引用 • 40 回帖
  • 链滴

    链滴是一个记录生活的地方。

    记录生活,连接点滴

    153 引用 • 3783 回帖 • 1 关注
  • Python

    Python 是一种面向对象、直译式电脑编程语言,具有近二十年的发展历史,成熟且稳定。它包含了一组完善而且容易理解的标准库,能够轻松完成很多常见的任务。它的语法简捷和清晰,尽量使用无异义的英语单词,与其它大多数程序设计语言使用大括号不一样,它使用缩进来定义语句块。

    543 引用 • 672 回帖 • 1 关注
  • Rust

    Rust 是一门赋予每个人构建可靠且高效软件能力的语言。Rust 由 Mozilla 开发,最早发布于 2014 年 9 月。

    58 引用 • 22 回帖
  • DevOps

    DevOps(Development 和 Operations 的组合词)是一组过程、方法与系统的统称,用于促进开发(应用程序/软件工程)、技术运营和质量保障(QA)部门之间的沟通、协作与整合。

    47 引用 • 25 回帖
  • Flume

    Flume 是一套分布式的、可靠的,可用于有效地收集、聚合和搬运大量日志数据的服务架构。

    9 引用 • 6 回帖 • 629 关注
  • Linux

    Linux 是一套免费使用和自由传播的类 Unix 操作系统,是一个基于 POSIX 和 Unix 的多用户、多任务、支持多线程和多 CPU 的操作系统。它能运行主要的 Unix 工具软件、应用程序和网络协议,并支持 32 位和 64 位硬件。Linux 继承了 Unix 以网络为核心的设计思想,是一个性能稳定的多用户网络操作系统。

    943 引用 • 943 回帖
  • DNSPod

    DNSPod 建立于 2006 年 3 月份,是一款免费智能 DNS 产品。 DNSPod 可以为同时有电信、网通、教育网服务器的网站提供智能的解析,让电信用户访问电信的服务器,网通的用户访问网通的服务器,教育网的用户访问教育网的服务器,达到互联互通的效果。

    6 引用 • 26 回帖 • 510 关注
  • 互联网

    互联网(Internet),又称网际网络,或音译因特网、英特网。互联网始于 1969 年美国的阿帕网,是网络与网络之间所串连成的庞大网络,这些网络以一组通用的协议相连,形成逻辑上的单一巨大国际网络。

    98 引用 • 344 回帖
  • Kafka

    Kafka 是一种高吞吐量的分布式发布订阅消息系统,它可以处理消费者规模的网站中的所有动作流数据。 这种动作(网页浏览,搜索和其他用户的行动)是现代系统中许多功能的基础。 这些数据通常是由于吞吐量的要求而通过处理日志和日志聚合来解决。

    36 引用 • 35 回帖
  • InfluxDB

    InfluxDB 是一个开源的没有外部依赖的时间序列数据库。适用于记录度量,事件及实时分析。

    2 引用 • 71 关注
  • GitBook

    GitBook 使您的团队可以轻松编写和维护高质量的文档。 分享知识,提高团队的工作效率,让用户满意。

    3 引用 • 8 回帖 • 4 关注
  • IDEA

    IDEA 全称 IntelliJ IDEA,是一款 Java 语言开发的集成环境,在业界被公认为最好的 Java 开发工具之一。IDEA 是 JetBrains 公司的产品,这家公司总部位于捷克共和国的首都布拉格,开发人员以严谨著称的东欧程序员为主。

    180 引用 • 400 回帖
  • Elasticsearch

    Elasticsearch 是一个基于 Lucene 的搜索服务器。它提供了一个分布式多用户能力的全文搜索引擎,基于 RESTful 接口。Elasticsearch 是用 Java 开发的,并作为 Apache 许可条款下的开放源码发布,是当前流行的企业级搜索引擎。设计用于云计算中,能够达到实时搜索,稳定,可靠,快速,安装使用方便。

    117 引用 • 99 回帖 • 211 关注
  • 尊园地产

    昆明尊园房地产经纪有限公司,即:Kunming Zunyuan Property Agency Company Limited(简称“尊园地产”)于 2007 年 6 月开始筹备,2007 年 8 月 18 日正式成立,注册资本 200 万元,公司性质为股份经纪有限公司,主营业务为:代租、代售、代办产权过户、办理银行按揭、担保、抵押、评估等。

    1 引用 • 22 回帖 • 762 关注
  • 数据库

    据说 99% 的性能瓶颈都在数据库。

    340 引用 • 708 回帖
  • 博客

    记录并分享人生的经历。

    273 引用 • 2388 回帖
  • SVN

    SVN 是 Subversion 的简称,是一个开放源代码的版本控制系统,相较于 RCS、CVS,它采用了分支管理系统,它的设计目标就是取代 CVS。

    29 引用 • 98 回帖 • 680 关注
  • AngularJS

    AngularJS 诞生于 2009 年,由 Misko Hevery 等人创建,后为 Google 所收购。是一款优秀的前端 JS 框架,已经被用于 Google 的多款产品当中。AngularJS 有着诸多特性,最为核心的是:MVC、模块化、自动化双向数据绑定、语义化标签、依赖注入等。2.0 版本后已经改名为 Angular。

    12 引用 • 50 回帖 • 474 关注
  • 音乐

    你听到信仰的声音了么?

    60 引用 • 511 回帖
  • 前端

    前端技术一般分为前端设计和前端开发,前端设计可以理解为网站的视觉设计,前端开发则是网站的前台代码实现,包括 HTML、CSS 以及 JavaScript 等。

    247 引用 • 1348 回帖
  • Tomcat

    Tomcat 最早是由 Sun Microsystems 开发的一个 Servlet 容器,在 1999 年被捐献给 ASF(Apache Software Foundation),隶属于 Jakarta 项目,现在已经独立为一个顶级项目。Tomcat 主要实现了 JavaEE 中的 Servlet、JSP 规范,同时也提供 HTTP 服务,是市场上非常流行的 Java Web 容器。

    162 引用 • 529 回帖
  • ZooKeeper

    ZooKeeper 是一个分布式的,开放源码的分布式应用程序协调服务,是 Google 的 Chubby 一个开源的实现,是 Hadoop 和 HBase 的重要组件。它是一个为分布式应用提供一致性服务的软件,提供的功能包括:配置维护、域名服务、分布式同步、组服务等。

    59 引用 • 29 回帖 • 5 关注
  • Ant-Design

    Ant Design 是服务于企业级产品的设计体系,基于确定和自然的设计价值观上的模块化解决方案,让设计者和开发者专注于更好的用户体验。

    17 引用 • 23 回帖
  • 强迫症

    强迫症(OCD)属于焦虑障碍的一种类型,是一组以强迫思维和强迫行为为主要临床表现的神经精神疾病,其特点为有意识的强迫和反强迫并存,一些毫无意义、甚至违背自己意愿的想法或冲动反反复复侵入患者的日常生活。

    15 引用 • 161 回帖
  • 链书

    链书(Chainbook)是 B3log 开源社区提供的区块链纸质书交易平台,通过 B3T 实现共享激励与价值链。可将你的闲置书籍上架到链书,我们共同构建这个全新的交易平台,让闲置书籍继续发挥它的价值。

    链书社

    链书目前已经下线,也许以后还有计划重制上线。

    14 引用 • 257 回帖