机器学习 (二): 决策树算法

本贴最后更新于 2290 天前,其中的信息可能已经沧海桑田

概述

决策树算法是比较常规的多分类算法,常见的算法实现有 ID3,C45,CART 等算法,这里我们会采用 ID3 算法。

原理介绍

给定的数据集如下:

年龄是否大于 30 身高是否大于 170 收入情况(-1,0,1 分别表示低中高) 是否交往
1 1 1 0
1 1 0 0
1 1 -1 0
1 0 1 0
0 0 1 0
0 1 -1 0
0 1 0 0
0 1 1 1
问题来了,如何如构造一个决策树,来自动帮助我们做辅助决策.

代码实例

在实际构造程序之前我们先来各个击破遇到的问题:

问题一:选哪一个特征做当前的决策树的根节点

在有 N 多特征的时候,我们一般一定会选择区分度高的特征作为根节点,那么怎么去定义区分度呢。
这时候我们引入概率和信息熵的概率,一个简单的例子是,一个信息如果发生的概率为 1,那么这个信息对你来说它的价值是 0,就好比我告诉你明天地球还会公转一样。但是如果一个事情发生的概率为 0,但是我告诉你发生了,这个对你来说绝对是非常有用的信息。
发生的概率这里,我们会去统计每一个事件发生的次数/总次数。比如上面的数据集交往的概率为 1/8,不交往的概率为 7/8,那么信息是否有用信息熵来表示公式为:此处输入图片的描述.那么计算对各种情况的熵求和此处输入图片的描述。一个随机变量的熵越大,其不确定性就越大,(不管是先验熵,后验熵,还是条件熵都是这样的)正确的估计其值的可能性就越小,越是不确定的随机变量越是需要更大的信息量来确定其值。 那么我们可以先计算一下上面数据集的信息熵,代码如下。

public static Object[][] createDataSet() { Object[][] dataSet = {{1, 1, 1, 0}, {1, 1, 0, 0}, {1, 1, -1, 0}, {1, 0, 1, 0}, {0, 0, 1, 0}, {0, 1, -1, 0}, {0, 1, 0, 0}, {0, 1, 1, 1}}; return dataSet; } /** * @param dataSet 采用二维数组表示,最后一位是label * @return */ public static double calcEntropy(Object[][] dataSet) { double result = 0; int columnSize = dataSet[0].length; int rows = dataSet.length; Map<Object, AtomicInteger> maps = Maps.newHashMap(); for (int i = 0; i < rows; i++) { Object label = dataSet[i][columnSize - 1]; if (!maps.containsKey(label)) { maps.put(label, new AtomicInteger(0)); } maps.get(label).incrementAndGet(); } for (Map.Entry<Object, AtomicInteger> entry : maps.entrySet()) { double prop = 1.0 * entry.getValue().get() / rows; result += -1 * prop * (Math.log(prop) / Math.log(2)); } return result; }

上述数据集的信息熵依据公式计算结果为:0.5435644431995964
继续如何选择特性的问题,这个我们以年龄是否大于 30 为一个例子来说明如何选择,

年龄是否大于 30 交往 不交往
1 0/4 4/4
0 1/4 3/4

计算年龄特征的信息熵为 -0log0-1log1 + (-1/4log1/4- 3/4log3/4) = 1.5
而上面我们计算出了基础的信息熵为 0.5435644431995964,相减得到-0.9564355568004036
我们可以采用类似的方法计算其他特征,相减后数值最小的就是目前最适合的特征。

/** * 对数据集划分,按照指定的特性列的特征值做划分依据 * * @param dataSet * @param axis * @param value * @return */ public static Object[][] splitDataSet(Object[][] dataSet, int axis, Object value) { int columnSize = dataSet[0].length; int rows = dataSet.length; List<Object[]> lists = Lists.newArrayList(); for (int i = 0; i < rows; i++) { if (!dataSet[i][axis].equals(value)) { continue; } Object[] row = new Object[columnSize-1]; int index = 0; for (int j = 0; j < columnSize; j++) { if(j==axis){ continue; } row[index] = dataSet[i][index]; index++; } lists.add(row); } Object[][] retDataSet = new Object[lists.size()][]; int i = 0; for (Object[] row : lists) { retDataSet[i++] = row; } return retDataSet; } public static int chooseBestFeature(Object[][] dataSet) { int bestFeature = -1; double baseEntropy = calcEntropy(dataSet); int rows = dataSet.length; int featureSize = dataSet[0].length - 1; double infoGain = -1; double bestInfoGain = -1; for (int i = 0; i < featureSize; i++) { //处理每一列feature double newEntropy = 0.0; Map<Object, AtomicInteger> maps = Maps.newHashMap(); for (int j = 0; j < rows; j++) {//计算当前列的唯一数值 if (!maps.containsKey(dataSet[j][i])) { maps.put(dataSet[j][i], new AtomicInteger(0)); } maps.get(dataSet[j][i]).incrementAndGet(); } for (Object featVa : maps.keySet()) { Object[][] subDataSet = splitDataSet(dataSet, i, featVa); double prop = maps.get(featVa).get() * 1.0 / rows; newEntropy += prop * calcEntropy(subDataSet); } infoGain = baseEntropy - newEntropy; if (infoGain > bestInfoGain) { bestFeature = i; bestInfoGain = infoGain; } } return bestFeature; }

问题二:没有可选的特征

当我们不断的分裂数据集,分裂到没有特征的时候,怎么做结果呢?这里我们使用额简单的投票机制,出现次数最多的就是结果。

//投票 public static String vote(Object[][] dataSet) { Map<Object,AtomicInteger> maps = Maps.newHashMap(); int rows = dataSet.length; for(int i=0;i<rows;i++){ if(!maps.containsKey(dataSet[i][0])){ maps.put(dataSet[i][0],new AtomicInteger(0)); } maps.get(dataSet[i][0]).incrementAndGet(); } Object maxValue =null; int max = -1; //使用出现次数最多的那个标签 for(Map.Entry<Object,AtomicInteger> entry:maps.entrySet()){ if(entry.getValue().get()>max){ max = entry.getValue().get(); maxValue = entry.getKey(); } } return maxValue.toString(); }

最终程序(可以直接 run)

import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import lombok.Data; /** * Created by hadoop on 17/10/31. * 决策树算法 */ public class Trees { public static Object[][] createDataSet() { Object[][] dataSet = {{1, 1, 1, 0}, {1, 1, 0, 0}, {1, 1, -1, 0}, {1, 0, 1, 0}, {0, 0, 1, 0}, {0, 1, -1, 0}, {0, 1, 0, 0}, {0, 1, 1, 1}}; return dataSet; } /** * @param dataSet 采用二维数组表示,最后一位是label * @return */ public static double calcEntropy(Object[][] dataSet) { double result = 0; int columnSize = dataSet[0].length; int rows = dataSet.length; Map<Object, AtomicInteger> maps = Maps.newHashMap(); for (int i = 0; i < rows; i++) { Object label = dataSet[i][columnSize - 1]; if (!maps.containsKey(label)) { maps.put(label, new AtomicInteger(0)); } maps.get(label).incrementAndGet(); } for (Map.Entry<Object, AtomicInteger> entry : maps.entrySet()) { double prop = 1.0 * entry.getValue().get() / rows; result += -1 * prop * (Math.log(prop) / Math.log(2)); } return result; } /** * 对数据集划分,按照指定的特性列的特征值做划分依据 * * @param dataSet * @param axis * @param value * @return */ public static Object[][] splitDataSet(Object[][] dataSet, int axis, Object value) { int columnSize = dataSet[0].length; int rows = dataSet.length; List<Object[]> lists = Lists.newArrayList(); for (int i = 0; i < rows; i++) { if (!dataSet[i][axis].equals(value)) { continue; } Object[] row = new Object[columnSize-1]; int index = 0; for (int j = 0; j < columnSize; j++) { if(j==axis){ continue; } row[index] = dataSet[i][index]; index++; } lists.add(row); } Object[][] retDataSet = new Object[lists.size()][]; int i = 0; for (Object[] row : lists) { retDataSet[i++] = row; } return retDataSet; } public static int chooseBestFeature(Object[][] dataSet) { int bestFeature = -1; double baseEntropy = calcEntropy(dataSet); int rows = dataSet.length; int featureSize = dataSet[0].length - 1; double infoGain = -1; double bestInfoGain = -1; for (int i = 0; i < featureSize; i++) { //处理每一列feature double newEntropy = 0.0; Map<Object, AtomicInteger> maps = Maps.newHashMap(); for (int j = 0; j < rows; j++) {//计算当前列的唯一数值 if (!maps.containsKey(dataSet[j][i])) { maps.put(dataSet[j][i], new AtomicInteger(0)); } maps.get(dataSet[j][i]).incrementAndGet(); } for (Object featVa : maps.keySet()) { Object[][] subDataSet = splitDataSet(dataSet, i, featVa); double prop = maps.get(featVa).get() * 1.0 / rows; newEntropy += prop * calcEntropy(subDataSet); } infoGain = baseEntropy - newEntropy; if (infoGain > bestInfoGain) { bestFeature = i; bestInfoGain = infoGain; } } return bestFeature; } public static String vote(Object[][] dataSet) { Map<Object,AtomicInteger> maps = Maps.newHashMap(); int rows = dataSet.length; for(int i=0;i<rows;i++){ if(!maps.containsKey(dataSet[i][0])){ maps.put(dataSet[i][0],new AtomicInteger(0)); } maps.get(dataSet[i][0]).incrementAndGet(); } Object maxValue =null; int max = -1; for(Map.Entry<Object,AtomicInteger> entry:maps.entrySet()){ if(entry.getValue().get()>max){ max = entry.getValue().get(); maxValue = entry.getKey(); } } return maxValue.toString(); } public static Object createDecisionTree(Object[][] dataSet, List<String> labels) { int rows = dataSet.length; int labelIndex = dataSet[0].length - 1; Set<Object> sets = Sets.newHashSet(); for (int i = 0; i < rows; i++) { sets.add(dataSet[i][labelIndex]); } if (sets.size() == rows) { return dataSet[0][labelIndex]; } if (dataSet[0].length == 1) { return vote(dataSet); } int bestFeature = chooseBestFeature(dataSet); List<String> copyLabels = Lists.newArrayList(labels); String bestFeatureLabel = copyLabels.get(bestFeature); copyLabels.remove(bestFeature); Set<Object> bestFeatureValues = Sets.newHashSet(); for (int i = 0; i < rows; i++) { bestFeatureValues.add(dataSet[i][bestFeature]); } DecisionTree decisionTree = new DecisionTree(); decisionTree.setAttributeName(bestFeatureLabel); for (Object value : bestFeatureValues) { List<String> subFeatureLabels = Lists.newArrayList(copyLabels); decisionTree.children.put(value.toString(), createDecisionTree(splitDataSet(dataSet, bestFeature, value), subFeatureLabels)); } return decisionTree; } public static void main(String[] args) { System.out.println(calcEntropy(createDataSet())); Object result = createDecisionTree(createDataSet(),Lists.newArrayList("年龄是否大于30","身高是否大于170","收入情况","yes/no")); System.out.println(result); } @Data static class DecisionTree{ private String attributeName; public Map<String, Object> children = Maps.newHashMap(); } }
  • 决策树
    2 引用 • 1 回帖
  • 机器学习

    机器学习(Machine Learning)是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。

    83 引用 • 37 回帖

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...