优化版 JAVA 最大熵模型(GIS 训练)

本贴最后更新于 2280 天前,其中的信息可能已经渤澥桑田

网上现有的最大熵模型,如:https://blog.csdn.net/nwpuwyk/article/details/37500371
该代码在训练环节性能较差,特征函数存储的结构也涉及较简单。
我在该版本基础上进行了改进,优化了特征函数的数据结构和训练代码。

     /**
   * 样本数据集
   */
  List<Instance> instanceList = new ArrayList<Instance>();
  /**
   * 特征列表,来自所有事件的统计结果
   */

// Map<String,Feature> featureMap=new HashMap<>();
/**
* 每个特征的出现次数
*/
//Map<String,Integer> featureCountMap=new HashMap<>();
/**
* 事件(类别)集
*/
List labels = new ArrayList();
/**
* 每个特征函数的权重
*/
// double[] weight;
Map<String,Weight> weightMap=new HashMap<>();
/**
* 一个事件最多一共有多少种特征
*/
double learningRate=10;
int C;
Map<String,List> testInstance;
/**
* 样本数据集 */ List instanceList = new ArrayList();
/**
* 特征列表,来自所有事件的统计结果 */// Map featureMap=new HashMap<>();
/**
* 每个特征的出现次数 */ //Map featureCountMap=new HashMap<>();
/**
* 事件(类别)集 */ List labels = new ArrayList();
/**
* 每个特征函数的权重 */ // double[] weight;
Map<String,Weight> weightMap=new HashMap<>();
/**
* 一个事件最多一共有多少种特征 */ double learningRate=10;
int C;

   * 训练模型 * @param maxIt 最大迭代次数
   */public void train(int maxIt,String savePath) throws IOException {
      Map,Double> empiricalE = new HashMap<>(); // 经验期望
    Map,Double> modelE = new HashMap<>(); // 模型期望

    for (Map.Entry,Weight> e:weightMap.entrySet())
      {
          double ratio=(double) e.getValue().getCnt() / instanceList.size();
    empiricalE.put(e.getKey(),ratio);
    }
      Map,Double> lastWeight=new HashMap<>();
   for (int i = 0; i < maxIt; ++i)
      {
          System.out.println("iter:"+i);
    computeModeE(modelE);//计算模型期望
    System.out.println("model finish.updating...");
   for (Map.Entry,Weight> e:weightMap.entrySet())
          {
             //lastWeight[w] = weight[w];
    lastWeight.put(e.getKey(),e.getValue().getWeight());
    String f=e.getKey();
   double delta=learningRate / C * Math.log(empiricalE.get(f)/ modelE.get(f));
    weightMap.get(f).addWeight(delta);
    }
          System.out.println("saving iter:"+i);
    learningRate*=0.99;
    learningRate=learningRate<10?10:learningRate;
    saveParam(savePath+"ent_insopt.par"+i);
   if (checkConvergence(lastWeight, weightMap)) break;

    }

  }

  /**
   * 预测类别 * @param fieldList
    * @return
    */
  public Pair, Double>[] predict(Map,Integer> fieldList)
  {
      double[] prob = calProb(fieldList);
    Pair, Double>[] pairResult = new Pair[prob.length];
   for (int i = 0; i < prob.length; ++i)
      {
          pairResult[i] = new Pair, Double>(labels.get(i), prob[i]);
    }

      return pairResult;
  }

  /**
   * 检查是否收敛 * @param w1
    * @param w2
    * @return 是否收敛
   */public boolean checkConvergence(Map,Double> w1, Map,Weight> w2)
  {
      System.out.println("w1 size:"+w1.size());
   boolean flag=true;
   for (Map.Entry,Double> e1:w1.entrySet())
      {
          //System.out.println("thread:"+Math.abs(e1.getValue() - w2.get(e1.getKey())) );
    if (Math.abs(e1.getValue() - w2.get(e1.getKey()).getWeight()) >= 1e-4)    // 收敛阀值0.01可自行调整
    flag=false;
    }
      return flag;
  }

  /**
   * 计算模型期望,即在当前的特征函数的权重下,计算特征函数的模型期望值。 * @param modelE 储存空间,应当事先分配好内存(之所以不return一个modelE是为了避免重复分配内存)
   */public void computeModeE(Map,Double> modelE)
  {
      modelE.clear();
   double rate=1.0 / instanceList.size();
   for (int i = 0; i < instanceList.size(); ++i)
      {

          Map,Integer> fieldMap = instanceList.get(i).fieldList;//no labels
   //计算当前样本X对应所有类别的概率  double[] pro = calProb(fieldMap);
   for (Map.Entry,Integer> e:fieldMap.entrySet())
          {
              String insFeature=e.getKey();/**
     * 训练模型
     * @param maxIt 最大迭代次数
     */
    public void train(int maxIt,String savePath) throws IOException {
        Map<String,Double> empiricalE = new HashMap<>();   // 经验期望
        Map<String,Double> modelE = new HashMap<>();       // 模型期望

        for (Map.Entry<String,Weight> e:weightMap.entrySet())
        {
            double ratio=(double) e.getValue().getCnt() / instanceList.size();
            empiricalE.put(e.getKey(),ratio);
        }
        Map<String,Double> lastWeight=new HashMap<>();
        for (int i = 0; i < maxIt; ++i)
        {
            System.out.println("iter:"+i);
            computeModeE(modelE);//计算模型期望
            System.out.println("model finish.updating...");
            for (Map.Entry<String,Weight> e:weightMap.entrySet())
            {
               //lastWeight[w] = weight[w];
                lastWeight.put(e.getKey(),e.getValue().getWeight());
                String f=e.getKey();
                double delta=learningRate / C * Math.log(empiricalE.get(f)/ modelE.get(f));
                weightMap.get(f).addWeight(delta);
            }
            System.out.println("saving iter:"+i);
            learningRate*=0.99;
            learningRate=learningRate<10?10:learningRate;
            saveParam(savePath+"ent_insopt.par"+i);
            if (checkConvergence(lastWeight, weightMap)) break;

        }

    }

    /**
     * 预测类别
     * @param fieldList
     * @return
     */
    public Pair<String, Double>[] predict(Map<String,Integer> fieldList)
    {
        double[] prob = calProb(fieldList);
        Pair<String, Double>[] pairResult = new Pair[prob.length];
        for (int i = 0; i < prob.length; ++i)
        {
            pairResult[i] = new Pair<String, Double>(labels.get(i), prob[i]);
        }

        return pairResult;
    }

    /**
     * 检查是否收敛
     * @param w1
     * @param w2
     * @return 是否收敛
     */
    public boolean checkConvergence(Map<String,Double> w1, Map<String,Weight> w2)
    {
        System.out.println("w1 size:"+w1.size());
        boolean flag=true;
        for (Map.Entry<String,Double> e1:w1.entrySet())
        {
            //System.out.println("thread:"+Math.abs(e1.getValue() - w2.get(e1.getKey())) );
            if (Math.abs(e1.getValue() - w2.get(e1.getKey()).getWeight()) >= 1e-4)    // 收敛阀值0.01可自行调整
                flag=false;
        }
        return flag;
    }

    /**
     * 计算模型期望,即在当前的特征函数的权重下,计算特征函数的模型期望值。
     * @param modelE 储存空间,应当事先分配好内存(之所以不return一个modelE是为了避免重复分配内存)
     */
    public void computeModeE(Map<String,Double> modelE)
    {
        modelE.clear();
        double rate=1.0 / instanceList.size();
        for (int i = 0; i < instanceList.size(); ++i)
        {

            Map<String,Integer> fieldMap = instanceList.get(i).fieldList;//no labels
             //计算当前样本X对应所有类别的概率
            double[] pro = calProb(fieldMap);
            for (Map.Entry<String,Integer> e:fieldMap.entrySet())
            {
                String insFeature=e.getKey();
                int cnt=e.getValue();
                for (int k = 0; k < labels.size(); k++)
                {
                    String feature=labels.get(k)+":"+insFeature;
                      if (weightMap.containsKey(feature)) {
                        double  delta=pro[k] * rate*cnt;
                        modelE.put(feature, modelE.containsKey(feature) ? modelE.get(feature) + delta : delta);
                    }
                }
            }
        }
    }
//    public class Mode implements Runnable
//    {
//        ConcurrentLinkedQueue<Integer> insQueue=new ConcurrentLinkedQueue<>();
//        boolean flag=true;
//        List<Instance> i∂
//        public void addIns(int i)
//        {
//
//        }
//
//        @Override
//        public void run() {
//            while(flag)
//            {
//                int ins=insQueue.poll();
//            }
//        }
//    }
    /**
     * 计算p(y|x),此时的x指的是instance里的field
     * @param fieldList 实例的特征列表
     * @return 该实例属于每个类别的概率
     */
    public double[] calProb(Map<String,Integer> fieldList)
    {
        double[] p = new double[labels.size()];
        double sum = 0;  // 正则化因子,保证概率和为1
        for (int i = 0; i < labels.size(); ++i)
        {
            double weightSum = 0;
            String label=labels.get(i);
            for (String field : fieldList.keySet())
            {
                String feature=label+":"+field;
                 if (weightMap.containsKey(feature)) {
                    weightSum += weightMap.get(feature).getWeight()*fieldList.get(field);
                }
            }
            if(weightSum>15)
            {
                weightSum=15;
            }
            p[i] = Math.exp(weightSum);

            sum += p[i];
        }
        //System.out.println();
        for (int i = 0; i < p.length; ++i)
        {
            p[i] /= sum;
//            if(Double.isNaN(p[i]))
//            {
//                System.out.println(p[i]);
//            }
        }
        return p;
    }

    /**
     * 一个观测实例,包含事件和时间发生的环境
     */
    class Instance implements Serializable
    {
        /**
         * 事件(类别),如Outdoor
         */
        String label;
        /**
         * 事件发生的环境集合,如[Sunny, Happy]
         */
        Map<String,Integer> fieldList = new HashMap<>();

        public Instance(String label, Map<String,Integer>fieldList)
        {
            this.label = label;
            this.fieldList = fieldList;
        }
    }

    /**
     * 特征(二值函数)
     */
    class Weight
    {
        double weight=0;
        int cnt=0;
        public void addWeight(double w)
        {
            weight+=(w);
        }
        public double getWeight() {
            return weight;
        }
        public void addCnt(int c)
        {
            cnt+=c;
        }
        public void setWeight(double weight) {
            this.weight = weight;
        }

        public int getCnt() {
            return cnt;
        }

        public void setCnt(int cnt) {
            this.cnt = cnt;
        }
    }
   int cnt=e.getValue();
   for (int k = 0; k < labels.size(); k++)
              {
                  String feature=labels.get(k)+":"+insFeature;
   if (weightMap.containsKey(feature)) {
                      double delta=pro[k] * rate*cnt;
    modelE.put(feature, modelE.containsKey(feature) ? modelE.get(feature) + delta : delta);
    }
              }
          }
      }
  }
  • 最大熵模型
    1 引用
  • Java

    Java 是一种可以撰写跨平台应用软件的面向对象的程序设计语言,是由 Sun Microsystems 公司于 1995 年 5 月推出的。Java 技术具有卓越的通用性、高效性、平台移植性和安全性。

    3187 引用 • 8213 回帖

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...

推荐标签 标签

  • GAE

    Google App Engine(GAE)是 Google 管理的数据中心中用于 WEB 应用程序的开发和托管的平台。2008 年 4 月 发布第一个测试版本。目前支持 Python、Java 和 Go 开发部署。全球已有数十万的开发者在其上开发了众多的应用。

    14 引用 • 42 回帖 • 761 关注
  • 分享

    有什么新发现就分享给大家吧!

    248 引用 • 1792 回帖 • 1 关注
  • GitHub

    GitHub 于 2008 年上线,目前,除了 Git 代码仓库托管及基本的 Web 管理界面以外,还提供了订阅、讨论组、文本渲染、在线文件编辑器、协作图谱(报表)、代码片段分享(Gist)等功能。正因为这些功能所提供的便利,又经过长期的积累,GitHub 的用户活跃度很高,在开源世界里享有深远的声望,并形成了社交化编程文化(Social Coding)。

    209 引用 • 2031 回帖 • 1 关注
  • Latke

    Latke 是一款以 JSON 为主的 Java Web 框架。

    70 引用 • 533 回帖 • 780 关注
  • Android

    Android 是一种以 Linux 为基础的开放源码操作系统,主要使用于便携设备。2005 年由 Google 收购注资,并拉拢多家制造商组成开放手机联盟开发改良,逐渐扩展到到平板电脑及其他领域上。

    334 引用 • 323 回帖 • 2 关注
  • Firefox

    Mozilla Firefox 中文俗称“火狐”(正式缩写为 Fx 或 fx,非正式缩写为 FF),是一个开源的网页浏览器,使用 Gecko 排版引擎,支持多种操作系统,如 Windows、OSX 及 Linux 等。

    8 引用 • 30 回帖 • 409 关注
  • Maven

    Maven 是基于项目对象模型(POM)、通过一小段描述信息来管理项目的构建、报告和文档的软件项目管理工具。

    186 引用 • 318 回帖 • 306 关注
  • Postman

    Postman 是一款简单好用的 HTTP API 调试工具。

    4 引用 • 3 回帖 • 2 关注
  • LeetCode

    LeetCode(力扣)是一个全球极客挚爱的高质量技术成长平台,想要学习和提升专业能力从这里开始,充足技术干货等你来啃,轻松拿下 Dream Offer!

    209 引用 • 72 回帖
  • 创业

    你比 99% 的人都优秀么?

    84 引用 • 1399 回帖
  • HBase

    HBase 是一个分布式的、面向列的开源数据库,该技术来源于 Fay Chang 所撰写的 Google 论文 “Bigtable:一个结构化数据的分布式存储系统”。就像 Bigtable 利用了 Google 文件系统所提供的分布式数据存储一样,HBase 在 Hadoop 之上提供了类似于 Bigtable 的能力。

    17 引用 • 6 回帖 • 72 关注
  • Ubuntu

    Ubuntu(友帮拓、优般图、乌班图)是一个以桌面应用为主的 Linux 操作系统,其名称来自非洲南部祖鲁语或豪萨语的“ubuntu”一词,意思是“人性”、“我的存在是因为大家的存在”,是非洲传统的一种价值观,类似华人社会的“仁爱”思想。Ubuntu 的目标在于为一般用户提供一个最新的、同时又相当稳定的主要由自由软件构建而成的操作系统。

    124 引用 • 169 回帖
  • 电影

    这是一个不能说的秘密。

    120 引用 • 599 回帖
  • Typecho

    Typecho 是一款博客程序,它在 GPLv2 许可证下发行,基于 PHP 构建,可以运行在各种平台上,支持多种数据库(MySQL、PostgreSQL、SQLite)。

    12 引用 • 65 回帖 • 446 关注
  • Solo

    Solo 是一款小而美的开源博客系统,专为程序员设计。Solo 有着非常活跃的社区,可将文章作为帖子推送到社区,来自社区的回帖将作为博客评论进行联动(具体细节请浏览 B3log 构思 - 分布式社区网络)。

    这是一种全新的网络社区体验,让热爱记录和分享的你不再感到孤单!

    1434 引用 • 10054 回帖 • 491 关注
  • 思源笔记

    思源笔记是一款隐私优先的个人知识管理系统,支持完全离线使用,同时也支持端到端加密同步。

    融合块、大纲和双向链接,重构你的思维。

    22188 引用 • 88668 回帖 • 7 关注
  • Sublime

    Sublime Text 是一款可以用来写代码、写文章的文本编辑器。支持代码高亮、自动完成,还支持通过插件进行扩展。

    10 引用 • 5 回帖
  • 阿里巴巴

    阿里巴巴网络技术有限公司(简称:阿里巴巴集团)是以曾担任英语教师的马云为首的 18 人,于 1999 年在中国杭州创立,他们相信互联网能够创造公平的竞争环境,让小企业通过创新与科技扩展业务,并在参与国内或全球市场竞争时处于更有利的位置。

    43 引用 • 221 回帖 • 120 关注
  • Netty

    Netty 是一个基于 NIO 的客户端-服务器编程框架,使用 Netty 可以让你快速、简单地开发出一个可维护、高性能的网络应用,例如实现了某种协议的客户、服务端应用。

    49 引用 • 33 回帖 • 24 关注
  • Vditor

    Vditor 是一款浏览器端的 Markdown 编辑器,支持所见即所得、即时渲染(类似 Typora)和分屏预览模式。它使用 TypeScript 实现,支持原生 JavaScript、Vue、React 和 Angular。

    349 引用 • 1803 回帖 • 1 关注
  • 架构

    我们平时所说的“架构”主要是指软件架构,这是有关软件整体结构与组件的抽象描述,用于指导软件系统各个方面的设计。另外还有“业务架构”、“网络架构”、“硬件架构”等细分领域。

    142 引用 • 442 回帖 • 1 关注
  • 运维

    互联网运维工作,以服务为中心,以稳定、安全、高效为三个基本点,确保公司的互联网业务能够 7×24 小时为用户提供高质量的服务。

    149 引用 • 257 回帖
  • Caddy

    Caddy 是一款默认自动启用 HTTPS 的 HTTP/2 Web 服务器。

    12 引用 • 54 回帖 • 164 关注
  • danl
    130 关注
  • Notion

    Notion - The all-in-one workspace for your notes, tasks, wikis, and databases.

    6 引用 • 38 回帖
  • B3log

    B3log 是一个开源组织,名字来源于“Bulletin Board Blog”缩写,目标是将独立博客与论坛结合,形成一种新的网络社区体验,详细请看 B3log 构思。目前 B3log 已经开源了多款产品:SymSoloVditor思源笔记

    1063 引用 • 3453 回帖 • 202 关注
  • Tomcat

    Tomcat 最早是由 Sun Microsystems 开发的一个 Servlet 容器,在 1999 年被捐献给 ASF(Apache Software Foundation),隶属于 Jakarta 项目,现在已经独立为一个顶级项目。Tomcat 主要实现了 JavaEE 中的 Servlet、JSP 规范,同时也提供 HTTP 服务,是市场上非常流行的 Java Web 容器。

    162 引用 • 529 回帖