优化版 JAVA 最大熵模型(GIS 训练)

本贴最后更新于 2429 天前,其中的信息可能已经渤澥桑田

网上现有的最大熵模型,如:https://blog.csdn.net/nwpuwyk/article/details/37500371
该代码在训练环节性能较差,特征函数存储的结构也涉及较简单。
我在该版本基础上进行了改进,优化了特征函数的数据结构和训练代码。

/** * 样本数据集 */ List<Instance> instanceList = new ArrayList<Instance>(); /** * 特征列表,来自所有事件的统计结果 */

// Map<String,Feature> featureMap=new HashMap<>();
/**
* 每个特征的出现次数
*/
//Map<String,Integer> featureCountMap=new HashMap<>();
/**
* 事件(类别)集
*/
List labels = new ArrayList();
/**
* 每个特征函数的权重
*/
// double[] weight;
Map<String,Weight> weightMap=new HashMap<>();
/**
* 一个事件最多一共有多少种特征
*/
double learningRate=10;
int C;
Map<String,List> testInstance;
/**
* 样本数据集 */ List instanceList = new ArrayList();
/**
* 特征列表,来自所有事件的统计结果 */// Map featureMap=new HashMap<>();
/**
* 每个特征的出现次数 */ //Map featureCountMap=new HashMap<>();
/**
* 事件(类别)集 */ List labels = new ArrayList();
/**
* 每个特征函数的权重 */ // double[] weight;
Map<String,Weight> weightMap=new HashMap<>();
/**
* 一个事件最多一共有多少种特征 */ double learningRate=10;
int C;

* 训练模型 * @param maxIt 最大迭代次数 */public void train(int maxIt,String savePath) throws IOException { Map,Double> empiricalE = new HashMap<>(); // 经验期望 Map,Double> modelE = new HashMap<>(); // 模型期望 for (Map.Entry,Weight> e:weightMap.entrySet()) { double ratio=(double) e.getValue().getCnt() / instanceList.size(); empiricalE.put(e.getKey(),ratio); } Map,Double> lastWeight=new HashMap<>(); for (int i = 0; i < maxIt; ++i) { System.out.println("iter:"+i); computeModeE(modelE);//计算模型期望 System.out.println("model finish.updating..."); for (Map.Entry,Weight> e:weightMap.entrySet()) { //lastWeight[w] = weight[w]; lastWeight.put(e.getKey(),e.getValue().getWeight()); String f=e.getKey(); double delta=learningRate / C * Math.log(empiricalE.get(f)/ modelE.get(f)); weightMap.get(f).addWeight(delta); } System.out.println("saving iter:"+i); learningRate*=0.99; learningRate=learningRate<10?10:learningRate; saveParam(savePath+"ent_insopt.par"+i); if (checkConvergence(lastWeight, weightMap)) break; } } /** * 预测类别 * @param fieldList * @return */ public Pair, Double>[] predict(Map,Integer> fieldList) { double[] prob = calProb(fieldList); Pair, Double>[] pairResult = new Pair[prob.length]; for (int i = 0; i < prob.length; ++i) { pairResult[i] = new Pair, Double>(labels.get(i), prob[i]); } return pairResult; } /** * 检查是否收敛 * @param w1 * @param w2 * @return 是否收敛 */public boolean checkConvergence(Map,Double> w1, Map,Weight> w2) { System.out.println("w1 size:"+w1.size()); boolean flag=true; for (Map.Entry,Double> e1:w1.entrySet()) { //System.out.println("thread:"+Math.abs(e1.getValue() - w2.get(e1.getKey())) ); if (Math.abs(e1.getValue() - w2.get(e1.getKey()).getWeight()) >= 1e-4) // 收敛阀值0.01可自行调整 flag=false; } return flag; } /** * 计算模型期望,即在当前的特征函数的权重下,计算特征函数的模型期望值。 * @param modelE 储存空间,应当事先分配好内存(之所以不return一个modelE是为了避免重复分配内存) */public void computeModeE(Map,Double> modelE) { modelE.clear(); double rate=1.0 / instanceList.size(); for (int i = 0; i < instanceList.size(); ++i) { Map,Integer> fieldMap = instanceList.get(i).fieldList;//no labels //计算当前样本X对应所有类别的概率 double[] pro = calProb(fieldMap); for (Map.Entry,Integer> e:fieldMap.entrySet()) { String insFeature=e.getKey();/** * 训练模型 * @param maxIt 最大迭代次数 */ public void train(int maxIt,String savePath) throws IOException { Map<String,Double> empiricalE = new HashMap<>(); // 经验期望 Map<String,Double> modelE = new HashMap<>(); // 模型期望 for (Map.Entry<String,Weight> e:weightMap.entrySet()) { double ratio=(double) e.getValue().getCnt() / instanceList.size(); empiricalE.put(e.getKey(),ratio); } Map<String,Double> lastWeight=new HashMap<>(); for (int i = 0; i < maxIt; ++i) { System.out.println("iter:"+i); computeModeE(modelE);//计算模型期望 System.out.println("model finish.updating..."); for (Map.Entry<String,Weight> e:weightMap.entrySet()) { //lastWeight[w] = weight[w]; lastWeight.put(e.getKey(),e.getValue().getWeight()); String f=e.getKey(); double delta=learningRate / C * Math.log(empiricalE.get(f)/ modelE.get(f)); weightMap.get(f).addWeight(delta); } System.out.println("saving iter:"+i); learningRate*=0.99; learningRate=learningRate<10?10:learningRate; saveParam(savePath+"ent_insopt.par"+i); if (checkConvergence(lastWeight, weightMap)) break; } } /** * 预测类别 * @param fieldList * @return */ public Pair<String, Double>[] predict(Map<String,Integer> fieldList) { double[] prob = calProb(fieldList); Pair<String, Double>[] pairResult = new Pair[prob.length]; for (int i = 0; i < prob.length; ++i) { pairResult[i] = new Pair<String, Double>(labels.get(i), prob[i]); } return pairResult; } /** * 检查是否收敛 * @param w1 * @param w2 * @return 是否收敛 */ public boolean checkConvergence(Map<String,Double> w1, Map<String,Weight> w2) { System.out.println("w1 size:"+w1.size()); boolean flag=true; for (Map.Entry<String,Double> e1:w1.entrySet()) { //System.out.println("thread:"+Math.abs(e1.getValue() - w2.get(e1.getKey())) ); if (Math.abs(e1.getValue() - w2.get(e1.getKey()).getWeight()) >= 1e-4) // 收敛阀值0.01可自行调整 flag=false; } return flag; } /** * 计算模型期望,即在当前的特征函数的权重下,计算特征函数的模型期望值。 * @param modelE 储存空间,应当事先分配好内存(之所以不return一个modelE是为了避免重复分配内存) */ public void computeModeE(Map<String,Double> modelE) { modelE.clear(); double rate=1.0 / instanceList.size(); for (int i = 0; i < instanceList.size(); ++i) { Map<String,Integer> fieldMap = instanceList.get(i).fieldList;//no labels //计算当前样本X对应所有类别的概率 double[] pro = calProb(fieldMap); for (Map.Entry<String,Integer> e:fieldMap.entrySet()) { String insFeature=e.getKey(); int cnt=e.getValue(); for (int k = 0; k < labels.size(); k++) { String feature=labels.get(k)+":"+insFeature; if (weightMap.containsKey(feature)) { double delta=pro[k] * rate*cnt; modelE.put(feature, modelE.containsKey(feature) ? modelE.get(feature) + delta : delta); } } } } } // public class Mode implements Runnable // { // ConcurrentLinkedQueue<Integer> insQueue=new ConcurrentLinkedQueue<>(); // boolean flag=true; // List<Instance> i∂ // public void addIns(int i) // { // // } // // @Override // public void run() { // while(flag) // { // int ins=insQueue.poll(); // } // } // } /** * 计算p(y|x),此时的x指的是instance里的field * @param fieldList 实例的特征列表 * @return 该实例属于每个类别的概率 */ public double[] calProb(Map<String,Integer> fieldList) { double[] p = new double[labels.size()]; double sum = 0; // 正则化因子,保证概率和为1 for (int i = 0; i < labels.size(); ++i) { double weightSum = 0; String label=labels.get(i); for (String field : fieldList.keySet()) { String feature=label+":"+field; if (weightMap.containsKey(feature)) { weightSum += weightMap.get(feature).getWeight()*fieldList.get(field); } } if(weightSum>15) { weightSum=15; } p[i] = Math.exp(weightSum); sum += p[i]; } //System.out.println(); for (int i = 0; i < p.length; ++i) { p[i] /= sum; // if(Double.isNaN(p[i])) // { // System.out.println(p[i]); // } } return p; } /** * 一个观测实例,包含事件和时间发生的环境 */ class Instance implements Serializable { /** * 事件(类别),如Outdoor */ String label; /** * 事件发生的环境集合,如[Sunny, Happy] */ Map<String,Integer> fieldList = new HashMap<>(); public Instance(String label, Map<String,Integer>fieldList) { this.label = label; this.fieldList = fieldList; } } /** * 特征(二值函数) */ class Weight { double weight=0; int cnt=0; public void addWeight(double w) { weight+=(w); } public double getWeight() { return weight; } public void addCnt(int c) { cnt+=c; } public void setWeight(double weight) { this.weight = weight; } public int getCnt() { return cnt; } public void setCnt(int cnt) { this.cnt = cnt; } } int cnt=e.getValue(); for (int k = 0; k < labels.size(); k++) { String feature=labels.get(k)+":"+insFeature; if (weightMap.containsKey(feature)) { double delta=pro[k] * rate*cnt; modelE.put(feature, modelE.containsKey(feature) ? modelE.get(feature) + delta : delta); } } } } }
  • 最大熵模型
    1 引用
  • Java

    Java 是一种可以撰写跨平台应用软件的面向对象的程序设计语言,是由 Sun Microsystems 公司于 1995 年 5 月推出的。Java 技术具有卓越的通用性、高效性、平台移植性和安全性。

    3196 引用 • 8215 回帖

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...

推荐标签 标签

  • jsoup

    jsoup 是一款 Java 的 HTML 解析器,可直接解析某个 URL 地址、HTML 文本内容。它提供了一套非常省力的 API,可通过 DOM,CSS 以及类似于 jQuery 的操作方法来取出和操作数据。

    6 引用 • 1 回帖 • 487 关注
  • OneDrive
    2 引用 • 1 关注
  • SEO

    发布对别人有帮助的原创内容是最好的 SEO 方式。

    35 引用 • 200 回帖 • 24 关注
  • AWS
    11 引用 • 28 回帖 • 11 关注
  • 深度学习

    深度学习(Deep Learning)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高层抽象的算法。

    53 引用 • 40 回帖
  • TGIF

    Thank God It's Friday! 感谢老天,总算到星期五啦!

    289 引用 • 4492 回帖 • 652 关注
  • 链书

    链书(Chainbook)是 B3log 开源社区提供的区块链纸质书交易平台,通过 B3T 实现共享激励与价值链。可将你的闲置书籍上架到链书,我们共同构建这个全新的交易平台,让闲置书籍继续发挥它的价值。

    链书社

    链书目前已经下线,也许以后还有计划重制上线。

    14 引用 • 257 回帖
  • 心情

    心是产生任何想法的源泉,心本体会陷入到对自己本体不能理解的状态中,因为心能产生任何想法,不能分出对错,不能分出自己。

    59 引用 • 369 回帖
  • 浅吟主题

    Jeffrey Chen 制作的思源笔记主题,项目仓库:https://github.com/TCOTC/Whisper

    1 引用 • 28 回帖 • 1 关注
  • Mac

    Mac 是苹果公司自 1984 年起以“Macintosh”开始开发的个人消费型计算机,如:iMac、Mac mini、Macbook Air、Macbook Pro、Macbook、Mac Pro 等计算机。

    169 引用 • 595 回帖 • 2 关注
  • GraphQL

    GraphQL 是一个用于 API 的查询语言,是一个使用基于类型系统来执行查询的服务端运行时(类型系统由你的数据定义)。GraphQL 并没有和任何特定数据库或者存储引擎绑定,而是依靠你现有的代码和数据支撑。

    4 引用 • 3 回帖 • 4 关注
  • 房星科技

    房星网,我们不和没有钱的程序员谈理想,我们要让程序员又有理想又有钱。我们有雄厚的房地产行业线下资源,遍布昆明全城的 100 家门店、四千地产经纪人是我们坚实的后盾。

    6 引用 • 141 回帖 • 592 关注
  • Bug

    Bug 本意是指臭虫、缺陷、损坏、犯贫、窃听器、小虫等。现在人们把在程序中一些缺陷或问题统称为 bug(漏洞)。

    76 引用 • 1742 回帖
  • 博客

    记录并分享人生的经历。

    273 引用 • 2388 回帖
  • IBM

    IBM(国际商业机器公司)或万国商业机器公司,简称 IBM(International Business Machines Corporation),总公司在纽约州阿蒙克市。1911 年托马斯·沃森创立于美国,是全球最大的信息技术和业务解决方案公司,拥有全球雇员 30 多万人,业务遍及 160 多个国家和地区。

    17 引用 • 53 回帖 • 146 关注
  • OpenShift

    红帽提供的 PaaS 云,支持多种编程语言,为开发人员提供了更为灵活的框架、存储选择。

    14 引用 • 20 回帖 • 655 关注
  • 国际化

    i18n(其来源是英文单词 internationalization 的首末字符 i 和 n,18 为中间的字符数)是“国际化”的简称。对程序来说,国际化是指在不修改代码的情况下,能根据不同语言及地区显示相应的界面。

    8 引用 • 26 回帖 • 2 关注
  • 快应用

    快应用 是基于手机硬件平台的新型应用形态;标准是由主流手机厂商组成的快应用联盟联合制定;快应用标准的诞生将在研发接口、能力接入、开发者服务等层面建设标准平台;以平台化的生态模式对个人开发者和企业开发者全品类开放。

    15 引用 • 127 回帖
  • Angular

    AngularAngularJS 的新版本。

    26 引用 • 66 回帖 • 543 关注
  • JetBrains

    JetBrains 是一家捷克的软件开发公司,该公司位于捷克的布拉格,并在俄国的圣彼得堡及美国麻州波士顿都设有办公室,该公司最为人所熟知的产品是 Java 编程语言开发撰写时所用的集成开发环境:IntelliJ IDEA

    18 引用 • 54 回帖
  • V2EX

    V2EX 是创意工作者们的社区。这里目前汇聚了超过 400,000 名主要来自互联网行业、游戏行业和媒体行业的创意工作者。V2EX 希望能够成为创意工作者们的生活和事业的一部分。

    16 引用 • 236 回帖 • 267 关注
  • Notion

    Notion - The all-in-one workspace for your notes, tasks, wikis, and databases.

    10 引用 • 76 回帖
  • 互联网

    互联网(Internet),又称网际网络,或音译因特网、英特网。互联网始于 1969 年美国的阿帕网,是网络与网络之间所串连成的庞大网络,这些网络以一组通用的协议相连,形成逻辑上的单一巨大国际网络。

    99 引用 • 367 回帖
  • 强迫症

    强迫症(OCD)属于焦虑障碍的一种类型,是一组以强迫思维和强迫行为为主要临床表现的神经精神疾病,其特点为有意识的强迫和反强迫并存,一些毫无意义、甚至违背自己意愿的想法或冲动反反复复侵入患者的日常生活。

    15 引用 • 161 回帖
  • DNSPod

    DNSPod 建立于 2006 年 3 月份,是一款免费智能 DNS 产品。 DNSPod 可以为同时有电信、网通、教育网服务器的网站提供智能的解析,让电信用户访问电信的服务器,网通的用户访问网通的服务器,教育网的用户访问教育网的服务器,达到互联互通的效果。

    6 引用 • 26 回帖 • 533 关注
  • B3log

    B3log 是一个开源组织,名字来源于“Bulletin Board Blog”缩写,目标是将独立博客与论坛结合,形成一种新的网络社区体验,详细请看 B3log 构思。目前 B3log 已经开源了多款产品:SymSoloVditor思源笔记

    1063 引用 • 3455 回帖 • 165 关注
  • 导航

    各种网址链接、内容导航。

    43 引用 • 177 回帖