from glasflow import RealNVP import matplotlib.pyplot as plt import matplotlib.patches as patches import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.distributions import Normal, Uniform from itertools import chain, permutations from datetime import datetime import pickle from scipy.interpolate import RectBivariateSpline from scipy.integrate import quad from scipy.special import logsumexp import time import shutil # for JS diveregcen calculation from scipy.spatial.distance import jensenshannon from scipy.stats import gaussian_kde from collections import namedtuple import seaborn as sns import os from scipy.stats import uniform, norm, chi from scipy.special import expit, log_wright_bessel from nessai.flowsampler import FlowSampler from nessai.model import Model from nessai.utils import setup_logger from nessai.livepoint import dict_to_live_points import corner ###################################### # TO-DO LIST (in no particular order) # # 1. Widen frequency band - go to longer observation times - towards 1 day per segment # 2. [DONE] Add declination # 3. [DONE] Merge amplitude and alpha into 2 Euclidian parameters (wrap around sky) # 4. Save JS values to file and plot comparisons # 5. Add attention head to conditional data compression # 6. Move to 16,32, ... segments # 7. [DONE] Switch to properly heterodyned data # 8. Solve issue of Annual frequency variation > band [ADD EXTRA SEGMENTS IN FREQUENCY!!] # 9. [DONE] Do importance sampling for the final epoch on the reordered samples # 10. [DONE} Compute Bayesian evidence # 11. [DONE] Print Bayes factor with 2 decimal places # 12. [DONE] Do 1/2 of test cases with no signal # 13. Add corner plots for the final result # 14. Make scatter plots into contour plots # 15. Plot JS values as a function of segment # 16. [DONE] Add Bayes factor using samples from prior # 17. [DONE] Plot Bayes factors seed = 3 torch.manual_seed(seed) np.random.seed(seed) sns.set_context("notebook") sns.set_palette("colorblind") device = "cuda" date_id = datetime.today().isoformat() plot_path = '/data/www.astro/chrism/uroboros/{}'.format(date_id) ns_path = './' try: os.mkdir('{}'.format(plot_path)) except: print('unable to make output directory {}'.format(plot_path)) exit(1) shutil.copyfile('./uroboros.py', '{}/uroboros.txt'.format(plot_path)) iterations = 200000 batch_size = 1024 plot_step = 10000 n_test = 8 # number of individual test data samples n_prior = 128 # number of samples used to represent the conditonal prior n_max_list = [1,2,4,8,16] n_max = n_max_list[-1] # max the number of measurements per sample glob_n_max = n_max n_par = 4 # the number of hyperparameters n_meas = 128 # the size of a measurement n_prior_cc_out = 32 # the size of compressed prior n_meas_cc_out = 16 # the size of compressed measurement n_sigma = 1.0 # the scale of the distance noise n_post = 3000 # the number of posterior samples used for plotting lr = 1e-3 # the learning rate run_id = 'cw_het_s{}_nmax{}_ntest{}'.format(seed,n_max,n_test) Nreorder = 10 multiflow = False multiflow_n_max = 1 if multiflow: multiflow_n_max = n_max # define the Flow that will estimate the parameters conditional on new data and compressed prior information flow = [] for i in range(multiflow_n_max): flow.append(RealNVP( n_inputs=n_par, # number of params n_transforms=4, n_conditional_inputs=n_prior_cc_out+n_meas_cc_out + 1, # size of compressed prior plus size of measurement plus segment index n_neurons=32, n_blocks_per_transform=2, batch_norm_within_blocks=True, linear_transform='permutation', batch_norm_between_transforms=True, ).to(device)) print(f"Created flow and sent to {device}") class MAB(nn.Module): def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): super(MAB, self).__init__() self.dim_V = dim_V self.num_heads = num_heads class MAB(nn.Module): def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): super(MAB, self).__init__() self.dim_V = dim_V self.num_heads = num_heads self.fc_q = nn.Linear(dim_Q, dim_V) self.fc_k = nn.Linear(dim_K, dim_V) self.fc_v = nn.Linear(dim_K, dim_V) if ln: self.ln0 = nn.LayerNorm(dim_V) self.ln1 = nn.LayerNorm(dim_V) self.fc_o = nn.Linear(dim_V, dim_V) def forward(self, Q, K): Q = self.fc_q(Q) K, V = self.fc_k(K), self.fc_v(K) dim_split = self.dim_V // self.num_heads Q_ = torch.cat(Q.split(dim_split, 2), 0) K_ = torch.cat(K.split(dim_split, 2), 0) V_ = torch.cat(V.split(dim_split, 2), 0) A = torch.softmax(Q_.bmm(K_.transpose(1,2))/np.sqrt(self.dim_V), 2) O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2) O = O if getattr(self, 'ln0', None) is None else self.ln0(O) O = O + F.relu(self.fc_o(O)) O = O if getattr(self, 'ln1', None) is None else self.ln1(O) return O class SAB(nn.Module): def __init__(self, dim_in, dim_out, num_heads, ln=False): super(SAB, self).__init__() self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln) def forward(self, X): return self.mab(X, X) class ISAB(nn.Module): def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False): super(ISAB, self).__init__() self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out)) nn.init.xavier_uniform_(self.I) self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln) self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln) def forward(self, X): H = self.mab0(self.I.repeat(X.size(0), 1, 1), X) return self.mab1(X, H) class PMA(nn.Module): def __init__(self, dim, num_heads, num_seeds, ln=False): super(PMA, self).__init__() self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim)) nn.init.xavier_uniform_(self.S) self.mab = MAB(dim, dim, dim, num_heads, ln=ln) def forward(self, X): return self.mab(self.S.repeat(X.size(0), 1, 1), X) class SmallSetTransformer(nn.Module): def __init__(self,): super().__init__() self.enc = nn.Sequential( SAB(dim_in=n_par, dim_out=32, num_heads=4), SAB(dim_in=32, dim_out=32, num_heads=4), ) self.dec = nn.Sequential( PMA(dim=32, num_heads=4, num_seeds=1), nn.Linear(in_features=32, out_features=n_prior_cc_out), nn.Sigmoid(), # I added this ) def forward(self, x): x = self.enc(x) x = self.dec(x) return x.squeeze(-2) class PI_NeuralNetwork(nn.Module): def __init__(self): super().__init__() self.flatten = nn.Flatten() self.nn1 = nn.Sequential( nn.Linear(n_par, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 64), ) self.nn2 = nn.Sequential( nn.Linear(n_par + 64, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, n_prior_cc_out), ) def forward(self, x0): # the input data has shape (bs,n_prior,n_par+1) # we want to make this (bs*n_prior,n_par) x = x0.flatten(0,1) # new shape (batch*n_prior,ndim) x = self.nn1(x) # output shape (batch*n_prior,64) x = x.reshape(-1,n_prior,64) # reshape to (batch,n_prior,64) x = torch.concat([x,x0],dim=2) # concat to get (batch, n_prior,64+ndim) x = torch.mean(x,dim=1) # take mean to get (batch,64+ndim) x = self.nn2(x) # process again to get (batch,n_prior_cc_out) return x cc_prior_model = [] for i in range(multiflow_n_max): cc_prior_model.append(SmallSetTransformer().to(device)) # the compression model that takes measurement samples and compresses them cc_meas_model = nn.Sequential( nn.Linear(n_meas, 32), nn.ReLU(), nn.Linear(32, 32), nn.ReLU(), #nn.Linear(64, 64), #nn.ReLU(), nn.Linear(32, n_meas_cc_out), nn.Sigmoid() ) cc_meas_model.to(device) def calc_median_error(jsvalues, quantiles=(0.16, 0.84)): quants_to_compute = np.array([quantiles[0], 0.5, quantiles[1]]) quants = np.percentile(jsvalues, quants_to_compute * 100) summary = namedtuple("summary", ["median", "lower", "upper"]) summary.median = quants[1] summary.plus = quants[2] - summary.median summary.minus = summary.median - quants[0] return summary def calculate_js(samplesA, samplesB, ntests=10, xsteps=100): js_array = np.zeros(ntests) for j in range(ntests): nsamples = min([len(samplesA), len(samplesB)]) A = np.random.choice(samplesA, size=nsamples, replace=False) B = np.random.choice(samplesB, size=nsamples, replace=False) xmin = np.min([np.min(A), np.min(B)]) xmax = np.max([np.max(A), np.max(B)]) x = np.linspace(xmin, xmax, xsteps) A_pdf = gaussian_kde(A)(x) B_pdf = gaussian_kde(B)(x) js_array[j] = np.nan_to_num(np.power(jensenshannon(A_pdf, B_pdf), 2)) return calc_median_error(js_array) #def sig(Asin,Acos,f0,t,i): # """ # phi0 - the phase normalise between 0 and 1 # f0 - the frequency normalised in reference to the nyquist frequency (0-1) # N - the number of samples in teh timeseries # i - the index of the timeseries # """ # f = (0.4 + 0.2*f0) # fraction of the Nyquist frequency (0.4 - 0.6) # #return 0.25*np.sin(2.0*np.pi*(phi0 + f*0.5*(t + i))) # phase = 2.0*np.pi*(f*0.5*(t + i)) # return 0.25*(Asin*np.cos(phase) + Acos*np.sin(phase)) class cw_model(Model): """A simple two-dimensional Gaussian likelihood.""" def __init__(self,n_meas,n_max,d=None,i_ref=0): # Names of parameters to sample self.n_meas = n_meas # the size of the measurement self.N = int(self.n_meas/2) # the heterodyned and downsampled complex timeseries length self.dvec = None if d is not None: self.dvec = torch.from_numpy(d) # the input measured data self.T = 1800.0 # the segment length (sec) self.i_ref = i_ref # the starting segment index self.fs = self.N/self.T # the sampling rate in the heterodyned data (Hz) self.df = 1.0/self.T # the frequency resolution (Hz) self.band = self.df*self.N # bandwidth - same as the sampling frequency (Hz) self.dt = 1.0/self.fs # the sampling time (sec) self.t = torch.arange(self.N)*self.dt # define the time vector self.fmax = 0.5/self.dt # the Nyquist frequency (Hz) self.fhet = 100.0 # the heterodyne frequency (Hz) self.n_max = n_max # the max number of segments self.n_sigma = 1.0 # the noise standard deviation self.SNR_factor = 0.25 # a scaling that controls the average SNR self.n_par = 4 # number of final parameters self.asini = 6e6/3e8 # the Earth spin semi-major axis (sec) self.Omega = 2.0*np.pi/86400.0 # the Earth spin angular frequency (rads/sec) self.names = ["Ax", "Ay", "Az", "f0"] # the parameter names self.eps = 1e-6 # needed to avoid divide by zero when no signal self.plot_bounds = [[-3,3],[-3,3],[-3,3],[0,1]] self.bounds = {"Ax": [-np.inf, np.inf], "Ay": [-np.inf, np.inf], "Az": [-np.inf, np.inf], "f0": [0,1]} self.repar_bounds = [[0,3],[0,1],[0,1],[-0.5,0.5]] if d is not None: self.logZ_noise = self.log_likelihood_null() print(self.logZ_noise) def sig(self,Ax,Ay,Az,f0,phase,t,i): """ Ax,Ay,Az - the signal amplitudes in x,y,z from which the total amplitude and sky position are produced f0 - the frequency normalised in reference to the nyquist frequency (0-1) phase - the signal phase (cycles) t - the timeseries i - the index of the segment """ A = torch.sqrt(Ax**2 + Ay**2 + Az**2) # the signal amplitude (chi-squared with 3 dof) alpha = torch.remainder(torch.atan2(Ay,Ax),2*np.pi)/(2.0*np.pi) # the RA of the source (rads) delta = torch.atan(Az/torch.sqrt(Ax**2 + Ay**2 + self.eps))/np.pi # the dec of the source (rads) f = self.fhet + (f0-0.5)*0.2*self.band # the frequency of the source (Hz) t = t + i*self.T # offset the time vector for the chosen segment # define the signal phase phi = 2.0*np.pi*((f-self.fhet)*t + f*self.asini*torch.cos(np.pi*delta)*torch.sin(2.0*np.pi*alpha + self.Omega*t) + phase) # return the real and imag parts concatenated, plus phi return self.SNR_factor*A*torch.sin(phi),self.SNR_factor*A*torch.cos(phi), phi def gen_pars(self,n_data,noise_only_frac=0.0): """ n_data - the number of signals to generate parameters for """ noise_only = (torch.rand(size=(n_data,1))>noise_only_frac).long().to(device) Ax = noise_only*(torch.randn(size=(n_data,1))).to(device) Ay = noise_only*(torch.randn(size=(n_data,1))).to(device) Az = noise_only*(torch.randn(size=(n_data,1))).to(device) f0 = (torch.rand(size=(n_data,1))).to(device) phase = (torch.rand(size=(n_data,1))).to(device) Omega = (torch.concatenate((Ax,Ay,Az,f0,phase),axis=1)).to(device) return Omega def gen_sig(self,n_data,n_max,noise_only_frac=0.0): Omega = self.gen_pars(n_data,noise_only_frac=noise_only_frac) Ax = Omega[:,0] Ay = Omega[:,1] Az = Omega[:,2] f0 = Omega[:,3] phase = Omega[:,4] # make complex noise n = (self.n_sigma*torch.randn(size=(n_data,self.n_meas,n_max)))#.to(device) # generate the signal - the inputs all have shape (n_data,N,n_max). The output has the same shape. meas_real, meas_imag, _ = self.sig(torch.tile(Ax.reshape(n_data,1,1),(1,self.N,n_max)).flatten().cpu(), torch.tile(Ay.reshape(n_data,1,1),(1,self.N,n_max)).flatten().cpu(), torch.tile(Az.reshape(n_data,1,1),(1,self.N,n_max)).flatten().cpu(), torch.tile(f0.reshape(n_data,1,1),(1,self.N,n_max)).flatten().cpu(), torch.tile(phase.reshape(n_data,1,1),(1,self.N,n_max)).flatten().cpu(), torch.tile(self.dt*torch.arange(self.N).reshape(1,self.N,1),(n_data,1,n_max)).flatten().cpu(), torch.tile(torch.arange(n_max).reshape(1,1,n_max),(n_data,self.N,1)).flatten().cpu()) meas_real = meas_real.reshape(n_data,self.N,n_max)#.to(device) meas_imag = meas_imag.reshape(n_data,self.N,n_max)#.to(device) # measured data is now the real part forllowed by the imaginary part in a 1D vector + noise meas = torch.concatenate([meas_real,meas_imag],axis=1) + n meas = meas.type(torch.cuda.FloatTensor) return meas, Omega, n def new_point(self, N=1): """Draw n points. This is used for the initial sampling. Points do not need to be drawn from the exact prior but algorithm will be more efficient if they are. """ # There are various ways to create live points in nessai, such as # from dictionaries and numpy arrays. See nessai.livepoint for options d = { "Ax": norm().rvs(size=N), "Ay": norm().rvs(size=N), "Az": norm().rvs(size=N), "f0": uniform.rvs(loc=self.bounds["f0"][0], scale=self.bounds["f0"][1] - self.bounds["f0"][0], size=N), } #for i in range(self.n_max): # d["phase{}".format(i)] = uniform.rvs(loc=self.bounds["phase{}".format(i)][0], scale=self.bounds["phase{}".format(i)][1] - self.bounds["phase{}".format(i)][0], size=N) return dict_to_live_points(d) def new_point_log_prob(self, x): """Returns the log-probability for a new point. Since we have redefined `new_point` we also need to redefine this function. """ return self.log_prior(x) def log_prior(self, x): """ Returns log of prior given a live point assuming uniform priors on each parameter. """ # Check if values are in bounds, returns True/False # Then take the log to get 0/-inf and make sure the dtype is float log_p = np.log(self.in_bounds(x), dtype="float") # Iterate through each parameter (x and y) # since the live points are a structured array we can # get each value using just the name log_p += norm().logpdf(x["Ax"]) log_p += norm().logpdf(x["Ay"]) log_p += norm().logpdf(x["Az"]) log_p -= np.log(self.bounds["f0"][1] - self.bounds["f0"][0]) # uniform prior return log_p def log_likelihood(self, x): """ Returns log likelihood of given live point assuming a Gaussian likelihood. """ log_l = 0.0 # initialise the log likelihood Ax = torch.tensor(np.array(x["Ax"]).reshape(-1)) Ay = torch.tensor(np.array(x["Ay"]).reshape(-1)) Az = torch.tensor(np.array(x["Az"]).reshape(-1)) f0 = torch.tensor(np.array(x["f0"]).reshape(-1)) #alpha = np.array(x["alpha"]).reshape(-1) A = torch.sqrt(Ax**2 + Ay**2 + Az**2)#.to(device) N = A.shape[0] for i,d in enumerate(self.dvec): # loop over measurements # get the signal - only need the phase for the marginalised likelihood #_, _, phi = self.sig(torch.tile(torch.from_numpy(np.array(x["Ax"])).reshape(-1,1),(1,self.N)).flatten().cpu(), # torch.tile(torch.from_numpy(np.array(x["Ay"])).reshape(-1,1),(1,self.N)).flatten().cpu(), # torch.tile(torch.from_numpy(np.array(x["Az"])).reshape(-1,1),(1,self.N)).flatten().cpu(), # torch.tile(torch.from_numpy(np.array(x["f0"])).reshape(-1,1),(1,self.N)).flatten().cpu(), # torch.tile(torch.zeros(1).reshape(1,1,1),(N,self.N)).flatten().cpu(), # torch.tile(self.dt*torch.arange(self.N).reshape(1,self.N),(N,1)).flatten().cpu(), # torch.tile(i*torch.ones(1).reshape(1,1,1),(N,self.N)).flatten().cpu()) _, _, phi = self.sig(torch.tile(Ax.reshape(-1,1),(1,self.N)).flatten(),#.to(device), torch.tile(Ay.reshape(-1,1),(1,self.N)).flatten(),#.to(device), torch.tile(Az.reshape(-1,1),(1,self.N)).flatten(),#.to(device), torch.tile(f0.reshape(-1,1),(1,self.N)).flatten(),#.to(device), torch.tile(torch.zeros(1).reshape(1,1,1),(N,self.N)).flatten(),#.to(device), torch.tile(self.dt*torch.arange(self.N).reshape(1,self.N),(N,1)).flatten(),#.to(device), torch.tile(i*torch.ones(1).reshape(1,1,1),(N,self.N)).flatten())#.to(device)) phi = phi.reshape(N,self.N) # format the measured data for the likelihood Xreal = d[:self.N]#.to(device) Ximag = d[self.N:]#.to(device) sumXsq = torch.sum(Xreal**2 + Ximag**2,dim=-1)#.to(device) W_arg = 0.25*((A*self.SNR_factor/self.n_sigma**2)**2)*(torch.sum(Xreal*torch.sin(phi) + Ximag*torch.cos(phi),dim=-1)**2 + torch.sum(Xreal*torch.cos(phi) - Ximag*torch.sin(phi),dim=-1)**2) W_arg = W_arg.cpu().numpy() # the phase marginalised log likelihood log_l += -2.0*self.N*np.log(self.n_sigma) - self.N*np.log(2.0*np.pi) - 0.5*sumXsq/self.n_sigma**2 - 0.5*self.N*(A*self.SNR_factor/self.n_sigma)**2 + torch.from_numpy(log_wright_bessel(1,1,W_arg))#.to(device) #log_l += -0.25*n_meas*(A*self.SNR_factor/self.n_sigma)**2 + log_wright_bessel(1,1,0.25*((A*self.SNR_factor/self.n_sigma**2)**2)*(np.sum(d*np.sin(phi),axis=-1)**2 + np.sum(d*np.cos(phi),axis=-1)**2)) return log_l.cpu().numpy() def log_likelihood_null(self): """ Returns log likelihood for the noise model """ log_l = 0.0 # initialise the log likelihood for i,d in enumerate(self.dvec): # loop over measurements #s = self.sig(A,f0,phase,torch.arange(self.n_meas),idx) log_l += np.sum(norm.logpdf(0,loc=d,scale=self.n_sigma)) return log_l #def log_likelihood(self, x): # """ # Returns log likelihood of given live point assuming a Gaussian # likelihood. # """ # log_l = 0.0 # for i,d in enumerate(self.dvec): # s = sig(x["Asin"],x["Acos"],x["f0"],np.arange(self.n_meas),self.n_meas*i + self.i_ref*self.n_meas) # log_l += np.sum(norm.logpdf(s,loc=d,scale=n_sigma)) # return log_l def change_vars(self,samples): #return samples # convert samples in [Asin, Acos, f0] to [phase, f0, Amp] N = samples.shape[0] new_samples = torch.zeros((N,self.n_par)) new_samples[:,0] = torch.sqrt(samples[:,0]**2 + samples[:,1]**2 + samples[:,2]**2) new_samples[:,1] = samples[:,3] new_samples[:,2] = torch.remainder(torch.atan2(samples[:,1],samples[:,0]),2*np.pi)/(2.0*np.pi) new_samples[:,3] = torch.atan(samples[:,2]/torch.sqrt(samples[:,0]**2 + samples[:,1]**2))/np.pi return new_samples #def shift_time(self,samples,i): # # convert samples in [Asin, Acos, f0] to a different reference time # N = samples.shape[0] # new_samples = torch.zeros((N,self.n_par)) # phase = torch.remainder(torch.atan2(samples[:,0],samples[:,1]),2*np.pi)/(2.0*np.pi) # f = (0.4 + 0.2*samples[:,2])*0.5 # new_phase = phase + 2.0*np.pi*f*self.T*i # A = torch.sqrt(samples[:,0]**2 + samples[:,1]**2) # new_samples[:,0] = A*torch.sin(new_phase) # new_samples[:,1] = A*torch.cos(new_phase) # new_samples[:,2] = samples[:,2] # return new_samples def scatter_plot(self,ax,samples,ns_samples,reord_samples,x,n_max_list,n_prior,n_post,n_par,ns_logZ,prior_logZ,logZ,change_vars=False): js = np.zeros((len(n_max_list)+1,2)) samples = torch.permute(samples,(0,3,1,2))[0,:,:,:].reshape(n_max_list[-1],n_post,n_par) #old_samples = torch.permute(old_samples,(0,3,1,2))[0,:,:,:].reshape(nd,n_prior,n_par) reord_samples = torch.permute(reord_samples,(0,3,1,2))[0,:,:,:].reshape(n_max_list[-1],-1,n_par) if change_vars==True: samples = signal_model.change_vars(samples[:,:,:n_par].flatten(0,1)).reshape(n_max_list[-1],n_post,n_par).cpu().numpy() #old_samples = signal_model.change_vars(old_samples[:,:,:n_par].flatten(0,1)).reshape(nd,n_prior,n_par).cpu().numpy() reord_samples = signal_model.change_vars(reord_samples[:,:,:n_par].flatten(0,1)).reshape(n_max_list[-1],-1,n_par).cpu().numpy() x = signal_model.change_vars(x[:n_par].reshape(1,n_par)).flatten().cpu().numpy() ns_samples = signal_model.change_vars(torch.from_numpy(ns_samples).flatten(0,2)).reshape(len(n_max_list),-1,2,n_par).cpu().numpy() else: samples = samples.cpu().numpy() #old_samples = old_samples.cpu().numpy() reord_samples = reord_samples.cpu().numpy() x = x.flatten().cpu().numpy() #ns_samples = ns_samples.flatten().reshape(n_max,-1,2,n_par).cpu().numpy() idx = 0 # index for the ns samples for k in n_max_list: # index for the uroboros samples # prior_s = old_samples[k+1,:,:n_par].reshape(n_prior,n_par) post_s = samples[k-1,:,:n_par].reshape(n_post,n_par) js[idx,0] = 1e3*calculate_js(post_s[:,a], ns_samples[idx,:,0,a]).median js[idx,1] = 1e3*calculate_js(post_s[:,b], ns_samples[idx,:,0,b]).median ax[idx].plot(post_s[:,a],post_s[:,b],'.b',alpha=0.5,markersize=1) ax[idx].plot(ns_samples[idx,:,0,a],ns_samples[idx,:,0,b],'.r',alpha=0.5,markersize=1) ax[idx].plot(ns_samples[idx,:,1,a],ns_samples[idx,:,1,b],'.g',alpha=0.5,markersize=1) ax[idx].annotate('JS = {:.0f},{:.0f}'.format(js[idx,0],js[idx,1]),(0.5,0.9),xycoords='axes fraction') # #ax[j,k].set_xlim(signal_model.repar_bounds[a]) # #ax[j,k].set_ylim(signal_model.repar_bounds[b]) if change_vars==True: ax[idx].set_xlim(self.repar_bounds[a]) ax[idx].set_ylim(self.repar_bounds[b]) else: ax[idx].set_xlim(self.plot_bounds[a]) ax[idx].set_ylim(self.plot_bounds[b]) ax[idx].plot(x[a],x[b],'xk',label='truth',markersize=10) ax[idx].annotate('{}'.format(k),(0.9,0.05),xycoords='axes fraction') idx += 1 logB_ns = ns_logZ[0] - ns_logZ[2] logB_prior = prior_logZ - ns_logZ[2] logB_urob = logZ - ns_logZ[2] ax[-1].annotate('logB (UROB) = {:.2f}'.format(logB_urob),(1.1,0.9),xycoords='axes fraction') ax[-1].annotate('logB (NESS) = {:.2f}'.format(logB_ns),(1.1,0.75),xycoords='axes fraction') ax[-1].annotate('logB (PRIO) = {:.2f}'.format(logB_prior),(1.1,0.6),xycoords='axes fraction') js[-1,0] = 1e3*calculate_js(reord_samples[-1,:,a], ns_samples[-1,:,0,a]).median js[-1,1] = 1e3*calculate_js(reord_samples[-1,:,b], ns_samples[-1,:,0,b]).median ax[-1].plot(reord_samples[-1,:,a],reord_samples[-1,:,b],'.k',alpha=0.5,markersize=1) ax[-1].plot(ns_samples[-1,:,0,a],ns_samples[-1,:,0,b],'.r',alpha=0.5,markersize=1) #ax[nd].plot(ns_samples[-1,:,1,a],ns_samples[-1,:,1,b],'.g',alpha=0.5,markersize=1) ax[-1].plot(x[a],x[b],'xk',markersize=10) ax[-1].annotate('JS = {:.0f},{:.0f}'.format(js[-1,0],js[-1,1]),(0.5,0.9),xycoords='axes fraction') # #ax[j,nd-1].set_xlim(signal_model.repar_bounds[a]) # #ax[j,nd-1].set_ylim(signal_model.repar_bounds[b]) if change_vars==True: ax[-1].set_xlim(self.repar_bounds[a]) ax[-1].set_ylim(self.repar_bounds[b]) else: ax[-1].set_xlim(self.plot_bounds[a]) ax[-1].set_ylim(self.plot_bounds[b]) ax[-1].annotate('{}'.format(n_max_list[-1]),(0.9,0.05),xycoords='axes fraction') return js def plot_test_data(self,d_test_tensor,n_test_tensor,plot_path): """ Function to plot the test data """ d_test = d_test_tensor.cpu().numpy() n_test = n_test_tensor.cpu().numpy() fig, ax = plt.subplots(d_test.shape[0],1, figsize=(32, 3*d_test.shape[0]), dpi=100) i = 0 for d,n in zip(d_test,n_test): t = np.linspace(0,self.T*d.shape[1],d.shape[1]*self.N) dreal = d[:self.N,:] dimag = d[self.N:,:] nreal = n[:self.N,:] nimag = n[self.N:,:] ax[i].plot(t,dreal.transpose(1,0).flatten()) ax[i].plot(t,dimag.transpose(1,0).flatten()) ax[i].plot(t,dreal.transpose(1,0).flatten()-nreal.transpose(1,0).flatten()) ax[i].plot(t,dimag.transpose(1,0).flatten()-nimag.transpose(1,0).flatten()) i += 1 plt.savefig('{}/test_data.png'.format(plot_path)) plt.close() def plot_Bayesfactor(self,ns_logZ,prior_logZ,logZ,plot_path): #logB_ns = ns_logZ[:,0] - ns_logZ[:,2] #logB_prior = prior_logZ - np.tile(ns_logZ[:,2],(1,prior_logZ.shape[1])) #logB_urob = logZ - np.tile(ns_logZ[:,2],(1,logZ.shape[1])) fig, ax = plt.subplots(1, 1, figsize=(8, 8), dpi=100) #print(logB_ns.shape,logB_prior.shape,logB_urob.shape) for ns,prior,urob in zip(ns_logZ,prior_logZ,logZ): logB_ns = ns[0] - ns[2] logB_prior = prior - ns[2] logB_urob = urob - ns[2] ax.plot(logB_ns*np.ones(logB_prior.size),logB_prior-logB_ns,'.k',alpha=0.3) ax.plot(logB_ns*np.ones(logB_urob.size),logB_urob-logB_ns,'.r',alpha=0.3) ax.plot(logB_ns,logB_urob[-1]-logB_ns,'xr',markersize=10) plt.xlabel('log B (Nessai)') plt.ylabel('log B (Uroboros) - log B (Nessai)') plt.grid('on') plt.savefig('{}/Bayesfactor.png'.format(plot_path)) plt.close() def make_data(n_data,n_prior=100,n_max=1,n_post=None,flow=None,cc_prior_model=None,cc_meas_model=None,meas=None,valtest=False,noise_only_frac=0.0): """ function to make training data n_data = number of training samples to make n_prior = the number of samples to use from the prior before compression n_post = the number of posterior samples n_max = the max number of measurements per training data sample meas = the distance measurements """ Omega = None n = None if meas is None: meas, Omega, n = cw_model(n_meas,n_max).gen_sig(n_data,n_max,noise_only_frac=noise_only_frac) # put everything else into eval mode for f,c in zip(flow,cc_prior_model): f.eval() c.eval() # compress the measured data - we should be in training mode for this since # it is the only place we do the compression if not valtest: cc_meas_model.train() c_meas = torch.zeros(n_data,n_meas_cc_out,n_max).to(device) for i in range(n_max): c_meas[:,:,i] = cc_meas_model(meas[:,:,i].detach()) else: cc_meas_model.eval() c_meas = torch.zeros(n_data,n_meas_cc_out,n_max).to(device) for i in range(n_max): with torch.no_grad(): c_meas[:,:,i] = cc_meas_model(meas[:,:,i].detach()) # initialise the prior samples tensor # there should be a small set of prior samples (and log probs) for each measurement and for each sample prior_label = torch.zeros(n_data,n_prior,n_par,n_max).to(device) # initialise the prior label tensor # PRIOR 1 - for 1st signals sample from the original prior #test = flow.inverse(torch.ones(n_data*n_prior,n_par),conditional=test_cond).to(device) #prior_label[:,:,0,0] = test[:,0].reshape(n_prior,n_data).transpose(1,0).to(device) prior_label[:,:,0,0] = torch.randn(size=(n_data,n_prior)).to(device) prior_label[:,:,1,0] = torch.randn(size=(n_data,n_prior)).to(device) # PRIOR 1 - for 1st signals sample from the original prior #test = flow.inverse(torch.ones(n_data*n_prior,n_par),conditional=test_cond).to(device) #prior_label[:,:,0,0] = test[:,0].reshape(n_prior,n_data).transpose(1,0).to(device) prior_label[:,:,0,0] = torch.randn(size=(n_data,n_prior)).to(device) prior_label[:,:,1,0] = torch.randn(size=(n_data,n_prior)).to(device) prior_label[:,:,2,0] = torch.randn(size=(n_data,n_prior)).to(device) prior_label[:,:,3,0] = torch.rand(size=(n_data,n_prior)).to(device) #prior_label[:,:,3,0] = (-0.5*prior_label[:,:,0,0]**2 - 0.5*prior_label[:,:,1,0]**2 - np.log(2.0*np.pi)).to(device) #torch.zeros(size=(n_data,n_prior)).to(device) # sort the prior samples by log-lik order #temp, idx = torch.sort(prior_label[:,:,3,0],dim=1) #prior_label[:,:,3,0] = temp ############################################################### # we can stop here IF ONLY 1 measurement is being considered # Otherwise we need to put the 1st (nth) measurement through the flow to get a new prior for the 2nd (n+1) measurement #flow.eval() # we DO NOT train the flow in the data generation step #cc_prior_model.eval() # we DO NOT train the prior compression in the data generation step for i in range(n_max-1): # loop over each event from the zeroth to the n-1'th (we don't want to do the last one) ii = i if multiflow else 0 with torch.no_grad(): c_prior = cc_prior_model[ii](prior_label[:,:,:,i]).detach() #.flatten(1,2)).detach() # compress the i'th prior data test_cond = torch.cat((c_meas[:,:,i].detach(),c_prior,(i/float(glob_n_max))*torch.ones(size=(n_data,1)).to(device)),dim=1).to(device) # combine the compressed measurement and prior and measurement indices test_cond = test_cond.tile(n_prior,1).to(device) # tile it to generate n_prior samples for each of n_data (n_data*n_prior,n_cc+n_meas) with torch.no_grad(): # run the current flow state to generate new posterior -> prior samples and log-likelihoods prior_samples = flow[ii].sample(n_data*n_prior,conditional=test_cond).to(device) # output shape should be (n_data*n_prior,n_cos) #prior_samples, _ = flow.inverse(torch.tile(z_prior,(n_data,1)),conditional=test_cond) #prior_logprob = flow[ii].log_prob(prior_samples,conditional=test_cond).to(device) # output shape should be (n_data*n_prior) # fill in the prior labels - these are now the priors for the NEXT measurement prior_label[:,:,0,i+1] = prior_samples[:,0].reshape(n_prior,n_data).transpose(1,0).to(device) prior_label[:,:,1,i+1] = prior_samples[:,1].reshape(n_prior,n_data).transpose(1,0).to(device) prior_label[:,:,2,i+1] = prior_samples[:,2].reshape(n_prior,n_data).transpose(1,0).to(device) prior_label[:,:,3,i+1] = prior_samples[:,3].reshape(n_prior,n_data).transpose(1,0).to(device) #prior_label[:,:,3,i+1] = prior_logprob.reshape(n_prior,n_data).transpose(1,0).to(device) # If we want the iteratively generated posteriors -> priors for plotting then we generate more samples # we still use the fixed lower number of samples generated above for each stage # and we now do compute the posterior after the final measurement post_label = None test_cond = None if n_post is not None: post_label = torch.zeros(n_data,n_post,n_par,n_max).to(device) # initialise the posterior label tensor #flow.eval() # we DO NOT train the flow in the data generation step #cc_prior_model.eval() # we DO NOT train the prior compression in the data generation step for i in range(n_max): ii = i if multiflow else 0 # POST 1 - compress the uniform prior and add the 1st meas condition with torch.no_grad(): c_prior = cc_prior_model[ii](prior_label[:,:,:,i]).detach() #.flatten(1,2)).detach() test_cond = torch.cat((c_meas[:,:,i].detach(),c_prior,(i/float(glob_n_max))*torch.ones(size=(n_data,1)).to(device)),dim=1).to(device) test_cond = test_cond.tile(n_post,1).to(device) # has shape (n_data*n_prior,n_cc+n_meas) # keep doing this stage until we have the desired number of samples AFTER importance sampling #flag = True #rsum = 0 #prior_samples = np.zeros((n_post,n_par)) #prior_logprob = np.zeros(n_post) #while flag: with torch.no_grad(): temp_prior_samples = flow[ii].sample(n_data*n_post,conditional=test_cond).to(device) # output shape should be (n_data*n_post,n_cos) #temp_prior_logprob = flow[ii].log_prob(temp_prior_samples,conditional=test_cond).to(device) # output shape should be (n_data*n_post) # fill in the posteriors - these are the posteriors AFTER each event post_label[:,:,0,i] = temp_prior_samples[:,0].reshape(n_post,n_data).transpose(1,0).to(device) post_label[:,:,1,i] = temp_prior_samples[:,1].reshape(n_post,n_data).transpose(1,0).to(device) post_label[:,:,2,i] = temp_prior_samples[:,2].reshape(n_post,n_data).transpose(1,0).to(device) post_label[:,:,3,i] = temp_prior_samples[:,3].reshape(n_post,n_data).transpose(1,0).to(device) #post_label[:,:,3,i] = temp_prior_logprob.reshape(n_post,n_data).transpose(1,0).to(device) return Omega, meas, c_meas, prior_label, post_label, n, test_cond # save or load the true params and measurements try: os.mkdir('{}/{}'.format(ns_path,run_id)) except: pass parfile = '{}/{}/testdata_{}.dat'.format(ns_path,run_id,run_id) print(parfile) if os.path.isfile(parfile)==False: # generate test data print('making test data') data_test_tensor, d_test_tensor, _, _, _, n_test_tensor, _ = make_data(n_data=n_test, n_max=n_max, n_prior=n_prior, flow=flow, cc_prior_model=cc_prior_model, cc_meas_model=cc_meas_model, valtest=True, noise_only_frac=0.5) test_n_sig = torch.ones(n_test)*n_max print('made test data') # and save it to file with open(parfile, 'wb') as f: # Python 3: open(..., 'wb') pickle.dump([data_test_tensor.cpu().numpy(), d_test_tensor.cpu().numpy(), n_test_tensor.cpu().numpy()], f) print('saved par file {}'.format(parfile)) else: # open the parfile and read the data and parameters with open(parfile, 'rb') as f: # Python 3: open(..., 'rb') data_test_tensor, d_test_tensor, n_test_tensor = pickle.load(f) data_test_tensor = torch.tensor(data_test_tensor).to(device) d_test_tensor = torch.tensor(d_test_tensor).to(device) n_test_tensor = torch.tensor(n_test_tensor).to(device) test_n_sig = torch.ones(n_test)*n_max print('read in par file {}'.format(parfile)) print('read in test params') # plot test data print('plotting test data') signal_model = cw_model(n_meas,n_max) signal_model.plot_test_data(d_test_tensor,n_test_tensor,plot_path) print('plotted test data') # The FlowSampler object is used to managed the sampling. Keyword arguments # are passed to the nested sampling. j = 0 ns_samples = [] logZ = [] print(data_test_tensor.shape,d_test_tensor.shape) for x,d in zip(data_test_tensor.cpu().numpy(),d_test_tensor.cpu().numpy()): for i in n_max_list: for k in range(2): if os.path.isfile("{}/{}/ns_{}_{}_{}.dat".format(ns_path,run_id,i,j,k))==False: output = "{}/{}/ns_{}_{}_{}/".format(ns_path,run_id,i,j,k) print('starting ns and outputting to {}'.format(output)) logger = setup_logger(output=output) if k==0: fs = FlowSampler(cw_model(n_meas,n_max,d=d.transpose(1,0)[:i,:]), output=output, resume=False, reset_flow=8, volume_fraction=0.98, seed=1234) logZ_noise = cw_model(n_meas,n_max,d=d.transpose(1,0)[:i,:]).logZ_noise else: fs = FlowSampler(cw_model(n_meas,n_max,d=d.transpose(1,0)[i-1,:].reshape(1,-1),i_ref=i-1), output=output, reset_flow=8, volume_fraction=0.98, resume=False, seed=1234) logZ_noise = cw_model(n_meas,n_max,d=d.transpose(1,0)[i-1,:]).logZ_noise fs.run() temp_ns_samples = [] for s in fs.posterior_samples: temp_ns_samples.append([s[q] for q in np.arange(n_par)]) temp_ns_samples = np.array(temp_ns_samples,dtype=np.float64) ns_samples.append(temp_ns_samples) print(fs.log_evidence, fs.log_evidence_error, logZ_noise) temp_logZ = np.array([fs.log_evidence, fs.log_evidence_error, logZ_noise]) logZ.append(temp_logZ) # save nested sampling samples print('saving ns') with open("{}/{}/ns_{}_{}_{}.dat".format(ns_path,run_id,i,j,k), "wb") as f: temp_ns_samples.tofile(f) with open("{}/{}/logZ_{}_{}_{}.dat".format(ns_path,run_id,i,j,k), "wb") as f: temp_logZ.tofile(f) else: fn = '{}/{}/ns_{}_{}_{}.dat'.format(ns_path,run_id,i,j,k) print('loading ns file {}'.format(fn)) with open("{}/{}/ns_{}_{}_{}.dat".format(ns_path,run_id,i,j,k), "rb") as f: ns_samples.append(np.fromfile(f).reshape(-1,n_par).astype(np.float64)) fn = '{}/{}/logZ_{}_{}_{}.dat'.format(ns_path,run_id,i,j,k) print('loading logZ file {}'.format(fn)) with open("{}/{}/logZ_{}_{}_{}.dat".format(ns_path,run_id,i,j,k), "rb") as f: logZ.append(np.fromfile(f).reshape(-1,3).astype(np.float64)) j += 1 # restructure the ns samples k = 0 min_len = 1e16 for i in range(n_test): for j in n_max_list: for b in range(2): if ns_samples[k].shape[0] < min_len: min_len = ns_samples[k].shape[0] k += 1 ns_min = min(min_len,n_post) temp_ns_samples = np.zeros((n_test,len(n_max_list),ns_min,2,n_par)) ns_logZ = np.zeros((n_test,len(n_max_list),2,3)) k = 0 for i in range(n_test): c = 0 for j in n_max_list: for b in range(2): idx = np.random.choice(ns_samples[k].shape[0],size=ns_min) temp_ns_samples[i,c,:,b,:] = ns_samples[k][idx,:] ns_logZ[i,c,b,:] = logZ[k] k += 1 c += 1 # increment the n_max index ns_samples = torch.from_numpy(temp_ns_samples).flatten(0,3).reshape(n_test,len(n_max_list),ns_min,2,n_par).cpu().numpy() flow_pars = [] cc_prior_pars = [] #cc_prior_model[0].parameters() for i in range(multiflow_n_max): flow_pars = chain(flow_pars,flow[i].parameters()) cc_prior_pars = chain(cc_prior_pars,cc_prior_model[i].parameters()) all_pars = chain(flow_pars,cc_prior_pars,cc_meas_model.parameters()) optimiser = torch.optim.Adam(all_pars,lr=lr) #optimiser = torch.optim.Adam(chain(flow.parameters(),cc_prior_model.parameters(),cc_meas_model.parameters()),lr=lr) n_loss_avg = 1000 loss = dict(train=[], val=[]) train_loss_smooth = [] val_loss_smooth = [] scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimiser, iterations, eta_min=1e-6, last_epoch=-1) current_n_max = 1 sub_batch = np.zeros(n_max+1) sub_batch[0] = batch_size JS = np.empty((0,n_par*(n_par-1),n_test,len(n_max_list)+1,2)) JS_cv = np.empty((0,n_par*(n_par-1),n_test,len(n_max_list)+1,2)) minJS = [] maxJS = [] medJS = [] ivec = [] JSdata = None jsfile = '{}/js.dat'.format(plot_path) logZ = [] prior_logZ = [] for i in range(iterations+1): # perform inference (while training) on the final measurement data_train_tensor = torch.empty(0,n_par).to(device) # <--- initialize train_cond = torch.empty(0,n_meas_cc_out+n_prior_cc_out+1).to(device) # <--- initialize #orig_data_train_tensor, _, dist_train_tensor, priors_train_tensor, _ = make_data(n_data=batch_size,n_max=current_n_max,n_prior=n_prior,flow=flow,cc_prior_model=cc_prior_model,cc_meas_model=cc_meas_model) #start_idx = 0 optimiser.zero_grad() _loss = 0.0 for k in range(current_n_max): kk = k if multiflow else 0 #end_idx = start_idx + int(sub_batch[k]) # all networks are set to eval mode inside the data generation - except for the measurement compression temp_data_train_tensor, _, temp_dist_train_tensor, temp_priors_train_tensor, _, _, _ = make_data(n_data=int(sub_batch[k]),n_max=k+1,n_prior=n_prior,flow=flow,cc_prior_model=cc_prior_model,cc_meas_model=cc_meas_model) #temp_data_train_tensor = orig_data_train_tensor[start_idx:end_idx,:].detach() #temp_dist_train_tensor = dist_train_tensor[start_idx:end_idx,:,:k+1] #temp_priors_train_tensor = priors_train_tensor[start_idx:end_idx,:,:,:k+1].detach() temp_data_train_tensor = temp_data_train_tensor.detach() temp_priors_train_tensor = temp_priors_train_tensor.detach() cc_prior_model[kk].train() compressed_prior = cc_prior_model[kk](temp_priors_train_tensor[:,:,:,-1]) #.flatten(1,2)) # take the last event prior temp_train_cond = torch.cat((temp_dist_train_tensor[:,:,-1],compressed_prior,(k/float(glob_n_max))*torch.ones(size=(int(sub_batch[k]),1)).to(device)),dim=1).to(device) data_train_tensor = torch.cat((data_train_tensor,temp_data_train_tensor[:,:n_par]),dim=0) train_cond = torch.cat((train_cond,temp_train_cond),dim=0) #start_idx += int(sub_batch[k]) flow[kk].train() #optimiser.zero_grad() _loss -= flow[kk].log_prob(data_train_tensor[:,:n_par], conditional=train_cond).mean() #*(sub_batch[k]/float(batch_size)) # compute the loss _loss.backward() optimiser.step() scheduler.step() train_loss = _loss.mean().item() loss["train"].append(train_loss) train_loss_smooth.append(np.median(loss["train"][max(i-n_loss_avg,0):])) # validation if not i % 100: data_val_tensor = torch.empty(0,n_par).to(device) # <--- initialize val_cond = torch.empty(0,n_meas_cc_out+n_prior_cc_out+1).to(device) # <--- initialize #orig_data_val_tensor, _, dist_val_tensor, priors_val_tensor, _ = make_data(n_data=batch_size,n_max=current_n_max,n_prior=n_prior,flow=flow,cc_prior_model=cc_prior_model,cc_meas_model=cc_meas_model,valtest=True) #start_idx = 0 _loss = 0.0 for k in range(current_n_max): ii = i if multiflow else 0 #end_idx = start_idx + int(sub_batch[k]) temp_data_val_tensor, _, temp_dist_val_tensor, temp_priors_val_tensor, _, _, _ = make_data(n_data=int(sub_batch[k]),n_max=k+1,n_prior=n_prior,flow=flow,cc_prior_model=cc_prior_model,cc_meas_model=cc_meas_model,valtest=True) #temp_data_val_tensor = orig_data_val_tensor[start_idx:end_idx,:].detach() #temp_dist_val_tensor = dist_val_tensor[start_idx:end_idx,:,:k+1].detach() #temp_priors_val_tensor = priors_val_tensor[start_idx:end_idx,:,:,:k+1].detach() temp_data_val_tensor = temp_data_val_tensor.detach() temp_dist_val_tensor = temp_dist_val_tensor.detach() temp_priors_val_tensor = temp_priors_val_tensor.detach() cc_prior_model[kk].eval() with torch.no_grad(): compressed_prior = cc_prior_model[kk](temp_priors_val_tensor[:,:,:,-1]) #.flatten(1,2)) # take the last event prior temp_val_cond = torch.cat((temp_dist_val_tensor[:,:,-1],compressed_prior,(k/float(glob_n_max))*torch.ones(size=(int(sub_batch[k]),1)).to(device)),dim=1).to(device) data_val_tensor = torch.cat((data_val_tensor,temp_data_val_tensor[:,:n_par]),dim=0) val_cond = torch.cat((val_cond,temp_val_cond),dim=0) #start_idx += int(sub_batch[k]) flow[kk].eval() with torch.no_grad(): _loss -= flow[kk].log_prob(data_val_tensor[:,:n_par], conditional=val_cond).mean() #*(sub_batch[k]/float(batch_size)) # .item() val_loss = _loss.mean().item() loss["val"].append(val_loss) val_loss_smooth.append(val_loss) if np.isnan(loss["val"][-1]): print('validation loss is nan: i={}'.format(i)) print('sub_batch is {}'.format(sub_batch)) #print('last_idx is {}'.format(last_idx)) print('cond_val any nans = {}'.format(torch.isnan(val_cond).any())) print('y_val any nans = {}'.format(torch.isnan(orig_data_val_tensor).any())) #print('x_val any nans = {}'.format(torch.isnan(x_val).any())) print('x_comp_val any nans = {}'.format(torch.isnan(dist_val_tensor).any())) print('y_prior_val any nans = {}'.format(torch.isnan(temp_priors_val_tensor).any())) print('compressed_prior any nans = {}'.format(torch.isnan(compressed_prior).any())) print('cond_val any nans = {}'.format(torch.isnan(val_cond).any())) if not i % plot_step: JStemp = np.zeros((1,n_par*(n_par-1),n_test,len(n_max_list)+1,2)) JStemp_cv = np.zeros((1,n_par*(n_par-1),n_test,len(n_max_list)+1,2)) my_lr = scheduler.get_last_lr()[0] print(f"Epoch {i} - train: {loss['train'][-1]:.3f}, val: {loss['val'][-1]:.3f}, lr: {my_lr:.3e}, n_max: {current_n_max}, sub_batch: {sub_batch}") # importance sampling if i==iterations: j = 0 isamples = np.zeros((n_test,n_post,n_par)) for x,d,nd in zip(data_test_tensor,d_test_tensor,test_n_sig.cpu().numpy().astype(int)): samples_temp = np.empty((0,n_par)) urob_logL = np.empty((0)) logL = np.empty((0)) temp_model = cw_model(n_meas,nd,d=d.transpose(1,0).detach().cpu().numpy()) flag = True while flag: _, _, _, _, samples, _, test_cond = make_data(1,n_prior=n_prior,n_max=nd,n_post=n_post,flow=flow,cc_prior_model=cc_prior_model,cc_meas_model=cc_meas_model,meas=d.reshape(1,n_meas,n_max),valtest=True) #print(samples.shape,test_cond.shape,urob_logL.shape) temp = flow[0].log_prob(samples[0,:,:,-1],conditional=test_cond).detach().cpu().numpy() #print(temp.shape,urob_logL.shape) urob_logL = np.concatenate([urob_logL,flow[0].log_prob(samples[0,:,:,-1],conditional=test_cond).detach().cpu().numpy()],axis=0) samples_temp = np.concatenate([samples_temp,samples[0,:,:,-1].detach().cpu().numpy()],axis=0) x = {} for h,nm in enumerate(temp_model.names): x["{}".format(nm)] = samples[0,:,h,-1].detach().cpu().numpy() logL = np.concatenate([logL,temp_model.log_likelihood(x)],axis=0) log_wts = logL - urob_logL log_wts -= np.max(log_wts) n_exp = np.sum(np.exp(log_wts)) print('n_exp = {}'.format(n_exp)) if n_exp>2*n_post: flag = False idx = np.argwhere(np.log(uniform.rvs(loc=0,scale=1,size=log_wts.size)) int(iterations/4): sub_batch = (batch_size/n_max)*np.ones(n_max) # if starting training another flow then copy the weights across from the previous one if current_n_max > old_n_max and multiflow: print('updated current_n_max from {} -> {}'.format(old_n_max,current_n_max)) start = old_n_max - 1 end = current_n_max - 1 flow[end].load_state_dict(flow[start].state_dict()) cc_prior_model[end].load_state_dict(cc_prior_model[start].state_dict()) print('copied model {} to model {}'.format(start,end)) old_n_max = current_n_max #if i>0 and not i % 50000 and current_n_max