概述
决策树算法是比较常规的多分类算法,常见的算法实现有 ID3,C45,CART 等算法,这里我们会采用 ID3 算法。
原理介绍
给定的数据集如下:
年龄是否大于 30 | 身高是否大于 170 | 收入情况(-1,0,1 分别表示低中高) | 是否交往 | |
---|---|---|---|---|
1 | 1 | 1 | 0 | |
1 | 1 | 0 | 0 | |
1 | 1 | -1 | 0 | |
1 | 0 | 1 | 0 | |
0 | 0 | 1 | 0 | |
0 | 1 | -1 | 0 | |
0 | 1 | 0 | 0 | |
0 | 1 | 1 | 1 | |
问题来了,如何如构造一个决策树,来自动帮助我们做辅助决策. |
代码实例
在实际构造程序之前我们先来各个击破遇到的问题:
问题一:选哪一个特征做当前的决策树的根节点
在有 N 多特征的时候,我们一般一定会选择区分度高的特征作为根节点,那么怎么去定义区分度呢。
这时候我们引入概率和信息熵的概率,一个简单的例子是,一个信息如果发生的概率为 1,那么这个信息对你来说它的价值是 0,就好比我告诉你明天地球还会公转一样。但是如果一个事情发生的概率为 0,但是我告诉你发生了,这个对你来说绝对是非常有用的信息。
发生的概率这里,我们会去统计每一个事件发生的次数/总次数。比如上面的数据集交往的概率为 1/8,不交往的概率为 7/8,那么信息是否有用信息熵来表示公式为:.那么计算对各种情况的熵求和
。一个随机变量的熵越大,其不确定性就越大,(不管是先验熵,后验熵,还是条件熵都是这样的)正确的估计其值的可能性就越小,越是不确定的随机变量越是需要更大的信息量来确定其值。 那么我们可以先计算一下上面数据集的信息熵,代码如下。
public static Object[][] createDataSet() {
Object[][] dataSet = {{1, 1, 1, 0}, {1, 1, 0, 0}, {1, 1, -1, 0}, {1, 0, 1, 0}, {0, 0, 1, 0}, {0, 1, -1, 0},
{0, 1, 0, 0}, {0, 1, 1, 1}};
return dataSet;
}
/**
* @param dataSet 采用二维数组表示,最后一位是label
* @return
*/
public static double calcEntropy(Object[][] dataSet) {
double result = 0;
int columnSize = dataSet[0].length;
int rows = dataSet.length;
Map<Object, AtomicInteger> maps = Maps.newHashMap();
for (int i = 0; i < rows; i++) {
Object label = dataSet[i][columnSize - 1];
if (!maps.containsKey(label)) {
maps.put(label, new AtomicInteger(0));
}
maps.get(label).incrementAndGet();
}
for (Map.Entry<Object, AtomicInteger> entry : maps.entrySet()) {
double prop = 1.0 * entry.getValue().get() / rows;
result += -1 * prop * (Math.log(prop) / Math.log(2));
}
return result;
}
上述数据集的信息熵依据公式计算结果为:0.5435644431995964
继续如何选择特性的问题,这个我们以年龄是否大于 30 为一个例子来说明如何选择,
年龄是否大于 30 | 交往 | 不交往 |
---|---|---|
1 | 0/4 | 4/4 |
0 | 1/4 | 3/4 |
计算年龄特征的信息熵为 -0log0-1log1 + (-1/4log1/4- 3/4log3/4) = 1.5
而上面我们计算出了基础的信息熵为 0.5435644431995964,相减得到-0.9564355568004036
我们可以采用类似的方法计算其他特征,相减后数值最小的就是目前最适合的特征。
/**
* 对数据集划分,按照指定的特性列的特征值做划分依据
*
* @param dataSet
* @param axis
* @param value
* @return
*/
public static Object[][] splitDataSet(Object[][] dataSet, int axis, Object value) {
int columnSize = dataSet[0].length;
int rows = dataSet.length;
List<Object[]> lists = Lists.newArrayList();
for (int i = 0; i < rows; i++) {
if (!dataSet[i][axis].equals(value)) {
continue;
}
Object[] row = new Object[columnSize-1];
int index = 0;
for (int j = 0; j < columnSize; j++) {
if(j==axis){
continue;
}
row[index] = dataSet[i][index];
index++;
}
lists.add(row);
}
Object[][] retDataSet = new Object[lists.size()][];
int i = 0;
for (Object[] row : lists) {
retDataSet[i++] = row;
}
return retDataSet;
}
public static int chooseBestFeature(Object[][] dataSet) {
int bestFeature = -1;
double baseEntropy = calcEntropy(dataSet);
int rows = dataSet.length;
int featureSize = dataSet[0].length - 1;
double infoGain = -1;
double bestInfoGain = -1;
for (int i = 0; i < featureSize; i++) {
//处理每一列feature
double newEntropy = 0.0;
Map<Object, AtomicInteger> maps = Maps.newHashMap();
for (int j = 0; j < rows; j++) {//计算当前列的唯一数值
if (!maps.containsKey(dataSet[j][i])) {
maps.put(dataSet[j][i], new AtomicInteger(0));
}
maps.get(dataSet[j][i]).incrementAndGet();
}
for (Object featVa : maps.keySet()) {
Object[][] subDataSet = splitDataSet(dataSet, i, featVa);
double prop = maps.get(featVa).get() * 1.0 / rows;
newEntropy += prop * calcEntropy(subDataSet);
}
infoGain = baseEntropy - newEntropy;
if (infoGain > bestInfoGain) {
bestFeature = i;
bestInfoGain = infoGain;
}
}
return bestFeature;
}
问题二:没有可选的特征
当我们不断的分裂数据集,分裂到没有特征的时候,怎么做结果呢?这里我们使用额简单的投票机制,出现次数最多的就是结果。
//投票
public static String vote(Object[][] dataSet) {
Map<Object,AtomicInteger> maps = Maps.newHashMap();
int rows = dataSet.length;
for(int i=0;i<rows;i++){
if(!maps.containsKey(dataSet[i][0])){
maps.put(dataSet[i][0],new AtomicInteger(0));
}
maps.get(dataSet[i][0]).incrementAndGet();
}
Object maxValue =null;
int max = -1;
//使用出现次数最多的那个标签
for(Map.Entry<Object,AtomicInteger> entry:maps.entrySet()){
if(entry.getValue().get()>max){
max = entry.getValue().get();
maxValue = entry.getKey();
}
}
return maxValue.toString();
}
最终程序(可以直接 run)
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import lombok.Data;
/**
* Created by hadoop on 17/10/31.
* 决策树算法
*/
public class Trees {
public static Object[][] createDataSet() {
Object[][] dataSet = {{1, 1, 1, 0}, {1, 1, 0, 0}, {1, 1, -1, 0}, {1, 0, 1, 0}, {0, 0, 1, 0}, {0, 1, -1, 0},
{0, 1, 0, 0}, {0, 1, 1, 1}};
return dataSet;
}
/**
* @param dataSet 采用二维数组表示,最后一位是label
* @return
*/
public static double calcEntropy(Object[][] dataSet) {
double result = 0;
int columnSize = dataSet[0].length;
int rows = dataSet.length;
Map<Object, AtomicInteger> maps = Maps.newHashMap();
for (int i = 0; i < rows; i++) {
Object label = dataSet[i][columnSize - 1];
if (!maps.containsKey(label)) {
maps.put(label, new AtomicInteger(0));
}
maps.get(label).incrementAndGet();
}
for (Map.Entry<Object, AtomicInteger> entry : maps.entrySet()) {
double prop = 1.0 * entry.getValue().get() / rows;
result += -1 * prop * (Math.log(prop) / Math.log(2));
}
return result;
}
/**
* 对数据集划分,按照指定的特性列的特征值做划分依据
*
* @param dataSet
* @param axis
* @param value
* @return
*/
public static Object[][] splitDataSet(Object[][] dataSet, int axis, Object value) {
int columnSize = dataSet[0].length;
int rows = dataSet.length;
List<Object[]> lists = Lists.newArrayList();
for (int i = 0; i < rows; i++) {
if (!dataSet[i][axis].equals(value)) {
continue;
}
Object[] row = new Object[columnSize-1];
int index = 0;
for (int j = 0; j < columnSize; j++) {
if(j==axis){
continue;
}
row[index] = dataSet[i][index];
index++;
}
lists.add(row);
}
Object[][] retDataSet = new Object[lists.size()][];
int i = 0;
for (Object[] row : lists) {
retDataSet[i++] = row;
}
return retDataSet;
}
public static int chooseBestFeature(Object[][] dataSet) {
int bestFeature = -1;
double baseEntropy = calcEntropy(dataSet);
int rows = dataSet.length;
int featureSize = dataSet[0].length - 1;
double infoGain = -1;
double bestInfoGain = -1;
for (int i = 0; i < featureSize; i++) {
//处理每一列feature
double newEntropy = 0.0;
Map<Object, AtomicInteger> maps = Maps.newHashMap();
for (int j = 0; j < rows; j++) {//计算当前列的唯一数值
if (!maps.containsKey(dataSet[j][i])) {
maps.put(dataSet[j][i], new AtomicInteger(0));
}
maps.get(dataSet[j][i]).incrementAndGet();
}
for (Object featVa : maps.keySet()) {
Object[][] subDataSet = splitDataSet(dataSet, i, featVa);
double prop = maps.get(featVa).get() * 1.0 / rows;
newEntropy += prop * calcEntropy(subDataSet);
}
infoGain = baseEntropy - newEntropy;
if (infoGain > bestInfoGain) {
bestFeature = i;
bestInfoGain = infoGain;
}
}
return bestFeature;
}
public static String vote(Object[][] dataSet) {
Map<Object,AtomicInteger> maps = Maps.newHashMap();
int rows = dataSet.length;
for(int i=0;i<rows;i++){
if(!maps.containsKey(dataSet[i][0])){
maps.put(dataSet[i][0],new AtomicInteger(0));
}
maps.get(dataSet[i][0]).incrementAndGet();
}
Object maxValue =null;
int max = -1;
for(Map.Entry<Object,AtomicInteger> entry:maps.entrySet()){
if(entry.getValue().get()>max){
max = entry.getValue().get();
maxValue = entry.getKey();
}
}
return maxValue.toString();
}
public static Object createDecisionTree(Object[][] dataSet, List<String> labels) {
int rows = dataSet.length;
int labelIndex = dataSet[0].length - 1;
Set<Object> sets = Sets.newHashSet();
for (int i = 0; i < rows; i++) {
sets.add(dataSet[i][labelIndex]);
}
if (sets.size() == rows) {
return dataSet[0][labelIndex];
}
if (dataSet[0].length == 1) {
return vote(dataSet);
}
int bestFeature = chooseBestFeature(dataSet);
List<String> copyLabels = Lists.newArrayList(labels);
String bestFeatureLabel = copyLabels.get(bestFeature);
copyLabels.remove(bestFeature);
Set<Object> bestFeatureValues = Sets.newHashSet();
for (int i = 0; i < rows; i++) {
bestFeatureValues.add(dataSet[i][bestFeature]);
}
DecisionTree decisionTree = new DecisionTree();
decisionTree.setAttributeName(bestFeatureLabel);
for (Object value : bestFeatureValues) {
List<String> subFeatureLabels = Lists.newArrayList(copyLabels);
decisionTree.children.put(value.toString(),
createDecisionTree(splitDataSet(dataSet, bestFeature, value), subFeatureLabels));
}
return decisionTree;
}
public static void main(String[] args) {
System.out.println(calcEntropy(createDataSet()));
Object result = createDecisionTree(createDataSet(),Lists.newArrayList("年龄是否大于30","身高是否大于170","收入情况","yes/no"));
System.out.println(result);
}
@Data
static class DecisionTree{
private String attributeName;
public Map<String, Object> children = Maps.newHashMap();
}
}
欢迎来到这里!
我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。
注册 关于