优化版 JAVA 最大熵模型(GIS 训练)

本贴最后更新于 2147 天前,其中的信息可能已经渤澥桑田

网上现有的最大熵模型,如:https://blog.csdn.net/nwpuwyk/article/details/37500371
该代码在训练环节性能较差,特征函数存储的结构也涉及较简单。
我在该版本基础上进行了改进,优化了特征函数的数据结构和训练代码。

     /**
   * 样本数据集
   */
  List<Instance> instanceList = new ArrayList<Instance>();
  /**
   * 特征列表,来自所有事件的统计结果
   */

// Map<String,Feature> featureMap=new HashMap<>();
/**
* 每个特征的出现次数
*/
//Map<String,Integer> featureCountMap=new HashMap<>();
/**
* 事件(类别)集
*/
List labels = new ArrayList();
/**
* 每个特征函数的权重
*/
// double[] weight;
Map<String,Weight> weightMap=new HashMap<>();
/**
* 一个事件最多一共有多少种特征
*/
double learningRate=10;
int C;
Map<String,List> testInstance;
/**
* 样本数据集 */ List instanceList = new ArrayList();
/**
* 特征列表,来自所有事件的统计结果 */// Map featureMap=new HashMap<>();
/**
* 每个特征的出现次数 */ //Map featureCountMap=new HashMap<>();
/**
* 事件(类别)集 */ List labels = new ArrayList();
/**
* 每个特征函数的权重 */ // double[] weight;
Map<String,Weight> weightMap=new HashMap<>();
/**
* 一个事件最多一共有多少种特征 */ double learningRate=10;
int C;

   * 训练模型 * @param maxIt 最大迭代次数
   */public void train(int maxIt,String savePath) throws IOException {
      Map,Double> empiricalE = new HashMap<>(); // 经验期望
    Map,Double> modelE = new HashMap<>(); // 模型期望

    for (Map.Entry,Weight> e:weightMap.entrySet())
      {
          double ratio=(double) e.getValue().getCnt() / instanceList.size();
    empiricalE.put(e.getKey(),ratio);
    }
      Map,Double> lastWeight=new HashMap<>();
   for (int i = 0; i < maxIt; ++i)
      {
          System.out.println("iter:"+i);
    computeModeE(modelE);//计算模型期望
    System.out.println("model finish.updating...");
   for (Map.Entry,Weight> e:weightMap.entrySet())
          {
             //lastWeight[w] = weight[w];
    lastWeight.put(e.getKey(),e.getValue().getWeight());
    String f=e.getKey();
   double delta=learningRate / C * Math.log(empiricalE.get(f)/ modelE.get(f));
    weightMap.get(f).addWeight(delta);
    }
          System.out.println("saving iter:"+i);
    learningRate*=0.99;
    learningRate=learningRate<10?10:learningRate;
    saveParam(savePath+"ent_insopt.par"+i);
   if (checkConvergence(lastWeight, weightMap)) break;

    }

  }

  /**
   * 预测类别 * @param fieldList
    * @return
    */
  public Pair, Double>[] predict(Map,Integer> fieldList)
  {
      double[] prob = calProb(fieldList);
    Pair, Double>[] pairResult = new Pair[prob.length];
   for (int i = 0; i < prob.length; ++i)
      {
          pairResult[i] = new Pair, Double>(labels.get(i), prob[i]);
    }

      return pairResult;
  }

  /**
   * 检查是否收敛 * @param w1
    * @param w2
    * @return 是否收敛
   */public boolean checkConvergence(Map,Double> w1, Map,Weight> w2)
  {
      System.out.println("w1 size:"+w1.size());
   boolean flag=true;
   for (Map.Entry,Double> e1:w1.entrySet())
      {
          //System.out.println("thread:"+Math.abs(e1.getValue() - w2.get(e1.getKey())) );
    if (Math.abs(e1.getValue() - w2.get(e1.getKey()).getWeight()) >= 1e-4)    // 收敛阀值0.01可自行调整
    flag=false;
    }
      return flag;
  }

  /**
   * 计算模型期望,即在当前的特征函数的权重下,计算特征函数的模型期望值。 * @param modelE 储存空间,应当事先分配好内存(之所以不return一个modelE是为了避免重复分配内存)
   */public void computeModeE(Map,Double> modelE)
  {
      modelE.clear();
   double rate=1.0 / instanceList.size();
   for (int i = 0; i < instanceList.size(); ++i)
      {

          Map,Integer> fieldMap = instanceList.get(i).fieldList;//no labels
   //计算当前样本X对应所有类别的概率  double[] pro = calProb(fieldMap);
   for (Map.Entry,Integer> e:fieldMap.entrySet())
          {
              String insFeature=e.getKey();/**
     * 训练模型
     * @param maxIt 最大迭代次数
     */
    public void train(int maxIt,String savePath) throws IOException {
        Map<String,Double> empiricalE = new HashMap<>();   // 经验期望
        Map<String,Double> modelE = new HashMap<>();       // 模型期望

        for (Map.Entry<String,Weight> e:weightMap.entrySet())
        {
            double ratio=(double) e.getValue().getCnt() / instanceList.size();
            empiricalE.put(e.getKey(),ratio);
        }
        Map<String,Double> lastWeight=new HashMap<>();
        for (int i = 0; i < maxIt; ++i)
        {
            System.out.println("iter:"+i);
            computeModeE(modelE);//计算模型期望
            System.out.println("model finish.updating...");
            for (Map.Entry<String,Weight> e:weightMap.entrySet())
            {
               //lastWeight[w] = weight[w];
                lastWeight.put(e.getKey(),e.getValue().getWeight());
                String f=e.getKey();
                double delta=learningRate / C * Math.log(empiricalE.get(f)/ modelE.get(f));
                weightMap.get(f).addWeight(delta);
            }
            System.out.println("saving iter:"+i);
            learningRate*=0.99;
            learningRate=learningRate<10?10:learningRate;
            saveParam(savePath+"ent_insopt.par"+i);
            if (checkConvergence(lastWeight, weightMap)) break;

        }

    }

    /**
     * 预测类别
     * @param fieldList
     * @return
     */
    public Pair<String, Double>[] predict(Map<String,Integer> fieldList)
    {
        double[] prob = calProb(fieldList);
        Pair<String, Double>[] pairResult = new Pair[prob.length];
        for (int i = 0; i < prob.length; ++i)
        {
            pairResult[i] = new Pair<String, Double>(labels.get(i), prob[i]);
        }

        return pairResult;
    }

    /**
     * 检查是否收敛
     * @param w1
     * @param w2
     * @return 是否收敛
     */
    public boolean checkConvergence(Map<String,Double> w1, Map<String,Weight> w2)
    {
        System.out.println("w1 size:"+w1.size());
        boolean flag=true;
        for (Map.Entry<String,Double> e1:w1.entrySet())
        {
            //System.out.println("thread:"+Math.abs(e1.getValue() - w2.get(e1.getKey())) );
            if (Math.abs(e1.getValue() - w2.get(e1.getKey()).getWeight()) >= 1e-4)    // 收敛阀值0.01可自行调整
                flag=false;
        }
        return flag;
    }

    /**
     * 计算模型期望,即在当前的特征函数的权重下,计算特征函数的模型期望值。
     * @param modelE 储存空间,应当事先分配好内存(之所以不return一个modelE是为了避免重复分配内存)
     */
    public void computeModeE(Map<String,Double> modelE)
    {
        modelE.clear();
        double rate=1.0 / instanceList.size();
        for (int i = 0; i < instanceList.size(); ++i)
        {

            Map<String,Integer> fieldMap = instanceList.get(i).fieldList;//no labels
             //计算当前样本X对应所有类别的概率
            double[] pro = calProb(fieldMap);
            for (Map.Entry<String,Integer> e:fieldMap.entrySet())
            {
                String insFeature=e.getKey();
                int cnt=e.getValue();
                for (int k = 0; k < labels.size(); k++)
                {
                    String feature=labels.get(k)+":"+insFeature;
                      if (weightMap.containsKey(feature)) {
                        double  delta=pro[k] * rate*cnt;
                        modelE.put(feature, modelE.containsKey(feature) ? modelE.get(feature) + delta : delta);
                    }
                }
            }
        }
    }
//    public class Mode implements Runnable
//    {
//        ConcurrentLinkedQueue<Integer> insQueue=new ConcurrentLinkedQueue<>();
//        boolean flag=true;
//        List<Instance> i∂
//        public void addIns(int i)
//        {
//
//        }
//
//        @Override
//        public void run() {
//            while(flag)
//            {
//                int ins=insQueue.poll();
//            }
//        }
//    }
    /**
     * 计算p(y|x),此时的x指的是instance里的field
     * @param fieldList 实例的特征列表
     * @return 该实例属于每个类别的概率
     */
    public double[] calProb(Map<String,Integer> fieldList)
    {
        double[] p = new double[labels.size()];
        double sum = 0;  // 正则化因子,保证概率和为1
        for (int i = 0; i < labels.size(); ++i)
        {
            double weightSum = 0;
            String label=labels.get(i);
            for (String field : fieldList.keySet())
            {
                String feature=label+":"+field;
                 if (weightMap.containsKey(feature)) {
                    weightSum += weightMap.get(feature).getWeight()*fieldList.get(field);
                }
            }
            if(weightSum>15)
            {
                weightSum=15;
            }
            p[i] = Math.exp(weightSum);

            sum += p[i];
        }
        //System.out.println();
        for (int i = 0; i < p.length; ++i)
        {
            p[i] /= sum;
//            if(Double.isNaN(p[i]))
//            {
//                System.out.println(p[i]);
//            }
        }
        return p;
    }

    /**
     * 一个观测实例,包含事件和时间发生的环境
     */
    class Instance implements Serializable
    {
        /**
         * 事件(类别),如Outdoor
         */
        String label;
        /**
         * 事件发生的环境集合,如[Sunny, Happy]
         */
        Map<String,Integer> fieldList = new HashMap<>();

        public Instance(String label, Map<String,Integer>fieldList)
        {
            this.label = label;
            this.fieldList = fieldList;
        }
    }

    /**
     * 特征(二值函数)
     */
    class Weight
    {
        double weight=0;
        int cnt=0;
        public void addWeight(double w)
        {
            weight+=(w);
        }
        public double getWeight() {
            return weight;
        }
        public void addCnt(int c)
        {
            cnt+=c;
        }
        public void setWeight(double weight) {
            this.weight = weight;
        }

        public int getCnt() {
            return cnt;
        }

        public void setCnt(int cnt) {
            this.cnt = cnt;
        }
    }
   int cnt=e.getValue();
   for (int k = 0; k < labels.size(); k++)
              {
                  String feature=labels.get(k)+":"+insFeature;
   if (weightMap.containsKey(feature)) {
                      double delta=pro[k] * rate*cnt;
    modelE.put(feature, modelE.containsKey(feature) ? modelE.get(feature) + delta : delta);
    }
              }
          }
      }
  }
  • 最大熵模型
    1 引用
  • Java

    Java 是一种可以撰写跨平台应用软件的面向对象的程序设计语言,是由 Sun Microsystems 公司于 1995 年 5 月推出的。Java 技术具有卓越的通用性、高效性、平台移植性和安全性。

    3169 引用 • 8208 回帖

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...

推荐标签 标签

  • 招聘

    哪里都缺人,哪里都不缺人。

    189 引用 • 1056 回帖 • 1 关注
  • 正则表达式

    正则表达式(Regular Expression)使用单个字符串来描述、匹配一系列遵循某个句法规则的字符串。

    31 引用 • 94 回帖
  • 笔记

    好记性不如烂笔头。

    306 引用 • 782 回帖
  • Jenkins

    Jenkins 是一套开源的持续集成工具。它提供了非常丰富的插件,让构建、部署、自动化集成项目变得简单易用。

    51 引用 • 37 回帖 • 2 关注
  • 博客

    记录并分享人生的经历。

    272 引用 • 2386 回帖
  • Angular

    AngularAngularJS 的新版本。

    26 引用 • 66 回帖 • 531 关注
  • 链滴

    链滴是一个记录生活的地方。

    记录生活,连接点滴

    141 引用 • 3721 回帖 • 1 关注
  • ReactiveX

    ReactiveX 是一个专注于异步编程与控制可观察数据(或者事件)流的 API。它组合了观察者模式,迭代器模式和函数式编程的优秀思想。

    1 引用 • 2 回帖 • 141 关注
  • Unity

    Unity 是由 Unity Technologies 开发的一个让开发者可以轻松创建诸如 2D、3D 多平台的综合型游戏开发工具,是一个全面整合的专业游戏引擎。

    25 引用 • 7 回帖 • 233 关注
  • Telegram

    Telegram 是一个非盈利性、基于云端的即时消息服务。它提供了支持各大操作系统平台的开源的客户端,也提供了很多强大的 APIs 给开发者创建自己的客户端和机器人。

    5 引用 • 35 回帖 • 1 关注
  • JWT

    JWT(JSON Web Token)是一种用于双方之间传递信息的简洁的、安全的表述性声明规范。JWT 作为一个开放的标准(RFC 7519),定义了一种简洁的,自包含的方法用于通信双方之间以 JSON 的形式安全的传递信息。

    20 引用 • 15 回帖 • 19 关注
  • 外包

    有空闲时间是接外包好呢还是学习好呢?

    26 引用 • 232 回帖
  • HBase

    HBase 是一个分布式的、面向列的开源数据库,该技术来源于 Fay Chang 所撰写的 Google 论文 “Bigtable:一个结构化数据的分布式存储系统”。就像 Bigtable 利用了 Google 文件系统所提供的分布式数据存储一样,HBase 在 Hadoop 之上提供了类似于 Bigtable 的能力。

    17 引用 • 6 回帖 • 58 关注
  • IPFS

    IPFS(InterPlanetary File System,星际文件系统)是永久的、去中心化保存和共享文件的方法,这是一种内容可寻址、版本化、点对点超媒体的分布式协议。请浏览 IPFS 入门笔记了解更多细节。

    20 引用 • 245 回帖 • 234 关注
  • webpack

    webpack 是一个用于前端开发的模块加载器和打包工具,它能把各种资源,例如 JS、CSS(less/sass)、图片等都作为模块来使用和处理。

    41 引用 • 130 回帖 • 288 关注
  • Mobi.css

    Mobi.css is a lightweight, flexible CSS framework that focus on mobile.

    1 引用 • 6 回帖 • 708 关注
  • BND

    BND(Baidu Netdisk Downloader)是一款图形界面的百度网盘不限速下载器,支持 Windows、Linux 和 Mac,详细介绍请看这里

    107 引用 • 1281 回帖 • 31 关注
  • Chrome

    Chrome 又称 Google 浏览器,是一个由谷歌公司开发的网页浏览器。该浏览器是基于其他开源软件所编写,包括 WebKit,目标是提升稳定性、速度和安全性,并创造出简单且有效率的使用者界面。

    60 引用 • 287 回帖
  • OkHttp

    OkHttp 是一款 HTTP & HTTP/2 客户端库,专为 Android 和 Java 应用打造。

    16 引用 • 6 回帖 • 53 关注
  • 酷鸟浏览器

    安全 · 稳定 · 快速
    为跨境从业人员提供专业的跨境浏览器

    3 引用 • 59 回帖 • 16 关注
  • MySQL

    MySQL 是一个关系型数据库管理系统,由瑞典 MySQL AB 公司开发,目前属于 Oracle 公司。MySQL 是最流行的关系型数据库管理系统之一。

    675 引用 • 535 回帖
  • Kotlin

    Kotlin 是一种在 Java 虚拟机上运行的静态类型编程语言,由 JetBrains 设计开发并开源。Kotlin 可以编译成 Java 字节码,也可以编译成 JavaScript,方便在没有 JVM 的设备上运行。在 Google I/O 2017 中,Google 宣布 Kotlin 成为 Android 官方开发语言。

    19 引用 • 33 回帖 • 43 关注
  • Swagger

    Swagger 是一款非常流行的 API 开发工具,它遵循 OpenAPI Specification(这是一种通用的、和编程语言无关的 API 描述规范)。Swagger 贯穿整个 API 生命周期,如 API 的设计、编写文档、测试和部署。

    26 引用 • 35 回帖 • 11 关注
  • API

    应用程序编程接口(Application Programming Interface)是一些预先定义的函数,目的是提供应用程序与开发人员基于某软件或硬件得以访问一组例程的能力,而又无需访问源码,或理解内部工作机制的细节。

    76 引用 • 429 回帖
  • 书籍

    宋真宗赵恒曾经说过:“书中自有黄金屋,书中自有颜如玉。”

    76 引用 • 390 回帖
  • 星云链

    星云链是一个开源公链,业内简单的将其称为区块链上的谷歌。其实它不仅仅是区块链搜索引擎,一个公链的所有功能,它基本都有,比如你可以用它来开发部署你的去中心化的 APP,你可以在上面编写智能合约,发送交易等等。3 分钟快速接入星云链 (NAS) 测试网

    3 引用 • 16 回帖
  • Firefox

    Mozilla Firefox 中文俗称“火狐”(正式缩写为 Fx 或 fx,非正式缩写为 FF),是一个开源的网页浏览器,使用 Gecko 排版引擎,支持多种操作系统,如 Windows、OSX 及 Linux 等。

    7 引用 • 30 回帖 • 446 关注