优化版 JAVA 最大熵模型(GIS 训练)

本贴最后更新于 2323 天前,其中的信息可能已经渤澥桑田

网上现有的最大熵模型,如:https://blog.csdn.net/nwpuwyk/article/details/37500371
该代码在训练环节性能较差,特征函数存储的结构也涉及较简单。
我在该版本基础上进行了改进,优化了特征函数的数据结构和训练代码。

     /**
   * 样本数据集
   */
  List<Instance> instanceList = new ArrayList<Instance>();
  /**
   * 特征列表,来自所有事件的统计结果
   */

// Map<String,Feature> featureMap=new HashMap<>();
/**
* 每个特征的出现次数
*/
//Map<String,Integer> featureCountMap=new HashMap<>();
/**
* 事件(类别)集
*/
List labels = new ArrayList();
/**
* 每个特征函数的权重
*/
// double[] weight;
Map<String,Weight> weightMap=new HashMap<>();
/**
* 一个事件最多一共有多少种特征
*/
double learningRate=10;
int C;
Map<String,List> testInstance;
/**
* 样本数据集 */ List instanceList = new ArrayList();
/**
* 特征列表,来自所有事件的统计结果 */// Map featureMap=new HashMap<>();
/**
* 每个特征的出现次数 */ //Map featureCountMap=new HashMap<>();
/**
* 事件(类别)集 */ List labels = new ArrayList();
/**
* 每个特征函数的权重 */ // double[] weight;
Map<String,Weight> weightMap=new HashMap<>();
/**
* 一个事件最多一共有多少种特征 */ double learningRate=10;
int C;

   * 训练模型 * @param maxIt 最大迭代次数
   */public void train(int maxIt,String savePath) throws IOException {
      Map,Double> empiricalE = new HashMap<>(); // 经验期望
    Map,Double> modelE = new HashMap<>(); // 模型期望

    for (Map.Entry,Weight> e:weightMap.entrySet())
      {
          double ratio=(double) e.getValue().getCnt() / instanceList.size();
    empiricalE.put(e.getKey(),ratio);
    }
      Map,Double> lastWeight=new HashMap<>();
   for (int i = 0; i < maxIt; ++i)
      {
          System.out.println("iter:"+i);
    computeModeE(modelE);//计算模型期望
    System.out.println("model finish.updating...");
   for (Map.Entry,Weight> e:weightMap.entrySet())
          {
             //lastWeight[w] = weight[w];
    lastWeight.put(e.getKey(),e.getValue().getWeight());
    String f=e.getKey();
   double delta=learningRate / C * Math.log(empiricalE.get(f)/ modelE.get(f));
    weightMap.get(f).addWeight(delta);
    }
          System.out.println("saving iter:"+i);
    learningRate*=0.99;
    learningRate=learningRate<10?10:learningRate;
    saveParam(savePath+"ent_insopt.par"+i);
   if (checkConvergence(lastWeight, weightMap)) break;

    }

  }

  /**
   * 预测类别 * @param fieldList
    * @return
    */
  public Pair, Double>[] predict(Map,Integer> fieldList)
  {
      double[] prob = calProb(fieldList);
    Pair, Double>[] pairResult = new Pair[prob.length];
   for (int i = 0; i < prob.length; ++i)
      {
          pairResult[i] = new Pair, Double>(labels.get(i), prob[i]);
    }

      return pairResult;
  }

  /**
   * 检查是否收敛 * @param w1
    * @param w2
    * @return 是否收敛
   */public boolean checkConvergence(Map,Double> w1, Map,Weight> w2)
  {
      System.out.println("w1 size:"+w1.size());
   boolean flag=true;
   for (Map.Entry,Double> e1:w1.entrySet())
      {
          //System.out.println("thread:"+Math.abs(e1.getValue() - w2.get(e1.getKey())) );
    if (Math.abs(e1.getValue() - w2.get(e1.getKey()).getWeight()) >= 1e-4)    // 收敛阀值0.01可自行调整
    flag=false;
    }
      return flag;
  }

  /**
   * 计算模型期望,即在当前的特征函数的权重下,计算特征函数的模型期望值。 * @param modelE 储存空间,应当事先分配好内存(之所以不return一个modelE是为了避免重复分配内存)
   */public void computeModeE(Map,Double> modelE)
  {
      modelE.clear();
   double rate=1.0 / instanceList.size();
   for (int i = 0; i < instanceList.size(); ++i)
      {

          Map,Integer> fieldMap = instanceList.get(i).fieldList;//no labels
   //计算当前样本X对应所有类别的概率  double[] pro = calProb(fieldMap);
   for (Map.Entry,Integer> e:fieldMap.entrySet())
          {
              String insFeature=e.getKey();/**
     * 训练模型
     * @param maxIt 最大迭代次数
     */
    public void train(int maxIt,String savePath) throws IOException {
        Map<String,Double> empiricalE = new HashMap<>();   // 经验期望
        Map<String,Double> modelE = new HashMap<>();       // 模型期望

        for (Map.Entry<String,Weight> e:weightMap.entrySet())
        {
            double ratio=(double) e.getValue().getCnt() / instanceList.size();
            empiricalE.put(e.getKey(),ratio);
        }
        Map<String,Double> lastWeight=new HashMap<>();
        for (int i = 0; i < maxIt; ++i)
        {
            System.out.println("iter:"+i);
            computeModeE(modelE);//计算模型期望
            System.out.println("model finish.updating...");
            for (Map.Entry<String,Weight> e:weightMap.entrySet())
            {
               //lastWeight[w] = weight[w];
                lastWeight.put(e.getKey(),e.getValue().getWeight());
                String f=e.getKey();
                double delta=learningRate / C * Math.log(empiricalE.get(f)/ modelE.get(f));
                weightMap.get(f).addWeight(delta);
            }
            System.out.println("saving iter:"+i);
            learningRate*=0.99;
            learningRate=learningRate<10?10:learningRate;
            saveParam(savePath+"ent_insopt.par"+i);
            if (checkConvergence(lastWeight, weightMap)) break;

        }

    }

    /**
     * 预测类别
     * @param fieldList
     * @return
     */
    public Pair<String, Double>[] predict(Map<String,Integer> fieldList)
    {
        double[] prob = calProb(fieldList);
        Pair<String, Double>[] pairResult = new Pair[prob.length];
        for (int i = 0; i < prob.length; ++i)
        {
            pairResult[i] = new Pair<String, Double>(labels.get(i), prob[i]);
        }

        return pairResult;
    }

    /**
     * 检查是否收敛
     * @param w1
     * @param w2
     * @return 是否收敛
     */
    public boolean checkConvergence(Map<String,Double> w1, Map<String,Weight> w2)
    {
        System.out.println("w1 size:"+w1.size());
        boolean flag=true;
        for (Map.Entry<String,Double> e1:w1.entrySet())
        {
            //System.out.println("thread:"+Math.abs(e1.getValue() - w2.get(e1.getKey())) );
            if (Math.abs(e1.getValue() - w2.get(e1.getKey()).getWeight()) >= 1e-4)    // 收敛阀值0.01可自行调整
                flag=false;
        }
        return flag;
    }

    /**
     * 计算模型期望,即在当前的特征函数的权重下,计算特征函数的模型期望值。
     * @param modelE 储存空间,应当事先分配好内存(之所以不return一个modelE是为了避免重复分配内存)
     */
    public void computeModeE(Map<String,Double> modelE)
    {
        modelE.clear();
        double rate=1.0 / instanceList.size();
        for (int i = 0; i < instanceList.size(); ++i)
        {

            Map<String,Integer> fieldMap = instanceList.get(i).fieldList;//no labels
             //计算当前样本X对应所有类别的概率
            double[] pro = calProb(fieldMap);
            for (Map.Entry<String,Integer> e:fieldMap.entrySet())
            {
                String insFeature=e.getKey();
                int cnt=e.getValue();
                for (int k = 0; k < labels.size(); k++)
                {
                    String feature=labels.get(k)+":"+insFeature;
                      if (weightMap.containsKey(feature)) {
                        double  delta=pro[k] * rate*cnt;
                        modelE.put(feature, modelE.containsKey(feature) ? modelE.get(feature) + delta : delta);
                    }
                }
            }
        }
    }
//    public class Mode implements Runnable
//    {
//        ConcurrentLinkedQueue<Integer> insQueue=new ConcurrentLinkedQueue<>();
//        boolean flag=true;
//        List<Instance> i∂
//        public void addIns(int i)
//        {
//
//        }
//
//        @Override
//        public void run() {
//            while(flag)
//            {
//                int ins=insQueue.poll();
//            }
//        }
//    }
    /**
     * 计算p(y|x),此时的x指的是instance里的field
     * @param fieldList 实例的特征列表
     * @return 该实例属于每个类别的概率
     */
    public double[] calProb(Map<String,Integer> fieldList)
    {
        double[] p = new double[labels.size()];
        double sum = 0;  // 正则化因子,保证概率和为1
        for (int i = 0; i < labels.size(); ++i)
        {
            double weightSum = 0;
            String label=labels.get(i);
            for (String field : fieldList.keySet())
            {
                String feature=label+":"+field;
                 if (weightMap.containsKey(feature)) {
                    weightSum += weightMap.get(feature).getWeight()*fieldList.get(field);
                }
            }
            if(weightSum>15)
            {
                weightSum=15;
            }
            p[i] = Math.exp(weightSum);

            sum += p[i];
        }
        //System.out.println();
        for (int i = 0; i < p.length; ++i)
        {
            p[i] /= sum;
//            if(Double.isNaN(p[i]))
//            {
//                System.out.println(p[i]);
//            }
        }
        return p;
    }

    /**
     * 一个观测实例,包含事件和时间发生的环境
     */
    class Instance implements Serializable
    {
        /**
         * 事件(类别),如Outdoor
         */
        String label;
        /**
         * 事件发生的环境集合,如[Sunny, Happy]
         */
        Map<String,Integer> fieldList = new HashMap<>();

        public Instance(String label, Map<String,Integer>fieldList)
        {
            this.label = label;
            this.fieldList = fieldList;
        }
    }

    /**
     * 特征(二值函数)
     */
    class Weight
    {
        double weight=0;
        int cnt=0;
        public void addWeight(double w)
        {
            weight+=(w);
        }
        public double getWeight() {
            return weight;
        }
        public void addCnt(int c)
        {
            cnt+=c;
        }
        public void setWeight(double weight) {
            this.weight = weight;
        }

        public int getCnt() {
            return cnt;
        }

        public void setCnt(int cnt) {
            this.cnt = cnt;
        }
    }
   int cnt=e.getValue();
   for (int k = 0; k < labels.size(); k++)
              {
                  String feature=labels.get(k)+":"+insFeature;
   if (weightMap.containsKey(feature)) {
                      double delta=pro[k] * rate*cnt;
    modelE.put(feature, modelE.containsKey(feature) ? modelE.get(feature) + delta : delta);
    }
              }
          }
      }
  }
  • 最大熵模型
    1 引用
  • Java

    Java 是一种可以撰写跨平台应用软件的面向对象的程序设计语言,是由 Sun Microsystems 公司于 1995 年 5 月推出的。Java 技术具有卓越的通用性、高效性、平台移植性和安全性。

    3190 引用 • 8214 回帖

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...

推荐标签 标签

  • BND

    BND(Baidu Netdisk Downloader)是一款图形界面的百度网盘不限速下载器,支持 Windows、Linux 和 Mac,详细介绍请看这里

    107 引用 • 1281 回帖 • 37 关注
  • uTools

    uTools 是一个极简、插件化、跨平台的现代桌面软件。通过自由选配丰富的插件,打造你得心应手的工具集合。

    6 引用 • 14 回帖
  • Sandbox

    如果帖子标签含有 Sandbox ,则该帖子会被视为“测试帖”,主要用于测试社区功能,排查 bug 等,该标签下内容不定期进行清理。

    410 引用 • 1246 回帖 • 587 关注
  • 导航

    各种网址链接、内容导航。

    42 引用 • 175 回帖
  • 生活

    生活是指人类生存过程中的各项活动的总和,范畴较广,一般指为幸福的意义而存在。生活实际上是对人生的一种诠释。生活包括人类在社会中与自己息息相关的日常活动和心理影射。

    230 引用 • 1454 回帖 • 1 关注
  • AngularJS

    AngularJS 诞生于 2009 年,由 Misko Hevery 等人创建,后为 Google 所收购。是一款优秀的前端 JS 框架,已经被用于 Google 的多款产品当中。AngularJS 有着诸多特性,最为核心的是:MVC、模块化、自动化双向数据绑定、语义化标签、依赖注入等。2.0 版本后已经改名为 Angular。

    12 引用 • 50 回帖 • 484 关注
  • 笔记

    好记性不如烂笔头。

    308 引用 • 793 回帖
  • CSS

    CSS(Cascading Style Sheet)“层叠样式表”是用于控制网页样式并允许将样式信息与网页内容分离的一种标记性语言。

    196 引用 • 540 回帖 • 1 关注
  • 30Seconds

    📙 前端知识精选集,包含 HTML、CSS、JavaScript、React、Node、安全等方面,每天仅需 30 秒。

    • 精选常见面试题,帮助您准备下一次面试
    • 精选常见交互,帮助您拥有简洁酷炫的站点
    • 精选有用的 React 片段,帮助你获取最佳实践
    • 精选常见代码集,帮助您提高打码效率
    • 整理前端界的最新资讯,邀您一同探索新世界
    488 引用 • 384 回帖
  • Python

    Python 是一种面向对象、直译式电脑编程语言,具有近二十年的发展历史,成熟且稳定。它包含了一组完善而且容易理解的标准库,能够轻松完成很多常见的任务。它的语法简捷和清晰,尽量使用无异义的英语单词,与其它大多数程序设计语言使用大括号不一样,它使用缩进来定义语句块。

    546 引用 • 672 回帖
  • CAP

    CAP 指的是在一个分布式系统中, Consistency(一致性)、 Availability(可用性)、Partition tolerance(分区容错性),三者不可兼得。

    11 引用 • 5 回帖 • 614 关注
  • H2

    H2 是一个开源的嵌入式数据库引擎,采用 Java 语言编写,不受平台的限制,同时 H2 提供了一个十分方便的 web 控制台用于操作和管理数据库内容。H2 还提供兼容模式,可以兼容一些主流的数据库,因此采用 H2 作为开发期的数据库非常方便。

    11 引用 • 54 回帖 • 656 关注
  • JavaScript

    JavaScript 一种动态类型、弱类型、基于原型的直译式脚本语言,内置支持类型。它的解释器被称为 JavaScript 引擎,为浏览器的一部分,广泛用于客户端的脚本语言,最早是在 HTML 网页上使用,用来给 HTML 网页增加动态功能。

    728 引用 • 1275 回帖
  • FFmpeg

    FFmpeg 是一套可以用来记录、转换数字音频、视频,并能将其转化为流的开源计算机程序。

    23 引用 • 32 回帖 • 1 关注
  • IBM

    IBM(国际商业机器公司)或万国商业机器公司,简称 IBM(International Business Machines Corporation),总公司在纽约州阿蒙克市。1911 年托马斯·沃森创立于美国,是全球最大的信息技术和业务解决方案公司,拥有全球雇员 30 多万人,业务遍及 160 多个国家和地区。

    17 引用 • 53 回帖 • 140 关注
  • 招聘

    哪里都缺人,哪里都不缺人。

    190 引用 • 1057 回帖
  • HTML

    HTML5 是 HTML 下一个的主要修订版本,现在仍处于发展阶段。广义论及 HTML5 时,实际指的是包括 HTML、CSS 和 JavaScript 在内的一套技术组合。

    107 引用 • 295 回帖
  • 链书

    链书(Chainbook)是 B3log 开源社区提供的区块链纸质书交易平台,通过 B3T 实现共享激励与价值链。可将你的闲置书籍上架到链书,我们共同构建这个全新的交易平台,让闲置书籍继续发挥它的价值。

    链书社

    链书目前已经下线,也许以后还有计划重制上线。

    14 引用 • 257 回帖
  • Linux

    Linux 是一套免费使用和自由传播的类 Unix 操作系统,是一个基于 POSIX 和 Unix 的多用户、多任务、支持多线程和多 CPU 的操作系统。它能运行主要的 Unix 工具软件、应用程序和网络协议,并支持 32 位和 64 位硬件。Linux 继承了 Unix 以网络为核心的设计思想,是一个性能稳定的多用户网络操作系统。

    946 引用 • 943 回帖 • 2 关注
  • Sphinx

    Sphinx 是一个基于 SQL 的全文检索引擎,可以结合 MySQL、PostgreSQL 做全文搜索,它可以提供比数据库本身更专业的搜索功能,使得应用程序更容易实现专业化的全文检索。

    1 引用 • 221 关注
  • SQLServer

    SQL Server 是由 [微软] 开发和推广的关系数据库管理系统(DBMS),它最初是由 微软、Sybase 和 Ashton-Tate 三家公司共同开发的,并于 1988 年推出了第一个 OS/2 版本。

    21 引用 • 31 回帖 • 2 关注
  • Eclipse

    Eclipse 是一个开放源代码的、基于 Java 的可扩展开发平台。就其本身而言,它只是一个框架和一组服务,用于通过插件组件构建开发环境。

    75 引用 • 258 回帖 • 623 关注
  • FreeMarker

    FreeMarker 是一款好用且功能强大的 Java 模版引擎。

    23 引用 • 20 回帖 • 465 关注
  • OpenResty

    OpenResty 是一个基于 NGINX 与 Lua 的高性能 Web 平台,其内部集成了大量精良的 Lua 库、第三方模块以及大多数的依赖项。用于方便地搭建能够处理超高并发、扩展性极高的动态 Web 应用、Web 服务和动态网关。

    17 引用 • 41 关注
  • flomo

    flomo 是新一代 「卡片笔记」 ,专注在碎片化时代,促进你的记录,帮你积累更多知识资产。

    5 引用 • 107 回帖 • 1 关注
  • Ubuntu

    Ubuntu(友帮拓、优般图、乌班图)是一个以桌面应用为主的 Linux 操作系统,其名称来自非洲南部祖鲁语或豪萨语的“ubuntu”一词,意思是“人性”、“我的存在是因为大家的存在”,是非洲传统的一种价值观,类似华人社会的“仁爱”思想。Ubuntu 的目标在于为一般用户提供一个最新的、同时又相当稳定的主要由自由软件构建而成的操作系统。

    126 引用 • 169 回帖
  • Dubbo

    Dubbo 是一个分布式服务框架,致力于提供高性能和透明化的 RPC 远程服务调用方案,是 [阿里巴巴] SOA 服务化治理方案的核心框架,每天为 2,000+ 个服务提供 3,000,000,000+ 次访问量支持,并被广泛应用于阿里巴巴集团的各成员站点。

    60 引用 • 82 回帖 • 604 关注