Spark mllib API- tree

本贴最后更新于 3156 天前,其中的信息可能已经天翻地覆

spark 在 tree 这个模块中提供了 DecisionTree、RandomForest、GradientBoostedTrees 三种算法。均属于分类/回归 树模型。
三种算法均可用于回归预测。其中决策树和决策森林可用于二元或多元分类,GBT 只能用于二元分类。

随机森林和 GBT 均属于组合模型,解决模型过拟合问题。

##DecisionTree 决策树

  • 类:pyspark.mllib.tree.DecisionTree
    决策树算法,训练决策树模型,提供分类和回归。

    • 方法:
      trainClassifier(data, numClasses, categoricalFeaturesInfo, impurity='gini', maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0)
      训练用于分类的二叉树模型。

      • data:训练数据集,格式为 LabeledPoint 的 RDD,LabeledPoint 中的 Label 是整数。
      • numClasses:分类的个数。
      • categoricalFeaturesInfo:向量中为分类属性的索引表。任务没有出现在该列表中的特征将会以连续值处理。{n:k}表示第 n 个特征,是 0-k 的分类属性。
      • impurity:纯度计算,支持“entropy”和“gini”(默认)
      • maxDepth:决策树的最大深度,默认 5
      • maxBins:每个特征分裂时,最大划分(桶)数量,默认 32
      • minInstancesPerNode:需要保证节点分割出的左右子节点的最少的样本数量达到这个值,默认 1
      • minInfoGain:当前节点的所有属性分割带来的信息增益都比这个值要小,默认 0.0
    • 方法:
      trainRegressor(data, categoricalFeaturesInfo, impurity='variance', maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0)
      训练用于回归的二叉树模型。

      • data:训练数据集,格式为 LabeledPoint 的 RDD,LabeledPoint 中的 Label 是实数。
      • categoricalFeaturesInfo:向量中为分类属性的索引表。任务没有出现在该列表中的特征将会以连续值处理。{n:k}表示第 n 个特征,是 0-k 的分类属性。
      • impurity:纯度计算,支持“variance”(默认)
      • maxDepth:决策树的最大深度,默认 5
      • maxBins:每个特征分裂时,最大划分(桶)数量,默认 32
      • minInstancesPerNode:需要保证节点分割出的左右子节点的最少的样本数量达到这个值,默认 1
      • minInfoGain:当前节点的所有属性分割带来的信息增益都比这个值要小,默认 0.0

  • 类:pyspark.mllib.tree.DecisionTreeModel(java_model)

    • 方法: call(name, *a)
      调用 java 模型
    • 方法: depth()
      获取决策树的深度
    • 方法: load(sc, path)
      从指定 path 加载决策树模型
    • 方法: numNodes()
      获取决策树的节点数量,包括叶子节点
    • 方法: predict(x)
      预测一个或多个样本的 label 值
    • 方法: save(sc, path)
      将决策树模型持久化到指定 path
    • 方法: toDebugString()
      以 string 输出整个模型的信息

##RandomForest 随机森林

  • 类:pyspark.mllib.tree.RandomForest
    • 方法:
      trainClassifier(data, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy='auto', impurity='gini', maxDepth=4, maxBins=32, seed=None)
      训练一个用于二元或多元分类的随机森林

      • data:训练数据集,格式为 LabeledPoint 的 RDD,LabeledPoint 中的 Label 是整数。
      • numClasses:分类的个数。
      • categoricalFeaturesInfo:向量中为分类属性的索引表。任务没有出现在该列表中的特征将会以连续值处理。{n:k}表示第 n 个特征,是 0-k 的分类属性。
      • numTrees:随机森林中,树的数量。
      • featureSubsetStrategy:特征子集采样策略,支持"auto"(默认),"all","aqrt","log2","onethird"
      • impurity:纯度计算,支持“entropy”和“gini”(建议)
      • maxDepth:树的最大深度。
      • maxBins:每个特征分裂时,最大划分(桶)数量,默认 32
      • seed:用于引导和选择特征子集的随机种子。
    • 方法:
      trainRegressor(data, categoricalFeaturesInfo, numTrees, featureSubsetStrategy='auto', impurity='variance', maxDepth=4, maxBins=32, seed=None)
      训练一个用于回归预测的随机森林

      • data:训练数据集,格式为 LabeledPoint 的 RDD,LabeledPoint 中的 Label 是实数。
      • categoricalFeaturesInfo:向量中为分类属性的索引表。任务没有出现在该列表中的特征将会以连续值处理。{n:k}表示第 n 个特征,是 0-k 的分类属性。
      • numTrees:随机森林中,树的数量。
      • featureSubsetStrategy:特征子集采样策略,支持"auto"(默认),"all","aqrt","log2","onethird"
      • impurity:纯度计算,支持“variance”
      • maxDepth:树的最大深度。
      • maxBins:每个特征分裂时,最大划分(桶)数量,默认 32
      • seed:用于引导和选择特征子集的随机种子。

  • 类:pyspark.mllib.tree.RandomForestModel(java_model)

    • 方法: call(name, *a)
      调用 java 模型
    • 方法: load(sc, path)
      从指定 path 加载决策树模型
    • 方法: numTrees()
      获取随机森林中树的数量
    • 方法: predict(x)
      预测一个或多个样本的 label 值
    • 方法: save(sc, path)
      将决策树模型持久化到指定 path
    • 方法: toDebugString()
      以 string 输出整个模型的信息
    • 方法: totalNumNodes()
      获得森林中所有树的节点总和

##GradientBoostedTrees(GBT) 梯度提升决策树
这是一种模型组合的方法,利用简单模型的组合克服过拟合等问题。常用于推荐系统。

  • 类:pyspark.mllib.tree.GradientBoostedTrees
    • 方法:
      trainClassifier(data, categoricalFeaturesInfo, loss='logLoss', numIterations=100, learningRate=0.1, maxDepth=3, maxBins=32)
      训练一个用于二元分类预测的梯度提升决策树模型。

      • data:训练数据集,格式为 LabeledPoint 的 RDD。label 必须为 0 或 1.
      • categoricalFeaturesInfo:向量中为分类属性的索引表。任务没有出现在该列表中的特征将会以连续值处理。{n:k}表示第 n 个特征,是 0-k 的分类属性。
      • loss:损失函数,梯度提升计算时需要最小化的该函数。支持“logLoss” (默认), “leastSquaresError”, “leastAbsoluteError”
      • numIterations:提升的迭代次数,默认 100.
      • learningRate:学习率,取值(0,1]
      • maxDepth:树的最大深度
      • maxBins:每个特征分裂时,最大划分(桶)数量,默认 32
    • 方法:
      trainRegressor(data, categoricalFeaturesInfo, loss='leastSquaresError', numIterations=100, learningRate=0.1, maxDepth=3, maxBins=32)
      训练一个用于回归预测的梯度提升决策树模型。

      • data:训练数据集,格式为 LabeledPoint 的 RDD。label 为实数。
      • categoricalFeaturesInfo:向量中为分类属性的索引表。任务没有出现在该列表中的特征将会以连续值处理。{n:k}表示第 n 个特征,是 0-k 的分类属性。
      • loss:损失函数,梯度提升计算时需要最小化的该函数。支持“logLoss” (默认), “leastSquaresError”, “leastAbsoluteError”
      • numIterations:提升的迭代次数,默认 100.
      • learningRate:学习率,取值(0,1]
      • maxDepth:树的最大深度
      • maxBins:每个特征分裂时,最大划分(桶)数量,默认 32

  • 类: pyspark.mllib.tree.GradientBoostedTreesModel(java_model)

    • 方法: call(name, *a)
      调用 java 模型
    • 方法: load(sc, path)
      从指定 path 加载决策树模型
    • 方法: numTrees()
      获取随机森林中树的数量
    • 方法: predict(x)
      预测一个或多个样本的 label 值
    • 方法: save(sc, path)
      将决策树模型持久化到指定 path
    • 方法: toDebugString()
      以 string 输出整个模型的信息
    • 方法: totalNumNodes()
      获得森林中所有树的节点总和
  • Spark

    Spark 是 UC Berkeley AMP lab 所开源的类 Hadoop MapReduce 的通用并行框架。Spark 拥有 Hadoop MapReduce 所具有的优点;但不同于 MapReduce 的是 Job 中间输出结果可以保存在内存中,从而不再需要读写 HDFS,因此 Spark 能更好地适用于数据挖掘与机器学习等需要迭代的 MapReduce 的算法。

    74 引用 • 46 回帖 • 559 关注
  • 数据挖掘
    17 引用 • 32 回帖 • 3 关注

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...

推荐标签 标签

  • 尊园地产

    昆明尊园房地产经纪有限公司,即:Kunming Zunyuan Property Agency Company Limited(简称“尊园地产”)于 2007 年 6 月开始筹备,2007 年 8 月 18 日正式成立,注册资本 200 万元,公司性质为股份经纪有限公司,主营业务为:代租、代售、代办产权过户、办理银行按揭、担保、抵押、评估等。

    1 引用 • 22 回帖 • 772 关注
  • Logseq

    Logseq 是一个隐私优先、开源的知识库工具。

    Logseq is a joyful, open-source outliner that works on top of local plain-text Markdown and Org-mode files. Use it to write, organize and share your thoughts, keep your to-do list, and build your own digital garden.

    6 引用 • 63 回帖 • 5 关注
  • 心情

    心是产生任何想法的源泉,心本体会陷入到对自己本体不能理解的状态中,因为心能产生任何想法,不能分出对错,不能分出自己。

    59 引用 • 369 回帖
  • CentOS

    CentOS(Community Enterprise Operating System)是 Linux 发行版之一,它是来自于 Red Hat Enterprise Linux 依照开放源代码规定释出的源代码所编译而成。由于出自同样的源代码,因此有些要求高度稳定的服务器以 CentOS 替代商业版的 Red Hat Enterprise Linux 使用。两者的不同在于 CentOS 并不包含封闭源代码软件。

    238 引用 • 224 回帖
  • 快应用

    快应用 是基于手机硬件平台的新型应用形态;标准是由主流手机厂商组成的快应用联盟联合制定;快应用标准的诞生将在研发接口、能力接入、开发者服务等层面建设标准平台;以平台化的生态模式对个人开发者和企业开发者全品类开放。

    15 引用 • 127 回帖
  • abitmean

    有点意思就行了

    27 关注
  • 大数据

    大数据(big data)是指无法在一定时间范围内用常规软件工具进行捕捉、管理和处理的数据集合,是需要新处理模式才能具有更强的决策力、洞察发现力和流程优化能力的海量、高增长率和多样化的信息资产。

    93 引用 • 113 回帖
  • 服务器

    服务器,也称伺服器,是提供计算服务的设备。由于服务器需要响应服务请求,并进行处理,因此一般来说服务器应具备承担服务并且保障服务的能力。

    125 引用 • 588 回帖
  • GitBook

    GitBook 使您的团队可以轻松编写和维护高质量的文档。 分享知识,提高团队的工作效率,让用户满意。

    3 引用 • 8 回帖
  • 又拍云

    又拍云是国内领先的 CDN 服务提供商,国家工信部认证通过的“可信云”,乌云众测平台认证的“安全云”,为移动时代的创业者提供新一代的 CDN 加速服务。

    21 引用 • 37 回帖 • 548 关注
  • Google

    Google(Google Inc.,NASDAQ:GOOG)是一家美国上市公司(公有股份公司),于 1998 年 9 月 7 日以私有股份公司的形式创立,设计并管理一个互联网搜索引擎。Google 公司的总部称作“Googleplex”,它位于加利福尼亚山景城。Google 目前被公认为是全球规模最大的搜索引擎,它提供了简单易用的免费服务。不作恶(Don't be evil)是谷歌公司的一项非正式的公司口号。

    49 引用 • 192 回帖
  • 智能合约

    智能合约(Smart contract)是一种旨在以信息化方式传播、验证或执行合同的计算机协议。智能合约允许在没有第三方的情况下进行可信交易,这些交易可追踪且不可逆转。智能合约概念于 1994 年由 Nick Szabo 首次提出。

    1 引用 • 11 回帖 • 2 关注
  • 倾城之链
    23 引用 • 66 回帖 • 138 关注
  • 人工智能

    人工智能(Artificial Intelligence)是研究、开发用于模拟、延伸和扩展人的智能的理论、方法、技术及应用系统的一门技术科学。

    135 引用 • 190 回帖
  • Android

    Android 是一种以 Linux 为基础的开放源码操作系统,主要使用于便携设备。2005 年由 Google 收购注资,并拉拢多家制造商组成开放手机联盟开发改良,逐渐扩展到到平板电脑及其他领域上。

    334 引用 • 323 回帖 • 4 关注
  • 域名

    域名(Domain Name),简称域名、网域,是由一串用点分隔的名字组成的 Internet 上某一台计算机或计算机组的名称,用于在数据传输时标识计算机的电子方位(有时也指地理位置)。

    43 引用 • 208 回帖
  • 脑图

    脑图又叫思维导图,是表达发散性思维的有效图形思维工具 ,它简单却又很有效,是一种实用性的思维工具。

    30 引用 • 96 回帖 • 1 关注
  • 互联网

    互联网(Internet),又称网际网络,或音译因特网、英特网。互联网始于 1969 年美国的阿帕网,是网络与网络之间所串连成的庞大网络,这些网络以一组通用的协议相连,形成逻辑上的单一巨大国际网络。

    98 引用 • 344 回帖
  • wolai

    我来 wolai:不仅仅是未来的云端笔记!

    2 引用 • 14 回帖
  • HHKB

    HHKB 是富士通的 Happy Hacking 系列电容键盘。电容键盘即无接点静电电容式键盘(Capacitive Keyboard)。

    5 引用 • 74 回帖 • 478 关注
  • Sublime

    Sublime Text 是一款可以用来写代码、写文章的文本编辑器。支持代码高亮、自动完成,还支持通过插件进行扩展。

    10 引用 • 5 回帖
  • Lute

    Lute 是一款结构化的 Markdown 引擎,支持 Go 和 JavaScript。

    26 引用 • 196 回帖 • 17 关注
  • 资讯

    资讯是用户因为及时地获得它并利用它而能够在相对短的时间内给自己带来价值的信息,资讯有时效性和地域性。

    55 引用 • 85 回帖
  • NGINX

    NGINX 是一个高性能的 HTTP 和反向代理服务器,也是一个 IMAP/POP3/SMTP 代理服务器。 NGINX 是由 Igor Sysoev 为俄罗斯访问量第二的 Rambler.ru 站点开发的,第一个公开版本 0.1.0 发布于 2004 年 10 月 4 日。

    313 引用 • 547 回帖
  • 禅道

    禅道是一款国产的开源项目管理软件,她的核心管理思想基于敏捷方法 scrum,内置了产品管理和项目管理,同时又根据国内研发现状补充了测试管理、计划管理、发布管理、文档管理、事务管理等功能,在一个软件中就可以将软件研发中的需求、任务、bug、用例、计划、发布等要素有序的跟踪管理起来,完整地覆盖了项目管理的核心流程。

    5 引用 • 15 回帖 • 102 关注
  • BookxNote

    BookxNote 是一款全新的电子书学习工具,助力您的学习与思考,让您的大脑更高效的记忆。

    笔记整理交给我,一心只读圣贤书。

    1 引用 • 1 回帖
  • Wide

    Wide 是一款基于 Web 的 Go 语言 IDE。通过浏览器就可以进行 Go 开发,并有代码自动完成、查看表达式、编译反馈、Lint、实时结果输出等功能。

    欢迎访问我们运维的实例: https://wide.b3log.org

    30 引用 • 218 回帖 • 635 关注