网上现有的最大熵模型,如:https://blog.csdn.net/nwpuwyk/article/details/37500371
该代码在训练环节性能较差,特征函数存储的结构也涉及较简单。
我在该版本基础上进行了改进,优化了特征函数的数据结构和训练代码。
/** * 样本数据集 */ List<Instance> instanceList = new ArrayList<Instance>(); /** * 特征列表,来自所有事件的统计结果 */
// Map<String,Feature> featureMap=new HashMap<>();
/**
* 每个特征的出现次数
*/
//Map<String,Integer> featureCountMap=new HashMap<>();
/**
* 事件(类别)集
*/
List labels = new ArrayList();
/**
* 每个特征函数的权重
*/
// double[] weight;
Map<String,Weight> weightMap=new HashMap<>();
/**
* 一个事件最多一共有多少种特征
*/
double learningRate=10;
int C;
Map<String,List> testInstance;
/**
* 样本数据集 */ List instanceList = new ArrayList();
/**
* 特征列表,来自所有事件的统计结果 */// Map featureMap=new HashMap<>();
/**
* 每个特征的出现次数 */ //Map featureCountMap=new HashMap<>();
/**
* 事件(类别)集 */ List labels = new ArrayList();
/**
* 每个特征函数的权重 */ // double[] weight;
Map<String,Weight> weightMap=new HashMap<>();
/**
* 一个事件最多一共有多少种特征 */ double learningRate=10;
int C;
* 训练模型 * @param maxIt 最大迭代次数 */public void train(int maxIt,String savePath) throws IOException { Map,Double> empiricalE = new HashMap<>(); // 经验期望 Map,Double> modelE = new HashMap<>(); // 模型期望 for (Map.Entry,Weight> e:weightMap.entrySet()) { double ratio=(double) e.getValue().getCnt() / instanceList.size(); empiricalE.put(e.getKey(),ratio); } Map,Double> lastWeight=new HashMap<>(); for (int i = 0; i < maxIt; ++i) { System.out.println("iter:"+i); computeModeE(modelE);//计算模型期望 System.out.println("model finish.updating..."); for (Map.Entry,Weight> e:weightMap.entrySet()) { //lastWeight[w] = weight[w]; lastWeight.put(e.getKey(),e.getValue().getWeight()); String f=e.getKey(); double delta=learningRate / C * Math.log(empiricalE.get(f)/ modelE.get(f)); weightMap.get(f).addWeight(delta); } System.out.println("saving iter:"+i); learningRate*=0.99; learningRate=learningRate<10?10:learningRate; saveParam(savePath+"ent_insopt.par"+i); if (checkConvergence(lastWeight, weightMap)) break; } } /** * 预测类别 * @param fieldList * @return */ public Pair, Double>[] predict(Map,Integer> fieldList) { double[] prob = calProb(fieldList); Pair, Double>[] pairResult = new Pair[prob.length]; for (int i = 0; i < prob.length; ++i) { pairResult[i] = new Pair, Double>(labels.get(i), prob[i]); } return pairResult; } /** * 检查是否收敛 * @param w1 * @param w2 * @return 是否收敛 */public boolean checkConvergence(Map,Double> w1, Map,Weight> w2) { System.out.println("w1 size:"+w1.size()); boolean flag=true; for (Map.Entry,Double> e1:w1.entrySet()) { //System.out.println("thread:"+Math.abs(e1.getValue() - w2.get(e1.getKey())) ); if (Math.abs(e1.getValue() - w2.get(e1.getKey()).getWeight()) >= 1e-4) // 收敛阀值0.01可自行调整 flag=false; } return flag; } /** * 计算模型期望,即在当前的特征函数的权重下,计算特征函数的模型期望值。 * @param modelE 储存空间,应当事先分配好内存(之所以不return一个modelE是为了避免重复分配内存) */public void computeModeE(Map,Double> modelE) { modelE.clear(); double rate=1.0 / instanceList.size(); for (int i = 0; i < instanceList.size(); ++i) { Map,Integer> fieldMap = instanceList.get(i).fieldList;//no labels //计算当前样本X对应所有类别的概率 double[] pro = calProb(fieldMap); for (Map.Entry,Integer> e:fieldMap.entrySet()) { String insFeature=e.getKey();/** * 训练模型 * @param maxIt 最大迭代次数 */ public void train(int maxIt,String savePath) throws IOException { Map<String,Double> empiricalE = new HashMap<>(); // 经验期望 Map<String,Double> modelE = new HashMap<>(); // 模型期望 for (Map.Entry<String,Weight> e:weightMap.entrySet()) { double ratio=(double) e.getValue().getCnt() / instanceList.size(); empiricalE.put(e.getKey(),ratio); } Map<String,Double> lastWeight=new HashMap<>(); for (int i = 0; i < maxIt; ++i) { System.out.println("iter:"+i); computeModeE(modelE);//计算模型期望 System.out.println("model finish.updating..."); for (Map.Entry<String,Weight> e:weightMap.entrySet()) { //lastWeight[w] = weight[w]; lastWeight.put(e.getKey(),e.getValue().getWeight()); String f=e.getKey(); double delta=learningRate / C * Math.log(empiricalE.get(f)/ modelE.get(f)); weightMap.get(f).addWeight(delta); } System.out.println("saving iter:"+i); learningRate*=0.99; learningRate=learningRate<10?10:learningRate; saveParam(savePath+"ent_insopt.par"+i); if (checkConvergence(lastWeight, weightMap)) break; } } /** * 预测类别 * @param fieldList * @return */ public Pair<String, Double>[] predict(Map<String,Integer> fieldList) { double[] prob = calProb(fieldList); Pair<String, Double>[] pairResult = new Pair[prob.length]; for (int i = 0; i < prob.length; ++i) { pairResult[i] = new Pair<String, Double>(labels.get(i), prob[i]); } return pairResult; } /** * 检查是否收敛 * @param w1 * @param w2 * @return 是否收敛 */ public boolean checkConvergence(Map<String,Double> w1, Map<String,Weight> w2) { System.out.println("w1 size:"+w1.size()); boolean flag=true; for (Map.Entry<String,Double> e1:w1.entrySet()) { //System.out.println("thread:"+Math.abs(e1.getValue() - w2.get(e1.getKey())) ); if (Math.abs(e1.getValue() - w2.get(e1.getKey()).getWeight()) >= 1e-4) // 收敛阀值0.01可自行调整 flag=false; } return flag; } /** * 计算模型期望,即在当前的特征函数的权重下,计算特征函数的模型期望值。 * @param modelE 储存空间,应当事先分配好内存(之所以不return一个modelE是为了避免重复分配内存) */ public void computeModeE(Map<String,Double> modelE) { modelE.clear(); double rate=1.0 / instanceList.size(); for (int i = 0; i < instanceList.size(); ++i) { Map<String,Integer> fieldMap = instanceList.get(i).fieldList;//no labels //计算当前样本X对应所有类别的概率 double[] pro = calProb(fieldMap); for (Map.Entry<String,Integer> e:fieldMap.entrySet()) { String insFeature=e.getKey(); int cnt=e.getValue(); for (int k = 0; k < labels.size(); k++) { String feature=labels.get(k)+":"+insFeature; if (weightMap.containsKey(feature)) { double delta=pro[k] * rate*cnt; modelE.put(feature, modelE.containsKey(feature) ? modelE.get(feature) + delta : delta); } } } } } // public class Mode implements Runnable // { // ConcurrentLinkedQueue<Integer> insQueue=new ConcurrentLinkedQueue<>(); // boolean flag=true; // List<Instance> i∂ // public void addIns(int i) // { // // } // // @Override // public void run() { // while(flag) // { // int ins=insQueue.poll(); // } // } // } /** * 计算p(y|x),此时的x指的是instance里的field * @param fieldList 实例的特征列表 * @return 该实例属于每个类别的概率 */ public double[] calProb(Map<String,Integer> fieldList) { double[] p = new double[labels.size()]; double sum = 0; // 正则化因子,保证概率和为1 for (int i = 0; i < labels.size(); ++i) { double weightSum = 0; String label=labels.get(i); for (String field : fieldList.keySet()) { String feature=label+":"+field; if (weightMap.containsKey(feature)) { weightSum += weightMap.get(feature).getWeight()*fieldList.get(field); } } if(weightSum>15) { weightSum=15; } p[i] = Math.exp(weightSum); sum += p[i]; } //System.out.println(); for (int i = 0; i < p.length; ++i) { p[i] /= sum; // if(Double.isNaN(p[i])) // { // System.out.println(p[i]); // } } return p; } /** * 一个观测实例,包含事件和时间发生的环境 */ class Instance implements Serializable { /** * 事件(类别),如Outdoor */ String label; /** * 事件发生的环境集合,如[Sunny, Happy] */ Map<String,Integer> fieldList = new HashMap<>(); public Instance(String label, Map<String,Integer>fieldList) { this.label = label; this.fieldList = fieldList; } } /** * 特征(二值函数) */ class Weight { double weight=0; int cnt=0; public void addWeight(double w) { weight+=(w); } public double getWeight() { return weight; } public void addCnt(int c) { cnt+=c; } public void setWeight(double weight) { this.weight = weight; } public int getCnt() { return cnt; } public void setCnt(int cnt) { this.cnt = cnt; } } int cnt=e.getValue(); for (int k = 0; k < labels.size(); k++) { String feature=labels.get(k)+":"+insFeature; if (weightMap.containsKey(feature)) { double delta=pro[k] * rate*cnt; modelE.put(feature, modelE.containsKey(feature) ? modelE.get(feature) + delta : delta); } } } } }
欢迎来到这里!
我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。
注册 关于