数据集的格式
Simulations/Linear_CA/data/decimated_dt1e-2_T100_r0_randnInit.pt
保存这段数据的代码:
torch.save([train_input, train_target, cv_input, cv_target, test_input, test_target,train_init, cv_init, test_init], fileName)
- train_input:train_batch_size,self.H.size()[0],T_MAX——记录观测值
- train_target:train_batch_size,self.F.size()[0],T_MAX——记录状态
- cv_input:cv_batch_size,self.H.size()[0],T_MAX
- cv_target:cv_batch_size,self.F.size()[0],T_MAX
- test_input:cv_batch_size,self.H.size()[0],T_MAX
- test_target:cv_batch_size,self.F.size()[0],T_MAX
- train_init:train_batch_size,self.F.size()[0],1——初始状态
- cv_init:cv_batch_size,self.F.size()[0],1
- test_init:cv_batch_size,self.F.size()[0],1
数据的再转换
构造真实状态对应观测的数据集,然后一些计算在 Kalmannet 里进行。
def step_KGain_est(self, y): # both in size [batch_size, n] obs_diff = torch.squeeze(y,2) - torch.squeeze(self.y_previous,2) obs_innov_diff = torch.squeeze(y,2) - torch.squeeze(self.m1y,2) # both in size [batch_size, m] fw_evol_diff = torch.squeeze(self.m1x_posterior,2) - torch.squeeze(self.m1x_posterior_previous,2) fw_update_diff = torch.squeeze(self.m1x_posterior,2) - torch.squeeze(self.m1x_prior_previous,2) obs_diff = func.normalize(obs_diff, p=2, dim=1, eps=1e-12, out=None) obs_innov_diff = func.normalize(obs_innov_diff, p=2, dim=1, eps=1e-12, out=None) fw_evol_diff = func.normalize(fw_evol_diff, p=2, dim=1, eps=1e-12, out=None) fw_update_diff = func.normalize(fw_update_diff, p=2, dim=1, eps=1e-12, out=None) # Kalman Gain Network Step KG = self.KGain_step(obs_diff, obs_innov_diff, fw_evol_diff, fw_update_diff) # Reshape Kalman Gain to a Matrix self.KGain = torch.reshape(KG, (self.batch_size, self.m, self.n))
欢迎来到这里!
我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。
注册 关于