Spark mllib API- tree

本贴最后更新于 2915 天前,其中的信息可能已经天翻地覆

spark 在 tree 这个模块中提供了 DecisionTree、RandomForest、GradientBoostedTrees 三种算法。均属于分类/回归 树模型。
三种算法均可用于回归预测。其中决策树和决策森林可用于二元或多元分类,GBT 只能用于二元分类。

随机森林和 GBT 均属于组合模型,解决模型过拟合问题。

##DecisionTree 决策树

  • 类:pyspark.mllib.tree.DecisionTree
    决策树算法,训练决策树模型,提供分类和回归。

    • 方法:
      trainClassifier(data, numClasses, categoricalFeaturesInfo, impurity='gini', maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0)
      训练用于分类的二叉树模型。

      • data:训练数据集,格式为 LabeledPoint 的 RDD,LabeledPoint 中的 Label 是整数。
      • numClasses:分类的个数。
      • categoricalFeaturesInfo:向量中为分类属性的索引表。任务没有出现在该列表中的特征将会以连续值处理。{n:k}表示第 n 个特征,是 0-k 的分类属性。
      • impurity:纯度计算,支持“entropy”和“gini”(默认)
      • maxDepth:决策树的最大深度,默认 5
      • maxBins:每个特征分裂时,最大划分(桶)数量,默认 32
      • minInstancesPerNode:需要保证节点分割出的左右子节点的最少的样本数量达到这个值,默认 1
      • minInfoGain:当前节点的所有属性分割带来的信息增益都比这个值要小,默认 0.0
    • 方法:
      trainRegressor(data, categoricalFeaturesInfo, impurity='variance', maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0)
      训练用于回归的二叉树模型。

      • data:训练数据集,格式为 LabeledPoint 的 RDD,LabeledPoint 中的 Label 是实数。
      • categoricalFeaturesInfo:向量中为分类属性的索引表。任务没有出现在该列表中的特征将会以连续值处理。{n:k}表示第 n 个特征,是 0-k 的分类属性。
      • impurity:纯度计算,支持“variance”(默认)
      • maxDepth:决策树的最大深度,默认 5
      • maxBins:每个特征分裂时,最大划分(桶)数量,默认 32
      • minInstancesPerNode:需要保证节点分割出的左右子节点的最少的样本数量达到这个值,默认 1
      • minInfoGain:当前节点的所有属性分割带来的信息增益都比这个值要小,默认 0.0

  • 类:pyspark.mllib.tree.DecisionTreeModel(java_model)

    • 方法: call(name, *a)
      调用 java 模型
    • 方法: depth()
      获取决策树的深度
    • 方法: load(sc, path)
      从指定 path 加载决策树模型
    • 方法: numNodes()
      获取决策树的节点数量,包括叶子节点
    • 方法: predict(x)
      预测一个或多个样本的 label 值
    • 方法: save(sc, path)
      将决策树模型持久化到指定 path
    • 方法: toDebugString()
      以 string 输出整个模型的信息

##RandomForest 随机森林

  • 类:pyspark.mllib.tree.RandomForest
    • 方法:
      trainClassifier(data, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy='auto', impurity='gini', maxDepth=4, maxBins=32, seed=None)
      训练一个用于二元或多元分类的随机森林

      • data:训练数据集,格式为 LabeledPoint 的 RDD,LabeledPoint 中的 Label 是整数。
      • numClasses:分类的个数。
      • categoricalFeaturesInfo:向量中为分类属性的索引表。任务没有出现在该列表中的特征将会以连续值处理。{n:k}表示第 n 个特征,是 0-k 的分类属性。
      • numTrees:随机森林中,树的数量。
      • featureSubsetStrategy:特征子集采样策略,支持"auto"(默认),"all","aqrt","log2","onethird"
      • impurity:纯度计算,支持“entropy”和“gini”(建议)
      • maxDepth:树的最大深度。
      • maxBins:每个特征分裂时,最大划分(桶)数量,默认 32
      • seed:用于引导和选择特征子集的随机种子。
    • 方法:
      trainRegressor(data, categoricalFeaturesInfo, numTrees, featureSubsetStrategy='auto', impurity='variance', maxDepth=4, maxBins=32, seed=None)
      训练一个用于回归预测的随机森林

      • data:训练数据集,格式为 LabeledPoint 的 RDD,LabeledPoint 中的 Label 是实数。
      • categoricalFeaturesInfo:向量中为分类属性的索引表。任务没有出现在该列表中的特征将会以连续值处理。{n:k}表示第 n 个特征,是 0-k 的分类属性。
      • numTrees:随机森林中,树的数量。
      • featureSubsetStrategy:特征子集采样策略,支持"auto"(默认),"all","aqrt","log2","onethird"
      • impurity:纯度计算,支持“variance”
      • maxDepth:树的最大深度。
      • maxBins:每个特征分裂时,最大划分(桶)数量,默认 32
      • seed:用于引导和选择特征子集的随机种子。

  • 类:pyspark.mllib.tree.RandomForestModel(java_model)

    • 方法: call(name, *a)
      调用 java 模型
    • 方法: load(sc, path)
      从指定 path 加载决策树模型
    • 方法: numTrees()
      获取随机森林中树的数量
    • 方法: predict(x)
      预测一个或多个样本的 label 值
    • 方法: save(sc, path)
      将决策树模型持久化到指定 path
    • 方法: toDebugString()
      以 string 输出整个模型的信息
    • 方法: totalNumNodes()
      获得森林中所有树的节点总和

##GradientBoostedTrees(GBT) 梯度提升决策树
这是一种模型组合的方法,利用简单模型的组合克服过拟合等问题。常用于推荐系统。

  • 类:pyspark.mllib.tree.GradientBoostedTrees
    • 方法:
      trainClassifier(data, categoricalFeaturesInfo, loss='logLoss', numIterations=100, learningRate=0.1, maxDepth=3, maxBins=32)
      训练一个用于二元分类预测的梯度提升决策树模型。

      • data:训练数据集,格式为 LabeledPoint 的 RDD。label 必须为 0 或 1.
      • categoricalFeaturesInfo:向量中为分类属性的索引表。任务没有出现在该列表中的特征将会以连续值处理。{n:k}表示第 n 个特征,是 0-k 的分类属性。
      • loss:损失函数,梯度提升计算时需要最小化的该函数。支持“logLoss” (默认), “leastSquaresError”, “leastAbsoluteError”
      • numIterations:提升的迭代次数,默认 100.
      • learningRate:学习率,取值(0,1]
      • maxDepth:树的最大深度
      • maxBins:每个特征分裂时,最大划分(桶)数量,默认 32
    • 方法:
      trainRegressor(data, categoricalFeaturesInfo, loss='leastSquaresError', numIterations=100, learningRate=0.1, maxDepth=3, maxBins=32)
      训练一个用于回归预测的梯度提升决策树模型。

      • data:训练数据集,格式为 LabeledPoint 的 RDD。label 为实数。
      • categoricalFeaturesInfo:向量中为分类属性的索引表。任务没有出现在该列表中的特征将会以连续值处理。{n:k}表示第 n 个特征,是 0-k 的分类属性。
      • loss:损失函数,梯度提升计算时需要最小化的该函数。支持“logLoss” (默认), “leastSquaresError”, “leastAbsoluteError”
      • numIterations:提升的迭代次数,默认 100.
      • learningRate:学习率,取值(0,1]
      • maxDepth:树的最大深度
      • maxBins:每个特征分裂时,最大划分(桶)数量,默认 32

  • 类: pyspark.mllib.tree.GradientBoostedTreesModel(java_model)

    • 方法: call(name, *a)
      调用 java 模型
    • 方法: load(sc, path)
      从指定 path 加载决策树模型
    • 方法: numTrees()
      获取随机森林中树的数量
    • 方法: predict(x)
      预测一个或多个样本的 label 值
    • 方法: save(sc, path)
      将决策树模型持久化到指定 path
    • 方法: toDebugString()
      以 string 输出整个模型的信息
    • 方法: totalNumNodes()
      获得森林中所有树的节点总和
  • Spark

    Spark 是 UC Berkeley AMP lab 所开源的类 Hadoop MapReduce 的通用并行框架。Spark 拥有 Hadoop MapReduce 所具有的优点;但不同于 MapReduce 的是 Job 中间输出结果可以保存在内存中,从而不再需要读写 HDFS,因此 Spark 能更好地适用于数据挖掘与机器学习等需要迭代的 MapReduce 的算法。

    74 引用 • 46 回帖 • 548 关注
  • 数据挖掘
    17 引用 • 32 回帖 • 2 关注

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...

推荐标签 标签

  • Linux

    Linux 是一套免费使用和自由传播的类 Unix 操作系统,是一个基于 POSIX 和 Unix 的多用户、多任务、支持多线程和多 CPU 的操作系统。它能运行主要的 Unix 工具软件、应用程序和网络协议,并支持 32 位和 64 位硬件。Linux 继承了 Unix 以网络为核心的设计思想,是一个性能稳定的多用户网络操作系统。

    915 引用 • 931 回帖
  • JetBrains

    JetBrains 是一家捷克的软件开发公司,该公司位于捷克的布拉格,并在俄国的圣彼得堡及美国麻州波士顿都设有办公室,该公司最为人所熟知的产品是 Java 编程语言开发撰写时所用的集成开发环境:IntelliJ IDEA

    18 引用 • 54 回帖 • 1 关注
  • Q&A

    提问之前请先看《提问的智慧》,好的问题比好的答案更有价值。

    6546 引用 • 29416 回帖 • 244 关注
  • Windows

    Microsoft Windows 是美国微软公司研发的一套操作系统,它问世于 1985 年,起初仅仅是 Microsoft-DOS 模拟环境,后续的系统版本由于微软不断的更新升级,不但易用,也慢慢的成为家家户户人们最喜爱的操作系统。

    215 引用 • 462 回帖 • 1 关注
  • 自由行
    1 关注
  • Hadoop

    Hadoop 是由 Apache 基金会所开发的一个分布式系统基础架构。用户可以在不了解分布式底层细节的情况下,开发分布式程序。充分利用集群的威力进行高速运算和存储。

    82 引用 • 122 回帖 • 620 关注
  • Tomcat

    Tomcat 最早是由 Sun Microsystems 开发的一个 Servlet 容器,在 1999 年被捐献给 ASF(Apache Software Foundation),隶属于 Jakarta 项目,现在已经独立为一个顶级项目。Tomcat 主要实现了 JavaEE 中的 Servlet、JSP 规范,同时也提供 HTTP 服务,是市场上非常流行的 Java Web 容器。

    162 引用 • 529 回帖 • 3 关注
  • JRebel

    JRebel 是一款 Java 虚拟机插件,它使得 Java 程序员能在不进行重部署的情况下,即时看到代码的改变对一个应用程序带来的影响。

    26 引用 • 78 回帖 • 623 关注
  • NetBeans

    NetBeans 是一个始于 1997 年的 Xelfi 计划,本身是捷克布拉格查理大学的数学及物理学院的学生计划。此计划延伸而成立了一家公司进而发展这个商用版本的 NetBeans IDE,直到 1999 年 Sun 买下此公司。Sun 于次年(2000 年)六月将 NetBeans IDE 开源,直到现在 NetBeans 的社群依然持续增长。

    78 引用 • 102 回帖 • 643 关注
  • Ant-Design

    Ant Design 是服务于企业级产品的设计体系,基于确定和自然的设计价值观上的模块化解决方案,让设计者和开发者专注于更好的用户体验。

    17 引用 • 23 回帖 • 2 关注
  • 导航

    各种网址链接、内容导航。

    37 引用 • 168 回帖
  • Firefox

    Mozilla Firefox 中文俗称“火狐”(正式缩写为 Fx 或 fx,非正式缩写为 FF),是一个开源的网页浏览器,使用 Gecko 排版引擎,支持多种操作系统,如 Windows、OSX 及 Linux 等。

    7 引用 • 30 回帖 • 451 关注
  • Java

    Java 是一种可以撰写跨平台应用软件的面向对象的程序设计语言,是由 Sun Microsystems 公司于 1995 年 5 月推出的。Java 技术具有卓越的通用性、高效性、平台移植性和安全性。

    3168 引用 • 8207 回帖
  • LaTeX

    LaTeX(音译“拉泰赫”)是一种基于 ΤΕΧ 的排版系统,由美国计算机学家莱斯利·兰伯特(Leslie Lamport)在 20 世纪 80 年代初期开发,利用这种格式,即使使用者没有排版和程序设计的知识也可以充分发挥由 TeX 所提供的强大功能,能在几天,甚至几小时内生成很多具有书籍质量的印刷品。对于生成复杂表格和数学公式,这一点表现得尤为突出。因此它非常适用于生成高印刷质量的科技和数学类文档。

    9 引用 • 32 回帖 • 166 关注
  • OpenShift

    红帽提供的 PaaS 云,支持多种编程语言,为开发人员提供了更为灵活的框架、存储选择。

    14 引用 • 20 回帖 • 602 关注
  • PWA

    PWA(Progressive Web App)是 Google 在 2015 年提出、2016 年 6 月开始推广的项目。它结合了一系列现代 Web 技术,在网页应用中实现和原生应用相近的用户体验。

    14 引用 • 69 回帖 • 133 关注
  • Ubuntu

    Ubuntu(友帮拓、优般图、乌班图)是一个以桌面应用为主的 Linux 操作系统,其名称来自非洲南部祖鲁语或豪萨语的“ubuntu”一词,意思是“人性”、“我的存在是因为大家的存在”,是非洲传统的一种价值观,类似华人社会的“仁爱”思想。Ubuntu 的目标在于为一般用户提供一个最新的、同时又相当稳定的主要由自由软件构建而成的操作系统。

    123 引用 • 168 回帖
  • golang

    Go 语言是 Google 推出的一种全新的编程语言,可以在不损失应用程序性能的情况下降低代码的复杂性。谷歌首席软件工程师罗布派克(Rob Pike)说:我们之所以开发 Go,是因为过去 10 多年间软件开发的难度令人沮丧。Go 是谷歌 2009 发布的第二款编程语言。

    492 引用 • 1383 回帖 • 375 关注
  • Telegram

    Telegram 是一个非盈利性、基于云端的即时消息服务。它提供了支持各大操作系统平台的开源的客户端,也提供了很多强大的 APIs 给开发者创建自己的客户端和机器人。

    5 引用 • 35 回帖 • 1 关注
  • HBase

    HBase 是一个分布式的、面向列的开源数据库,该技术来源于 Fay Chang 所撰写的 Google 论文 “Bigtable:一个结构化数据的分布式存储系统”。就像 Bigtable 利用了 Google 文件系统所提供的分布式数据存储一样,HBase 在 Hadoop 之上提供了类似于 Bigtable 的能力。

    17 引用 • 6 回帖 • 45 关注
  • 友情链接

    确认过眼神后的灵魂连接,站在链在!

    24 引用 • 373 回帖 • 1 关注
  • 房星科技

    房星网,我们不和没有钱的程序员谈理想,我们要让程序员又有理想又有钱。我们有雄厚的房地产行业线下资源,遍布昆明全城的 100 家门店、四千地产经纪人是我们坚实的后盾。

    6 引用 • 141 回帖 • 559 关注
  • Angular

    AngularAngularJS 的新版本。

    26 引用 • 66 回帖 • 511 关注
  • 游戏

    沉迷游戏伤身,强撸灰飞烟灭。

    169 引用 • 799 回帖
  • B3log

    B3log 是一个开源组织,名字来源于“Bulletin Board Blog”缩写,目标是将独立博客与论坛结合,形成一种新的网络社区体验,详细请看 B3log 构思。目前 B3log 已经开源了多款产品:SymSoloVditor思源笔记

    1083 引用 • 3461 回帖 • 286 关注
  • Markdown

    Markdown 是一种轻量级标记语言,用户可使用纯文本编辑器来排版文档,最终通过 Markdown 引擎将文档转换为所需格式(比如 HTML、PDF 等)。

    163 引用 • 1450 回帖
  • Solidity

    Solidity 是一种智能合约高级语言,运行在 [以太坊] 虚拟机(EVM)之上。它的语法接近于 JavaScript,是一种面向对象的语言。

    3 引用 • 18 回帖 • 350 关注