Spark mllib API- tree

本贴最后更新于 3251 天前,其中的信息可能已经天翻地覆

spark 在 tree 这个模块中提供了 DecisionTree、RandomForest、GradientBoostedTrees 三种算法。均属于分类/回归 树模型。
三种算法均可用于回归预测。其中决策树和决策森林可用于二元或多元分类,GBT 只能用于二元分类。

随机森林和 GBT 均属于组合模型,解决模型过拟合问题。

##DecisionTree 决策树

  • 类:pyspark.mllib.tree.DecisionTree
    决策树算法,训练决策树模型,提供分类和回归。

    • 方法:
      trainClassifier(data, numClasses, categoricalFeaturesInfo, impurity='gini', maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0)
      训练用于分类的二叉树模型。

      • data:训练数据集,格式为 LabeledPoint 的 RDD,LabeledPoint 中的 Label 是整数。
      • numClasses:分类的个数。
      • categoricalFeaturesInfo:向量中为分类属性的索引表。任务没有出现在该列表中的特征将会以连续值处理。{n:k}表示第 n 个特征,是 0-k 的分类属性。
      • impurity:纯度计算,支持“entropy”和“gini”(默认)
      • maxDepth:决策树的最大深度,默认 5
      • maxBins:每个特征分裂时,最大划分(桶)数量,默认 32
      • minInstancesPerNode:需要保证节点分割出的左右子节点的最少的样本数量达到这个值,默认 1
      • minInfoGain:当前节点的所有属性分割带来的信息增益都比这个值要小,默认 0.0
    • 方法:
      trainRegressor(data, categoricalFeaturesInfo, impurity='variance', maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0)
      训练用于回归的二叉树模型。

      • data:训练数据集,格式为 LabeledPoint 的 RDD,LabeledPoint 中的 Label 是实数。
      • categoricalFeaturesInfo:向量中为分类属性的索引表。任务没有出现在该列表中的特征将会以连续值处理。{n:k}表示第 n 个特征,是 0-k 的分类属性。
      • impurity:纯度计算,支持“variance”(默认)
      • maxDepth:决策树的最大深度,默认 5
      • maxBins:每个特征分裂时,最大划分(桶)数量,默认 32
      • minInstancesPerNode:需要保证节点分割出的左右子节点的最少的样本数量达到这个值,默认 1
      • minInfoGain:当前节点的所有属性分割带来的信息增益都比这个值要小,默认 0.0

  • 类:pyspark.mllib.tree.DecisionTreeModel(java_model)

    • 方法: call(name, *a)
      调用 java 模型
    • 方法: depth()
      获取决策树的深度
    • 方法: load(sc, path)
      从指定 path 加载决策树模型
    • 方法: numNodes()
      获取决策树的节点数量,包括叶子节点
    • 方法: predict(x)
      预测一个或多个样本的 label 值
    • 方法: save(sc, path)
      将决策树模型持久化到指定 path
    • 方法: toDebugString()
      以 string 输出整个模型的信息

##RandomForest 随机森林

  • 类:pyspark.mllib.tree.RandomForest
    • 方法:
      trainClassifier(data, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy='auto', impurity='gini', maxDepth=4, maxBins=32, seed=None)
      训练一个用于二元或多元分类的随机森林

      • data:训练数据集,格式为 LabeledPoint 的 RDD,LabeledPoint 中的 Label 是整数。
      • numClasses:分类的个数。
      • categoricalFeaturesInfo:向量中为分类属性的索引表。任务没有出现在该列表中的特征将会以连续值处理。{n:k}表示第 n 个特征,是 0-k 的分类属性。
      • numTrees:随机森林中,树的数量。
      • featureSubsetStrategy:特征子集采样策略,支持"auto"(默认),"all","aqrt","log2","onethird"
      • impurity:纯度计算,支持“entropy”和“gini”(建议)
      • maxDepth:树的最大深度。
      • maxBins:每个特征分裂时,最大划分(桶)数量,默认 32
      • seed:用于引导和选择特征子集的随机种子。
    • 方法:
      trainRegressor(data, categoricalFeaturesInfo, numTrees, featureSubsetStrategy='auto', impurity='variance', maxDepth=4, maxBins=32, seed=None)
      训练一个用于回归预测的随机森林

      • data:训练数据集,格式为 LabeledPoint 的 RDD,LabeledPoint 中的 Label 是实数。
      • categoricalFeaturesInfo:向量中为分类属性的索引表。任务没有出现在该列表中的特征将会以连续值处理。{n:k}表示第 n 个特征,是 0-k 的分类属性。
      • numTrees:随机森林中,树的数量。
      • featureSubsetStrategy:特征子集采样策略,支持"auto"(默认),"all","aqrt","log2","onethird"
      • impurity:纯度计算,支持“variance”
      • maxDepth:树的最大深度。
      • maxBins:每个特征分裂时,最大划分(桶)数量,默认 32
      • seed:用于引导和选择特征子集的随机种子。

  • 类:pyspark.mllib.tree.RandomForestModel(java_model)

    • 方法: call(name, *a)
      调用 java 模型
    • 方法: load(sc, path)
      从指定 path 加载决策树模型
    • 方法: numTrees()
      获取随机森林中树的数量
    • 方法: predict(x)
      预测一个或多个样本的 label 值
    • 方法: save(sc, path)
      将决策树模型持久化到指定 path
    • 方法: toDebugString()
      以 string 输出整个模型的信息
    • 方法: totalNumNodes()
      获得森林中所有树的节点总和

##GradientBoostedTrees(GBT) 梯度提升决策树
这是一种模型组合的方法,利用简单模型的组合克服过拟合等问题。常用于推荐系统。

  • 类:pyspark.mllib.tree.GradientBoostedTrees
    • 方法:
      trainClassifier(data, categoricalFeaturesInfo, loss='logLoss', numIterations=100, learningRate=0.1, maxDepth=3, maxBins=32)
      训练一个用于二元分类预测的梯度提升决策树模型。

      • data:训练数据集,格式为 LabeledPoint 的 RDD。label 必须为 0 或 1.
      • categoricalFeaturesInfo:向量中为分类属性的索引表。任务没有出现在该列表中的特征将会以连续值处理。{n:k}表示第 n 个特征,是 0-k 的分类属性。
      • loss:损失函数,梯度提升计算时需要最小化的该函数。支持“logLoss” (默认), “leastSquaresError”, “leastAbsoluteError”
      • numIterations:提升的迭代次数,默认 100.
      • learningRate:学习率,取值(0,1]
      • maxDepth:树的最大深度
      • maxBins:每个特征分裂时,最大划分(桶)数量,默认 32
    • 方法:
      trainRegressor(data, categoricalFeaturesInfo, loss='leastSquaresError', numIterations=100, learningRate=0.1, maxDepth=3, maxBins=32)
      训练一个用于回归预测的梯度提升决策树模型。

      • data:训练数据集,格式为 LabeledPoint 的 RDD。label 为实数。
      • categoricalFeaturesInfo:向量中为分类属性的索引表。任务没有出现在该列表中的特征将会以连续值处理。{n:k}表示第 n 个特征,是 0-k 的分类属性。
      • loss:损失函数,梯度提升计算时需要最小化的该函数。支持“logLoss” (默认), “leastSquaresError”, “leastAbsoluteError”
      • numIterations:提升的迭代次数,默认 100.
      • learningRate:学习率,取值(0,1]
      • maxDepth:树的最大深度
      • maxBins:每个特征分裂时,最大划分(桶)数量,默认 32

  • 类: pyspark.mllib.tree.GradientBoostedTreesModel(java_model)

    • 方法: call(name, *a)
      调用 java 模型
    • 方法: load(sc, path)
      从指定 path 加载决策树模型
    • 方法: numTrees()
      获取随机森林中树的数量
    • 方法: predict(x)
      预测一个或多个样本的 label 值
    • 方法: save(sc, path)
      将决策树模型持久化到指定 path
    • 方法: toDebugString()
      以 string 输出整个模型的信息
    • 方法: totalNumNodes()
      获得森林中所有树的节点总和
  • Spark

    Spark 是 UC Berkeley AMP lab 所开源的类 Hadoop MapReduce 的通用并行框架。Spark 拥有 Hadoop MapReduce 所具有的优点;但不同于 MapReduce 的是 Job 中间输出结果可以保存在内存中,从而不再需要读写 HDFS,因此 Spark 能更好地适用于数据挖掘与机器学习等需要迭代的 MapReduce 的算法。

    74 引用 • 46 回帖 • 567 关注
  • 数据挖掘
    17 引用 • 32 回帖 • 3 关注

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...

推荐标签 标签

  • WiFiDog

    WiFiDog 是一套开源的无线热点认证管理工具,主要功能包括:位置相关的内容递送;用户认证和授权;集中式网络监控。

    1 引用 • 7 回帖 • 608 关注
  • danl
    164 关注
  • Electron

    Electron 基于 Chromium 和 Node.js,让你可以使用 HTML、CSS 和 JavaScript 构建应用。它是一个由 GitHub 及众多贡献者组成的活跃社区共同维护的开源项目,兼容 Mac、Windows 和 Linux,它构建的应用可在这三个操作系统上面运行。

    15 引用 • 136 回帖 • 8 关注
  • 机器学习

    机器学习(Machine Learning)是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。

    83 引用 • 37 回帖
  • RYMCU

    RYMCU 致力于打造一个即严谨又活泼、专业又不失有趣,为数百万人服务的开源嵌入式知识学习交流平台。

    4 引用 • 6 回帖 • 53 关注
  • Sublime

    Sublime Text 是一款可以用来写代码、写文章的文本编辑器。支持代码高亮、自动完成,还支持通过插件进行扩展。

    10 引用 • 5 回帖
  • V2EX

    V2EX 是创意工作者们的社区。这里目前汇聚了超过 400,000 名主要来自互联网行业、游戏行业和媒体行业的创意工作者。V2EX 希望能够成为创意工作者们的生活和事业的一部分。

    16 引用 • 236 回帖 • 272 关注
  • Office

    Office 现已更名为 Microsoft 365. Microsoft 365 将高级 Office 应用(如 Word、Excel 和 PowerPoint)与 1 TB 的 OneDrive 云存储空间、高级安全性等结合在一起,可帮助你在任何设备上完成操作。

    5 引用 • 34 回帖
  • sts
    2 引用 • 2 回帖 • 225 关注
  • 安全

    安全永远都不是一个小问题。

    203 引用 • 818 回帖
  • 开源中国

    开源中国是目前中国最大的开源技术社区。传播开源的理念,推广开源项目,为 IT 开发者提供了一个发现、使用、并交流开源技术的平台。目前开源中国社区已收录超过两万款开源软件。

    7 引用 • 86 回帖 • 1 关注
  • 大数据

    大数据(big data)是指无法在一定时间范围内用常规软件工具进行捕捉、管理和处理的数据集合,是需要新处理模式才能具有更强的决策力、洞察发现力和流程优化能力的海量、高增长率和多样化的信息资产。

    93 引用 • 113 回帖
  • Sandbox

    如果帖子标签含有 Sandbox ,则该帖子会被视为“测试帖”,主要用于测试社区功能,排查 bug 等,该标签下内容不定期进行清理。

    427 引用 • 1250 回帖 • 597 关注
  • Log4j

    Log4j 是 Apache 开源的一款使用广泛的 Java 日志组件。

    20 引用 • 18 回帖 • 32 关注
  • SOHO

    为成为自由职业者在家办公而努力吧!

    7 引用 • 55 回帖 • 4 关注
  • MyBatis

    MyBatis 本是 Apache 软件基金会 的一个开源项目 iBatis,2010 年这个项目由 Apache 软件基金会迁移到了 google code,并且改名为 MyBatis ,2013 年 11 月再次迁移到了 GitHub。

    173 引用 • 414 回帖 • 368 关注
  • OneDrive
    2 引用 • 3 关注
  • SendCloud

    SendCloud 由搜狐武汉研发中心孵化的项目,是致力于为开发者提供高质量的触发邮件服务的云端邮件发送平台,为开发者提供便利的 API 接口来调用服务,让邮件准确迅速到达用户收件箱并获得强大的追踪数据。

    2 引用 • 8 回帖 • 485 关注
  • Hadoop

    Hadoop 是由 Apache 基金会所开发的一个分布式系统基础架构。用户可以在不了解分布式底层细节的情况下,开发分布式程序。充分利用集群的威力进行高速运算和存储。

    87 引用 • 122 回帖 • 622 关注
  • JavaScript

    JavaScript 一种动态类型、弱类型、基于原型的直译式脚本语言,内置支持类型。它的解释器被称为 JavaScript 引擎,为浏览器的一部分,广泛用于客户端的脚本语言,最早是在 HTML 网页上使用,用来给 HTML 网页增加动态功能。

    729 引用 • 1278 回帖
  • FreeMarker

    FreeMarker 是一款好用且功能强大的 Java 模版引擎。

    23 引用 • 20 回帖 • 459 关注
  • AWS
    11 引用 • 28 回帖 • 10 关注
  • Windows

    Microsoft Windows 是美国微软公司研发的一套操作系统,它问世于 1985 年,起初仅仅是 Microsoft-DOS 模拟环境,后续的系统版本由于微软不断的更新升级,不但易用,也慢慢的成为家家户户人们最喜爱的操作系统。

    226 引用 • 476 回帖
  • 锤子科技

    锤子科技(Smartisan)成立于 2012 年 5 月,是一家制造移动互联网终端设备的公司,公司的使命是用完美主义的工匠精神,打造用户体验一流的数码消费类产品(智能手机为主),改善人们的生活质量。

    4 引用 • 31 回帖 • 6 关注
  • 自由行
    2 关注
  • FFmpeg

    FFmpeg 是一套可以用来记录、转换数字音频、视频,并能将其转化为流的开源计算机程序。

    23 引用 • 32 回帖 • 2 关注
  • 爬虫

    网络爬虫(Spider、Crawler),是一种按照一定的规则,自动地抓取万维网信息的程序。

    106 引用 • 275 回帖