网上现有的最大熵模型,如:https://blog.csdn.net/nwpuwyk/article/details/37500371
该代码在训练环节性能较差,特征函数存储的结构也涉及较简单。
我在该版本基础上进行了改进,优化了特征函数的数据结构和训练代码。
/**
* 样本数据集
*/
List<Instance> instanceList = new ArrayList<Instance>();
/**
* 特征列表,来自所有事件的统计结果
*/
// Map<String,Feature> featureMap=new HashMap<>();
/**
* 每个特征的出现次数
*/
//Map<String,Integer> featureCountMap=new HashMap<>();
/**
* 事件(类别)集
*/
List labels = new ArrayList();
/**
* 每个特征函数的权重
*/
// double[] weight;
Map<String,Weight> weightMap=new HashMap<>();
/**
* 一个事件最多一共有多少种特征
*/
double learningRate=10;
int C;
Map<String,List> testInstance;
/**
* 样本数据集 */ List instanceList = new ArrayList();
/**
* 特征列表,来自所有事件的统计结果 */// Map featureMap=new HashMap<>();
/**
* 每个特征的出现次数 */ //Map featureCountMap=new HashMap<>();
/**
* 事件(类别)集 */ List labels = new ArrayList();
/**
* 每个特征函数的权重 */ // double[] weight;
Map<String,Weight> weightMap=new HashMap<>();
/**
* 一个事件最多一共有多少种特征 */ double learningRate=10;
int C;
* 训练模型 * @param maxIt 最大迭代次数
*/public void train(int maxIt,String savePath) throws IOException {
Map,Double> empiricalE = new HashMap<>(); // 经验期望
Map,Double> modelE = new HashMap<>(); // 模型期望
for (Map.Entry,Weight> e:weightMap.entrySet())
{
double ratio=(double) e.getValue().getCnt() / instanceList.size();
empiricalE.put(e.getKey(),ratio);
}
Map,Double> lastWeight=new HashMap<>();
for (int i = 0; i < maxIt; ++i)
{
System.out.println("iter:"+i);
computeModeE(modelE);//计算模型期望
System.out.println("model finish.updating...");
for (Map.Entry,Weight> e:weightMap.entrySet())
{
//lastWeight[w] = weight[w];
lastWeight.put(e.getKey(),e.getValue().getWeight());
String f=e.getKey();
double delta=learningRate / C * Math.log(empiricalE.get(f)/ modelE.get(f));
weightMap.get(f).addWeight(delta);
}
System.out.println("saving iter:"+i);
learningRate*=0.99;
learningRate=learningRate<10?10:learningRate;
saveParam(savePath+"ent_insopt.par"+i);
if (checkConvergence(lastWeight, weightMap)) break;
}
}
/**
* 预测类别 * @param fieldList
* @return
*/
public Pair, Double>[] predict(Map,Integer> fieldList)
{
double[] prob = calProb(fieldList);
Pair, Double>[] pairResult = new Pair[prob.length];
for (int i = 0; i < prob.length; ++i)
{
pairResult[i] = new Pair, Double>(labels.get(i), prob[i]);
}
return pairResult;
}
/**
* 检查是否收敛 * @param w1
* @param w2
* @return 是否收敛
*/public boolean checkConvergence(Map,Double> w1, Map,Weight> w2)
{
System.out.println("w1 size:"+w1.size());
boolean flag=true;
for (Map.Entry,Double> e1:w1.entrySet())
{
//System.out.println("thread:"+Math.abs(e1.getValue() - w2.get(e1.getKey())) );
if (Math.abs(e1.getValue() - w2.get(e1.getKey()).getWeight()) >= 1e-4) // 收敛阀值0.01可自行调整
flag=false;
}
return flag;
}
/**
* 计算模型期望,即在当前的特征函数的权重下,计算特征函数的模型期望值。 * @param modelE 储存空间,应当事先分配好内存(之所以不return一个modelE是为了避免重复分配内存)
*/public void computeModeE(Map,Double> modelE)
{
modelE.clear();
double rate=1.0 / instanceList.size();
for (int i = 0; i < instanceList.size(); ++i)
{
Map,Integer> fieldMap = instanceList.get(i).fieldList;//no labels
//计算当前样本X对应所有类别的概率 double[] pro = calProb(fieldMap);
for (Map.Entry,Integer> e:fieldMap.entrySet())
{
String insFeature=e.getKey();/**
* 训练模型
* @param maxIt 最大迭代次数
*/
public void train(int maxIt,String savePath) throws IOException {
Map<String,Double> empiricalE = new HashMap<>(); // 经验期望
Map<String,Double> modelE = new HashMap<>(); // 模型期望
for (Map.Entry<String,Weight> e:weightMap.entrySet())
{
double ratio=(double) e.getValue().getCnt() / instanceList.size();
empiricalE.put(e.getKey(),ratio);
}
Map<String,Double> lastWeight=new HashMap<>();
for (int i = 0; i < maxIt; ++i)
{
System.out.println("iter:"+i);
computeModeE(modelE);//计算模型期望
System.out.println("model finish.updating...");
for (Map.Entry<String,Weight> e:weightMap.entrySet())
{
//lastWeight[w] = weight[w];
lastWeight.put(e.getKey(),e.getValue().getWeight());
String f=e.getKey();
double delta=learningRate / C * Math.log(empiricalE.get(f)/ modelE.get(f));
weightMap.get(f).addWeight(delta);
}
System.out.println("saving iter:"+i);
learningRate*=0.99;
learningRate=learningRate<10?10:learningRate;
saveParam(savePath+"ent_insopt.par"+i);
if (checkConvergence(lastWeight, weightMap)) break;
}
}
/**
* 预测类别
* @param fieldList
* @return
*/
public Pair<String, Double>[] predict(Map<String,Integer> fieldList)
{
double[] prob = calProb(fieldList);
Pair<String, Double>[] pairResult = new Pair[prob.length];
for (int i = 0; i < prob.length; ++i)
{
pairResult[i] = new Pair<String, Double>(labels.get(i), prob[i]);
}
return pairResult;
}
/**
* 检查是否收敛
* @param w1
* @param w2
* @return 是否收敛
*/
public boolean checkConvergence(Map<String,Double> w1, Map<String,Weight> w2)
{
System.out.println("w1 size:"+w1.size());
boolean flag=true;
for (Map.Entry<String,Double> e1:w1.entrySet())
{
//System.out.println("thread:"+Math.abs(e1.getValue() - w2.get(e1.getKey())) );
if (Math.abs(e1.getValue() - w2.get(e1.getKey()).getWeight()) >= 1e-4) // 收敛阀值0.01可自行调整
flag=false;
}
return flag;
}
/**
* 计算模型期望,即在当前的特征函数的权重下,计算特征函数的模型期望值。
* @param modelE 储存空间,应当事先分配好内存(之所以不return一个modelE是为了避免重复分配内存)
*/
public void computeModeE(Map<String,Double> modelE)
{
modelE.clear();
double rate=1.0 / instanceList.size();
for (int i = 0; i < instanceList.size(); ++i)
{
Map<String,Integer> fieldMap = instanceList.get(i).fieldList;//no labels
//计算当前样本X对应所有类别的概率
double[] pro = calProb(fieldMap);
for (Map.Entry<String,Integer> e:fieldMap.entrySet())
{
String insFeature=e.getKey();
int cnt=e.getValue();
for (int k = 0; k < labels.size(); k++)
{
String feature=labels.get(k)+":"+insFeature;
if (weightMap.containsKey(feature)) {
double delta=pro[k] * rate*cnt;
modelE.put(feature, modelE.containsKey(feature) ? modelE.get(feature) + delta : delta);
}
}
}
}
}
// public class Mode implements Runnable
// {
// ConcurrentLinkedQueue<Integer> insQueue=new ConcurrentLinkedQueue<>();
// boolean flag=true;
// List<Instance> i∂
// public void addIns(int i)
// {
//
// }
//
// @Override
// public void run() {
// while(flag)
// {
// int ins=insQueue.poll();
// }
// }
// }
/**
* 计算p(y|x),此时的x指的是instance里的field
* @param fieldList 实例的特征列表
* @return 该实例属于每个类别的概率
*/
public double[] calProb(Map<String,Integer> fieldList)
{
double[] p = new double[labels.size()];
double sum = 0; // 正则化因子,保证概率和为1
for (int i = 0; i < labels.size(); ++i)
{
double weightSum = 0;
String label=labels.get(i);
for (String field : fieldList.keySet())
{
String feature=label+":"+field;
if (weightMap.containsKey(feature)) {
weightSum += weightMap.get(feature).getWeight()*fieldList.get(field);
}
}
if(weightSum>15)
{
weightSum=15;
}
p[i] = Math.exp(weightSum);
sum += p[i];
}
//System.out.println();
for (int i = 0; i < p.length; ++i)
{
p[i] /= sum;
// if(Double.isNaN(p[i]))
// {
// System.out.println(p[i]);
// }
}
return p;
}
/**
* 一个观测实例,包含事件和时间发生的环境
*/
class Instance implements Serializable
{
/**
* 事件(类别),如Outdoor
*/
String label;
/**
* 事件发生的环境集合,如[Sunny, Happy]
*/
Map<String,Integer> fieldList = new HashMap<>();
public Instance(String label, Map<String,Integer>fieldList)
{
this.label = label;
this.fieldList = fieldList;
}
}
/**
* 特征(二值函数)
*/
class Weight
{
double weight=0;
int cnt=0;
public void addWeight(double w)
{
weight+=(w);
}
public double getWeight() {
return weight;
}
public void addCnt(int c)
{
cnt+=c;
}
public void setWeight(double weight) {
this.weight = weight;
}
public int getCnt() {
return cnt;
}
public void setCnt(int cnt) {
this.cnt = cnt;
}
}
int cnt=e.getValue();
for (int k = 0; k < labels.size(); k++)
{
String feature=labels.get(k)+":"+insFeature;
if (weightMap.containsKey(feature)) {
double delta=pro[k] * rate*cnt;
modelE.put(feature, modelE.containsKey(feature) ? modelE.get(feature) + delta : delta);
}
}
}
}
}
欢迎来到这里!
我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。
注册 关于