优化版 JAVA 最大熵模型(GIS 训练)

本贴最后更新于 2084 天前,其中的信息可能已经渤澥桑田

网上现有的最大熵模型,如:https://blog.csdn.net/nwpuwyk/article/details/37500371
该代码在训练环节性能较差,特征函数存储的结构也涉及较简单。
我在该版本基础上进行了改进,优化了特征函数的数据结构和训练代码。

     /**
   * 样本数据集
   */
  List<Instance> instanceList = new ArrayList<Instance>();
  /**
   * 特征列表,来自所有事件的统计结果
   */

// Map<String,Feature> featureMap=new HashMap<>();
/**
* 每个特征的出现次数
*/
//Map<String,Integer> featureCountMap=new HashMap<>();
/**
* 事件(类别)集
*/
List labels = new ArrayList();
/**
* 每个特征函数的权重
*/
// double[] weight;
Map<String,Weight> weightMap=new HashMap<>();
/**
* 一个事件最多一共有多少种特征
*/
double learningRate=10;
int C;
Map<String,List> testInstance;
/**
* 样本数据集 */ List instanceList = new ArrayList();
/**
* 特征列表,来自所有事件的统计结果 */// Map featureMap=new HashMap<>();
/**
* 每个特征的出现次数 */ //Map featureCountMap=new HashMap<>();
/**
* 事件(类别)集 */ List labels = new ArrayList();
/**
* 每个特征函数的权重 */ // double[] weight;
Map<String,Weight> weightMap=new HashMap<>();
/**
* 一个事件最多一共有多少种特征 */ double learningRate=10;
int C;

   * 训练模型 * @param maxIt 最大迭代次数
   */public void train(int maxIt,String savePath) throws IOException {
      Map,Double> empiricalE = new HashMap<>(); // 经验期望
    Map,Double> modelE = new HashMap<>(); // 模型期望

    for (Map.Entry,Weight> e:weightMap.entrySet())
      {
          double ratio=(double) e.getValue().getCnt() / instanceList.size();
    empiricalE.put(e.getKey(),ratio);
    }
      Map,Double> lastWeight=new HashMap<>();
   for (int i = 0; i < maxIt; ++i)
      {
          System.out.println("iter:"+i);
    computeModeE(modelE);//计算模型期望
    System.out.println("model finish.updating...");
   for (Map.Entry,Weight> e:weightMap.entrySet())
          {
             //lastWeight[w] = weight[w];
    lastWeight.put(e.getKey(),e.getValue().getWeight());
    String f=e.getKey();
   double delta=learningRate / C * Math.log(empiricalE.get(f)/ modelE.get(f));
    weightMap.get(f).addWeight(delta);
    }
          System.out.println("saving iter:"+i);
    learningRate*=0.99;
    learningRate=learningRate<10?10:learningRate;
    saveParam(savePath+"ent_insopt.par"+i);
   if (checkConvergence(lastWeight, weightMap)) break;

    }

  }

  /**
   * 预测类别 * @param fieldList
    * @return
    */
  public Pair, Double>[] predict(Map,Integer> fieldList)
  {
      double[] prob = calProb(fieldList);
    Pair, Double>[] pairResult = new Pair[prob.length];
   for (int i = 0; i < prob.length; ++i)
      {
          pairResult[i] = new Pair, Double>(labels.get(i), prob[i]);
    }

      return pairResult;
  }

  /**
   * 检查是否收敛 * @param w1
    * @param w2
    * @return 是否收敛
   */public boolean checkConvergence(Map,Double> w1, Map,Weight> w2)
  {
      System.out.println("w1 size:"+w1.size());
   boolean flag=true;
   for (Map.Entry,Double> e1:w1.entrySet())
      {
          //System.out.println("thread:"+Math.abs(e1.getValue() - w2.get(e1.getKey())) );
    if (Math.abs(e1.getValue() - w2.get(e1.getKey()).getWeight()) >= 1e-4)    // 收敛阀值0.01可自行调整
    flag=false;
    }
      return flag;
  }

  /**
   * 计算模型期望,即在当前的特征函数的权重下,计算特征函数的模型期望值。 * @param modelE 储存空间,应当事先分配好内存(之所以不return一个modelE是为了避免重复分配内存)
   */public void computeModeE(Map,Double> modelE)
  {
      modelE.clear();
   double rate=1.0 / instanceList.size();
   for (int i = 0; i < instanceList.size(); ++i)
      {

          Map,Integer> fieldMap = instanceList.get(i).fieldList;//no labels
   //计算当前样本X对应所有类别的概率  double[] pro = calProb(fieldMap);
   for (Map.Entry,Integer> e:fieldMap.entrySet())
          {
              String insFeature=e.getKey();/**
     * 训练模型
     * @param maxIt 最大迭代次数
     */
    public void train(int maxIt,String savePath) throws IOException {
        Map<String,Double> empiricalE = new HashMap<>();   // 经验期望
        Map<String,Double> modelE = new HashMap<>();       // 模型期望

        for (Map.Entry<String,Weight> e:weightMap.entrySet())
        {
            double ratio=(double) e.getValue().getCnt() / instanceList.size();
            empiricalE.put(e.getKey(),ratio);
        }
        Map<String,Double> lastWeight=new HashMap<>();
        for (int i = 0; i < maxIt; ++i)
        {
            System.out.println("iter:"+i);
            computeModeE(modelE);//计算模型期望
            System.out.println("model finish.updating...");
            for (Map.Entry<String,Weight> e:weightMap.entrySet())
            {
               //lastWeight[w] = weight[w];
                lastWeight.put(e.getKey(),e.getValue().getWeight());
                String f=e.getKey();
                double delta=learningRate / C * Math.log(empiricalE.get(f)/ modelE.get(f));
                weightMap.get(f).addWeight(delta);
            }
            System.out.println("saving iter:"+i);
            learningRate*=0.99;
            learningRate=learningRate<10?10:learningRate;
            saveParam(savePath+"ent_insopt.par"+i);
            if (checkConvergence(lastWeight, weightMap)) break;

        }

    }

    /**
     * 预测类别
     * @param fieldList
     * @return
     */
    public Pair<String, Double>[] predict(Map<String,Integer> fieldList)
    {
        double[] prob = calProb(fieldList);
        Pair<String, Double>[] pairResult = new Pair[prob.length];
        for (int i = 0; i < prob.length; ++i)
        {
            pairResult[i] = new Pair<String, Double>(labels.get(i), prob[i]);
        }

        return pairResult;
    }

    /**
     * 检查是否收敛
     * @param w1
     * @param w2
     * @return 是否收敛
     */
    public boolean checkConvergence(Map<String,Double> w1, Map<String,Weight> w2)
    {
        System.out.println("w1 size:"+w1.size());
        boolean flag=true;
        for (Map.Entry<String,Double> e1:w1.entrySet())
        {
            //System.out.println("thread:"+Math.abs(e1.getValue() - w2.get(e1.getKey())) );
            if (Math.abs(e1.getValue() - w2.get(e1.getKey()).getWeight()) >= 1e-4)    // 收敛阀值0.01可自行调整
                flag=false;
        }
        return flag;
    }

    /**
     * 计算模型期望,即在当前的特征函数的权重下,计算特征函数的模型期望值。
     * @param modelE 储存空间,应当事先分配好内存(之所以不return一个modelE是为了避免重复分配内存)
     */
    public void computeModeE(Map<String,Double> modelE)
    {
        modelE.clear();
        double rate=1.0 / instanceList.size();
        for (int i = 0; i < instanceList.size(); ++i)
        {

            Map<String,Integer> fieldMap = instanceList.get(i).fieldList;//no labels
             //计算当前样本X对应所有类别的概率
            double[] pro = calProb(fieldMap);
            for (Map.Entry<String,Integer> e:fieldMap.entrySet())
            {
                String insFeature=e.getKey();
                int cnt=e.getValue();
                for (int k = 0; k < labels.size(); k++)
                {
                    String feature=labels.get(k)+":"+insFeature;
                      if (weightMap.containsKey(feature)) {
                        double  delta=pro[k] * rate*cnt;
                        modelE.put(feature, modelE.containsKey(feature) ? modelE.get(feature) + delta : delta);
                    }
                }
            }
        }
    }
//    public class Mode implements Runnable
//    {
//        ConcurrentLinkedQueue<Integer> insQueue=new ConcurrentLinkedQueue<>();
//        boolean flag=true;
//        List<Instance> i∂
//        public void addIns(int i)
//        {
//
//        }
//
//        @Override
//        public void run() {
//            while(flag)
//            {
//                int ins=insQueue.poll();
//            }
//        }
//    }
    /**
     * 计算p(y|x),此时的x指的是instance里的field
     * @param fieldList 实例的特征列表
     * @return 该实例属于每个类别的概率
     */
    public double[] calProb(Map<String,Integer> fieldList)
    {
        double[] p = new double[labels.size()];
        double sum = 0;  // 正则化因子,保证概率和为1
        for (int i = 0; i < labels.size(); ++i)
        {
            double weightSum = 0;
            String label=labels.get(i);
            for (String field : fieldList.keySet())
            {
                String feature=label+":"+field;
                 if (weightMap.containsKey(feature)) {
                    weightSum += weightMap.get(feature).getWeight()*fieldList.get(field);
                }
            }
            if(weightSum>15)
            {
                weightSum=15;
            }
            p[i] = Math.exp(weightSum);

            sum += p[i];
        }
        //System.out.println();
        for (int i = 0; i < p.length; ++i)
        {
            p[i] /= sum;
//            if(Double.isNaN(p[i]))
//            {
//                System.out.println(p[i]);
//            }
        }
        return p;
    }

    /**
     * 一个观测实例,包含事件和时间发生的环境
     */
    class Instance implements Serializable
    {
        /**
         * 事件(类别),如Outdoor
         */
        String label;
        /**
         * 事件发生的环境集合,如[Sunny, Happy]
         */
        Map<String,Integer> fieldList = new HashMap<>();

        public Instance(String label, Map<String,Integer>fieldList)
        {
            this.label = label;
            this.fieldList = fieldList;
        }
    }

    /**
     * 特征(二值函数)
     */
    class Weight
    {
        double weight=0;
        int cnt=0;
        public void addWeight(double w)
        {
            weight+=(w);
        }
        public double getWeight() {
            return weight;
        }
        public void addCnt(int c)
        {
            cnt+=c;
        }
        public void setWeight(double weight) {
            this.weight = weight;
        }

        public int getCnt() {
            return cnt;
        }

        public void setCnt(int cnt) {
            this.cnt = cnt;
        }
    }
   int cnt=e.getValue();
   for (int k = 0; k < labels.size(); k++)
              {
                  String feature=labels.get(k)+":"+insFeature;
   if (weightMap.containsKey(feature)) {
                      double delta=pro[k] * rate*cnt;
    modelE.put(feature, modelE.containsKey(feature) ? modelE.get(feature) + delta : delta);
    }
              }
          }
      }
  }
  • 最大熵模型
    1 引用
  • Java

    Java 是一种可以撰写跨平台应用软件的面向对象的程序设计语言,是由 Sun Microsystems 公司于 1995 年 5 月推出的。Java 技术具有卓越的通用性、高效性、平台移植性和安全性。

    3169 引用 • 8207 回帖 • 1 关注

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...

推荐标签 标签

  • SSL

    SSL(Secure Sockets Layer 安全套接层),及其继任者传输层安全(Transport Layer Security,TLS)是为网络通信提供安全及数据完整性的一种安全协议。TLS 与 SSL 在传输层对网络连接进行加密。

    69 引用 • 190 回帖 • 497 关注
  • 资讯

    资讯是用户因为及时地获得它并利用它而能够在相对短的时间内给自己带来价值的信息,资讯有时效性和地域性。

    53 引用 • 85 回帖 • 1 关注
  • Postman

    Postman 是一款简单好用的 HTTP API 调试工具。

    4 引用 • 3 回帖
  • Android

    Android 是一种以 Linux 为基础的开放源码操作系统,主要使用于便携设备。2005 年由 Google 收购注资,并拉拢多家制造商组成开放手机联盟开发改良,逐渐扩展到到平板电脑及其他领域上。

    333 引用 • 323 回帖 • 65 关注
  • danl
    64 关注
  • 安全

    安全永远都不是一个小问题。

    189 引用 • 813 回帖
  • QQ

    1999 年 2 月腾讯正式推出“腾讯 QQ”,在线用户由 1999 年的 2 人(马化腾和张志东)到现在已经发展到上亿用户了,在线人数超过一亿,是目前使用最广泛的聊天软件之一。

    45 引用 • 557 回帖 • 218 关注
  • Laravel

    Laravel 是一套简洁、优雅的 PHP Web 开发框架。它采用 MVC 设计,是一款崇尚开发效率的全栈框架。

    19 引用 • 23 回帖 • 689 关注
  • Ngui

    Ngui 是一个 GUI 的排版显示引擎和跨平台的 GUI 应用程序开发框架,基于
    Node.js / OpenGL。目标是在此基础上开发 GUI 应用程序可拥有开发 WEB 应用般简单与速度同时兼顾 Native 应用程序的性能与体验。

    7 引用 • 9 回帖 • 345 关注
  • Chrome

    Chrome 又称 Google 浏览器,是一个由谷歌公司开发的网页浏览器。该浏览器是基于其他开源软件所编写,包括 WebKit,目标是提升稳定性、速度和安全性,并创造出简单且有效率的使用者界面。

    60 引用 • 287 回帖 • 1 关注
  • Lute

    Lute 是一款结构化的 Markdown 引擎,支持 Go 和 JavaScript。

    25 引用 • 191 回帖 • 23 关注
  • 尊园地产

    昆明尊园房地产经纪有限公司,即:Kunming Zunyuan Property Agency Company Limited(简称“尊园地产”)于 2007 年 6 月开始筹备,2007 年 8 月 18 日正式成立,注册资本 200 万元,公司性质为股份经纪有限公司,主营业务为:代租、代售、代办产权过户、办理银行按揭、担保、抵押、评估等。

    1 引用 • 22 回帖 • 686 关注
  • BAE

    百度应用引擎(Baidu App Engine)提供了 PHP、Java、Python 的执行环境,以及云存储、消息服务、云数据库等全面的云服务。它可以让开发者实现自动地部署和管理应用,并且提供动态扩容和负载均衡的运行环境,让开发者不用考虑高成本的运维工作,只需专注于业务逻辑,大大降低了开发者学习和迁移的成本。

    19 引用 • 75 回帖 • 619 关注
  • golang

    Go 语言是 Google 推出的一种全新的编程语言,可以在不损失应用程序性能的情况下降低代码的复杂性。谷歌首席软件工程师罗布派克(Rob Pike)说:我们之所以开发 Go,是因为过去 10 多年间软件开发的难度令人沮丧。Go 是谷歌 2009 发布的第二款编程语言。

    492 引用 • 1383 回帖 • 368 关注
  • Webswing

    Webswing 是一个能将任何 Swing 应用通过纯 HTML5 运行在浏览器中的 Web 服务器,详细介绍请看 将 Java Swing 应用变成 Web 应用

    1 引用 • 15 回帖 • 635 关注
  • JWT

    JWT(JSON Web Token)是一种用于双方之间传递信息的简洁的、安全的表述性声明规范。JWT 作为一个开放的标准(RFC 7519),定义了一种简洁的,自包含的方法用于通信双方之间以 JSON 的形式安全的传递信息。

    20 引用 • 15 回帖 • 18 关注
  • SMTP

    SMTP(Simple Mail Transfer Protocol)即简单邮件传输协议,它是一组用于由源地址到目的地址传送邮件的规则,由它来控制信件的中转方式。SMTP 协议属于 TCP/IP 协议簇,它帮助每台计算机在发送或中转信件时找到下一个目的地。

    4 引用 • 18 回帖 • 592 关注
  • CAP

    CAP 指的是在一个分布式系统中, Consistency(一致性)、 Availability(可用性)、Partition tolerance(分区容错性),三者不可兼得。

    11 引用 • 5 回帖 • 566 关注
  • RYMCU

    RYMCU 致力于打造一个即严谨又活泼、专业又不失有趣,为数百万人服务的开源嵌入式知识学习交流平台。

    4 引用 • 6 回帖 • 39 关注
  • Scala

    Scala 是一门多范式的编程语言,集成面向对象编程和函数式编程的各种特性。

    13 引用 • 11 回帖 • 109 关注
  • Sandbox

    如果帖子标签含有 Sandbox ,则该帖子会被视为“测试帖”,主要用于测试社区功能,排查 bug 等,该标签下内容不定期进行清理。

    370 引用 • 1215 回帖 • 582 关注
  • OAuth

    OAuth 协议为用户资源的授权提供了一个安全的、开放而又简易的标准。与以往的授权方式不同之处是 oAuth 的授权不会使第三方触及到用户的帐号信息(如用户名与密码),即第三方无需使用用户的用户名与密码就可以申请获得该用户资源的授权,因此 oAuth 是安全的。oAuth 是 Open Authorization 的简写。

    36 引用 • 103 回帖 • 8 关注
  • Rust

    Rust 是一门赋予每个人构建可靠且高效软件能力的语言。Rust 由 Mozilla 开发,最早发布于 2014 年 9 月。

    57 引用 • 22 回帖 • 3 关注
  • Wide

    Wide 是一款基于 Web 的 Go 语言 IDE。通过浏览器就可以进行 Go 开发,并有代码自动完成、查看表达式、编译反馈、Lint、实时结果输出等功能。

    欢迎访问我们运维的实例: https://wide.b3log.org

    30 引用 • 218 回帖 • 602 关注
  • 脑图

    脑图又叫思维导图,是表达发散性思维的有效图形思维工具 ,它简单却又很有效,是一种实用性的思维工具。

    21 引用 • 58 回帖
  • Hibernate

    Hibernate 是一个开放源代码的对象关系映射框架,它对 JDBC 进行了非常轻量级的对象封装,使得 Java 程序员可以随心所欲的使用对象编程思维来操纵数据库。

    39 引用 • 103 回帖 • 683 关注
  • Kotlin

    Kotlin 是一种在 Java 虚拟机上运行的静态类型编程语言,由 JetBrains 设计开发并开源。Kotlin 可以编译成 Java 字节码,也可以编译成 JavaScript,方便在没有 JVM 的设备上运行。在 Google I/O 2017 中,Google 宣布 Kotlin 成为 Android 官方开发语言。

    19 引用 • 33 回帖 • 28 关注