机器学习 (二): 决策树算法

本贴最后更新于 1945 天前,其中的信息可能已经沧海桑田

概述

决策树算法是比较常规的多分类算法,常见的算法实现有 ID3,C45,CART 等算法,这里我们会采用 ID3 算法。

原理介绍

给定的数据集如下:

年龄是否大于 30 身高是否大于 170 收入情况(-1,0,1 分别表示低中高) 是否交往
1 1 1 0
1 1 0 0
1 1 -1 0
1 0 1 0
0 0 1 0
0 1 -1 0
0 1 0 0
0 1 1 1
问题来了,如何如构造一个决策树,来自动帮助我们做辅助决策.

代码实例

在实际构造程序之前我们先来各个击破遇到的问题:

问题一:选哪一个特征做当前的决策树的根节点

在有 N 多特征的时候,我们一般一定会选择区分度高的特征作为根节点,那么怎么去定义区分度呢。
这时候我们引入概率和信息熵的概率,一个简单的例子是,一个信息如果发生的概率为 1,那么这个信息对你来说它的价值是 0,就好比我告诉你明天地球还会公转一样。但是如果一个事情发生的概率为 0,但是我告诉你发生了,这个对你来说绝对是非常有用的信息。
发生的概率这里,我们会去统计每一个事件发生的次数/总次数。比如上面的数据集交往的概率为 1/8,不交往的概率为 7/8,那么信息是否有用信息熵来表示公式为:此处输入图片的描述.那么计算对各种情况的熵求和此处输入图片的描述。一个随机变量的熵越大,其不确定性就越大,(不管是先验熵,后验熵,还是条件熵都是这样的)正确的估计其值的可能性就越小,越是不确定的随机变量越是需要更大的信息量来确定其值。 那么我们可以先计算一下上面数据集的信息熵,代码如下。


   public static Object[][] createDataSet() {

        Object[][] dataSet = {{1, 1, 1, 0}, {1, 1, 0, 0}, {1, 1, -1, 0}, {1, 0, 1, 0}, {0, 0, 1, 0}, {0, 1, -1, 0},
            {0, 1, 0, 0}, {0, 1, 1, 1}};
        return dataSet;
    }
    
    /**
     * @param dataSet 采用二维数组表示,最后一位是label
     * @return
     */
    public static double calcEntropy(Object[][] dataSet) {
        double result = 0;
        int columnSize = dataSet[0].length;
        int rows = dataSet.length;
        Map<Object, AtomicInteger> maps = Maps.newHashMap();
        for (int i = 0; i < rows; i++) {
            Object label = dataSet[i][columnSize - 1];
            if (!maps.containsKey(label)) {
                maps.put(label, new AtomicInteger(0));
            }
            maps.get(label).incrementAndGet();
        }

        for (Map.Entry<Object, AtomicInteger> entry : maps.entrySet()) {
            double prop = 1.0 * entry.getValue().get() / rows;
            result += -1 * prop * (Math.log(prop) / Math.log(2));

        }
        return result;
    }

上述数据集的信息熵依据公式计算结果为:0.5435644431995964
继续如何选择特性的问题,这个我们以年龄是否大于 30 为一个例子来说明如何选择,

年龄是否大于 30 交往 不交往
1 0/4 4/4
0 1/4 3/4

计算年龄特征的信息熵为 -0log0-1log1 + (-1/4log1/4- 3/4log3/4) = 1.5
而上面我们计算出了基础的信息熵为 0.5435644431995964,相减得到-0.9564355568004036
我们可以采用类似的方法计算其他特征,相减后数值最小的就是目前最适合的特征。


    /**
     * 对数据集划分,按照指定的特性列的特征值做划分依据
     *
     * @param dataSet
     * @param axis
     * @param value
     * @return
     */
    public static Object[][] splitDataSet(Object[][] dataSet, int axis, Object value) {
        int columnSize = dataSet[0].length;
        int rows = dataSet.length;

        List<Object[]> lists = Lists.newArrayList();
        for (int i = 0; i < rows; i++) {

            if (!dataSet[i][axis].equals(value)) {
                continue;
            }
            Object[] row = new Object[columnSize-1];
            int index = 0;
            for (int j = 0; j < columnSize; j++) {
                if(j==axis){
                    continue;
                }
                row[index] = dataSet[i][index];
                index++;
            }
            lists.add(row);
        }

        Object[][] retDataSet = new Object[lists.size()][];
        int i = 0;
        for (Object[] row : lists) {
            retDataSet[i++] = row;
        }

        return retDataSet;
    }

    public static int chooseBestFeature(Object[][] dataSet) {
        int bestFeature = -1;
        double baseEntropy = calcEntropy(dataSet);
        int rows = dataSet.length;
        int featureSize = dataSet[0].length - 1;
        double infoGain = -1;
        double bestInfoGain = -1;
        for (int i = 0; i < featureSize; i++) {
            //处理每一列feature
            double newEntropy = 0.0;
            Map<Object, AtomicInteger> maps = Maps.newHashMap();
            for (int j = 0; j < rows; j++) {//计算当前列的唯一数值
                if (!maps.containsKey(dataSet[j][i])) {
                    maps.put(dataSet[j][i], new AtomicInteger(0));
                }
                maps.get(dataSet[j][i]).incrementAndGet();
            }
            for (Object featVa : maps.keySet()) {
                Object[][] subDataSet = splitDataSet(dataSet, i, featVa);
                double prop = maps.get(featVa).get() * 1.0 / rows;
                newEntropy += prop * calcEntropy(subDataSet);
            }
            infoGain = baseEntropy - newEntropy;
            if (infoGain > bestInfoGain) {
                bestFeature = i;
                bestInfoGain = infoGain;
            }
        }

        return bestFeature;
    }

问题二:没有可选的特征

当我们不断的分裂数据集,分裂到没有特征的时候,怎么做结果呢?这里我们使用额简单的投票机制,出现次数最多的就是结果。

    //投票
    public static  String vote(Object[][] dataSet) {
        Map<Object,AtomicInteger> maps = Maps.newHashMap();

        int rows = dataSet.length;

        for(int i=0;i<rows;i++){
            if(!maps.containsKey(dataSet[i][0])){
                maps.put(dataSet[i][0],new AtomicInteger(0));
            }
            maps.get(dataSet[i][0]).incrementAndGet();
        }

        Object maxValue  =null;
        int max = -1;
        //使用出现次数最多的那个标签
        for(Map.Entry<Object,AtomicInteger> entry:maps.entrySet()){
            if(entry.getValue().get()>max){
                max = entry.getValue().get();
                maxValue = entry.getKey();
            }
        }
        return maxValue.toString();
    }

最终程序(可以直接 run)

import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import lombok.Data;

/**
 * Created by hadoop on 17/10/31.
 * 决策树算法
 */
public class Trees {

    public static Object[][] createDataSet() {

        Object[][] dataSet = {{1, 1, 1, 0}, {1, 1, 0, 0}, {1, 1, -1, 0}, {1, 0, 1, 0}, {0, 0, 1, 0}, {0, 1, -1, 0},
            {0, 1, 0, 0}, {0, 1, 1, 1}};
        return dataSet;
    }

    /**
     * @param dataSet 采用二维数组表示,最后一位是label
     * @return
     */
    public static double calcEntropy(Object[][] dataSet) {
        double result = 0;
        int columnSize = dataSet[0].length;
        int rows = dataSet.length;
        Map<Object, AtomicInteger> maps = Maps.newHashMap();
        for (int i = 0; i < rows; i++) {
            Object label = dataSet[i][columnSize - 1];
            if (!maps.containsKey(label)) {
                maps.put(label, new AtomicInteger(0));
            }
            maps.get(label).incrementAndGet();
        }

        for (Map.Entry<Object, AtomicInteger> entry : maps.entrySet()) {
            double prop = 1.0 * entry.getValue().get() / rows;
            result += -1 * prop * (Math.log(prop) / Math.log(2));

        }
        return result;
    }

    /**
     * 对数据集划分,按照指定的特性列的特征值做划分依据
     *
     * @param dataSet
     * @param axis
     * @param value
     * @return
     */
    public static Object[][] splitDataSet(Object[][] dataSet, int axis, Object value) {
        int columnSize = dataSet[0].length;
        int rows = dataSet.length;

        List<Object[]> lists = Lists.newArrayList();
        for (int i = 0; i < rows; i++) {

            if (!dataSet[i][axis].equals(value)) {
                continue;
            }
            Object[] row = new Object[columnSize-1];
            int index = 0;
            for (int j = 0; j < columnSize; j++) {
                if(j==axis){
                    continue;
                }
                row[index] = dataSet[i][index];
                index++;
            }
            lists.add(row);
        }

        Object[][] retDataSet = new Object[lists.size()][];
        int i = 0;
        for (Object[] row : lists) {
            retDataSet[i++] = row;
        }

        return retDataSet;
    }

    public static int chooseBestFeature(Object[][] dataSet) {
        int bestFeature = -1;
        double baseEntropy = calcEntropy(dataSet);
        int rows = dataSet.length;
        int featureSize = dataSet[0].length - 1;
        double infoGain = -1;
        double bestInfoGain = -1;
        for (int i = 0; i < featureSize; i++) {
            //处理每一列feature
            double newEntropy = 0.0;
            Map<Object, AtomicInteger> maps = Maps.newHashMap();
            for (int j = 0; j < rows; j++) {//计算当前列的唯一数值
                if (!maps.containsKey(dataSet[j][i])) {
                    maps.put(dataSet[j][i], new AtomicInteger(0));
                }
                maps.get(dataSet[j][i]).incrementAndGet();
            }
            for (Object featVa : maps.keySet()) {
                Object[][] subDataSet = splitDataSet(dataSet, i, featVa);
                double prop = maps.get(featVa).get() * 1.0 / rows;
                newEntropy += prop * calcEntropy(subDataSet);
            }
            infoGain = baseEntropy - newEntropy;
            if (infoGain > bestInfoGain) {
                bestFeature = i;
                bestInfoGain = infoGain;
            }
        }

        return bestFeature;
    }

    public static  String vote(Object[][] dataSet) {
        Map<Object,AtomicInteger> maps = Maps.newHashMap();

        int rows = dataSet.length;

        for(int i=0;i<rows;i++){
            if(!maps.containsKey(dataSet[i][0])){
                maps.put(dataSet[i][0],new AtomicInteger(0));
            }
            maps.get(dataSet[i][0]).incrementAndGet();
        }

        Object maxValue  =null;
        int max = -1;

        for(Map.Entry<Object,AtomicInteger> entry:maps.entrySet()){
            if(entry.getValue().get()>max){
                max = entry.getValue().get();
                maxValue = entry.getKey();
            }
        }
        return maxValue.toString();
    }

    public static Object createDecisionTree(Object[][] dataSet, List<String> labels) {

        int rows = dataSet.length;
        int labelIndex = dataSet[0].length - 1;

        Set<Object> sets = Sets.newHashSet();
        for (int i = 0; i < rows; i++) {
            sets.add(dataSet[i][labelIndex]);
        }

        if (sets.size() == rows) {
            return dataSet[0][labelIndex];
        }

        if (dataSet[0].length == 1) {
            return vote(dataSet);
        }

        int bestFeature = chooseBestFeature(dataSet);

        List<String> copyLabels = Lists.newArrayList(labels);

        String bestFeatureLabel = copyLabels.get(bestFeature);
        copyLabels.remove(bestFeature);
        Set<Object> bestFeatureValues = Sets.newHashSet();

        for (int i = 0; i < rows; i++) {
            bestFeatureValues.add(dataSet[i][bestFeature]);
        }

        DecisionTree decisionTree = new DecisionTree();
        decisionTree.setAttributeName(bestFeatureLabel);

        for (Object value : bestFeatureValues) {
            List<String> subFeatureLabels = Lists.newArrayList(copyLabels);
            decisionTree.children.put(value.toString(),
                createDecisionTree(splitDataSet(dataSet, bestFeature, value), subFeatureLabels));
        }

        return decisionTree;

    }

    public static void main(String[] args) {
        System.out.println(calcEntropy(createDataSet()));
        Object result = createDecisionTree(createDataSet(),Lists.newArrayList("年龄是否大于30","身高是否大于170","收入情况","yes/no"));
        System.out.println(result);
    }


    @Data
    static class DecisionTree{
        private String attributeName;
        public Map<String, Object> children = Maps.newHashMap();
    }

}

  • 决策树
    2 引用 • 1 回帖
  • 机器学习

    机器学习(Machine Learning)是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。

    76 引用 • 37 回帖 • 1 关注

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...