from glasflow import RealNVP import matplotlib.pyplot as plt import matplotlib.patches as patches import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from itertools import chain, permutations from datetime import datetime import pickle from scipy.interpolate import RectBivariateSpline from scipy.integrate import quad import time import shutil import seaborn as sns import os from scipy.stats import norm from nessai.flowsampler import FlowSampler from nessai.model import Model from nessai.utils import setup_logger import corner seed = 4 torch.manual_seed(seed) np.random.seed(seed) sns.set_context("notebook") sns.set_palette("colorblind") device = "cuda" date_id = datetime.today().isoformat() plot_path = '/data/www.astro/chrism/uroboros/{}'.format(date_id) ns_path = './' try: os.mkdir('{}'.format(plot_path)) except: print('unable to make output directory {}'.format(plot_path)) exit(1) shutil.copyfile('./uroboros.py', '{}/uroboros.txt'.format(plot_path)) n_test = 2 # number of individual test data samples n_prior = 128 # number of samples used to represent the conditonal prior n_max = 3 # max the number of measurements per sample n_par = 3 # the number of hyperparameters n_meas = 64 # the size of a measurement n_prior_cc_out = 32 # the size of compressed prior n_meas_cc_out = 16 # the size of compressed measurement n_marg = 0 # the number of parameters per measurememnt to be marginalised n_post = 3000 # the number of posterior samples used for plotting lr = 1e-3 # the learning rate run_id = 'cw_amp_s{}_nmax{}_ntest{}'.format(seed,n_max,n_test) # define the Flow that will estimate the parameters conditional on new data and compressed prior information flow = RealNVP( n_inputs=n_par, # number of params n_transforms=7, n_conditional_inputs=n_prior_cc_out+n_meas_cc_out + 1, # size of compressed prior plus size of measurement plus segment index n_neurons=128, n_blocks_per_transform=4, batch_norm_within_blocks=True, linear_transform='permutation', batch_norm_between_transforms=True, ) flow.to(device) print(f"Created flow and sent to {device}") class MAB(nn.Module): def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): super(MAB, self).__init__() self.dim_V = dim_V self.num_heads = num_heads self.fc_q = nn.Linear(dim_Q, dim_V) self.fc_k = nn.Linear(dim_K, dim_V) self.fc_v = nn.Linear(dim_K, dim_V) if ln: self.ln0 = nn.LayerNorm(dim_V) self.ln1 = nn.LayerNorm(dim_V) self.fc_o = nn.Linear(dim_V, dim_V) def forward(self, Q, K): Q = self.fc_q(Q) K, V = self.fc_k(K), self.fc_v(K) dim_split = self.dim_V // self.num_heads Q_ = torch.cat(Q.split(dim_split, 2), 0) K_ = torch.cat(K.split(dim_split, 2), 0) V_ = torch.cat(V.split(dim_split, 2), 0) A = torch.softmax(Q_.bmm(K_.transpose(1,2))/np.sqrt(self.dim_V), 2) O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2) O = O if getattr(self, 'ln0', None) is None else self.ln0(O) O = O + F.relu(self.fc_o(O)) O = O if getattr(self, 'ln1', None) is None else self.ln1(O) return O class SAB(nn.Module): def __init__(self, dim_in, dim_out, num_heads, ln=False): super(SAB, self).__init__() self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln) def forward(self, X): return self.mab(X, X) class ISAB(nn.Module): def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False): super(ISAB, self).__init__() self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out)) nn.init.xavier_uniform_(self.I) self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln) self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln) def forward(self, X): H = self.mab0(self.I.repeat(X.size(0), 1, 1), X) return self.mab1(X, H) class PMA(nn.Module): def __init__(self, dim, num_heads, num_seeds, ln=False): super(PMA, self).__init__() self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim)) nn.init.xavier_uniform_(self.S) self.mab = MAB(dim, dim, dim, num_heads, ln=ln) def forward(self, X): return self.mab(self.S.repeat(X.size(0), 1, 1), X) class SmallSetTransformer(nn.Module): def __init__(self,): super().__init__() self.enc = nn.Sequential( SAB(dim_in=n_par+1, dim_out=64, num_heads=8), SAB(dim_in=64, dim_out=64, num_heads=8), ) self.dec = nn.Sequential( PMA(dim=64, num_heads=8, num_seeds=1), nn.Linear(in_features=64, out_features=n_prior_cc_out), ) def forward(self, x): x = self.enc(x) x = self.dec(x) return x.squeeze(-2) # define network to compress the recycled prior samples cc_prior_model = SmallSetTransformer() cc_prior_model.to(device) # the compression model that takes measurement samples and compresses them cc_meas_model = nn.Sequential( nn.Linear(n_meas, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, n_meas_cc_out), nn.Sigmoid() ) cc_meas_model.to(device) class cw_model(Model): """A simple two-dimensional Gaussian likelihood.""" def __init__(self,n_meas,d=None): # Names of parameters to sample self.n_meas = n_meas # the length of the timeseries self.dvec = d # the noisy data shape [n_obs,n_meas] self.T = 1.0 # observation time fixed to 1 sec self.dt = self.T/self.n_meas # define sampling time self.t = torch.arange(self.n_meas)*self.dt # define time vector self.fmax = 0.5/self.dt # define nyquist frequency self.n_sigma = 1.0 # the noise standard deviation self.SNR_factor = 0.25 # a scaling that contraols the average SNR self.n_par = 3 # number of parameters self.names = ["Asin", "Acos", "f0"] # the names of the parameters for Nested sampling self.bounds = {"Asin": [-10.0, 10.0], "Acos": [-10.0,10.0], "f0": [0, 1]} # the prior ranges for the parameters for NS def sig(self,x,i): """ We assume that the input shape of x is [n_par,N] where N=n_data*n_max We assume that the input shape of i is [N] All tensors are expanded to have an extra dimension of n_meas Output has shape [N,n_meas] """ N = x.shape[1] x = torch.tile(x.reshape(self.n_par,N,1),(1,1,self.n_meas)) Asin = x[0] # the phase sin quadrature (N,n_meas) Acos = x[1] # the phase cos quadrature (N,n_meas) f0 = x[2] # the frequency normalised in reference to the nyquist frequency (0-1) (N,n_meas) i = torch.tile(i.reshape(N,1),(1,n_meas)) #i - the index of the timeseries f = (0.4 + 0.2*f0)*self.fmax # f0 is fraction of the Nyquist frequency (0.4 - 0.6) phase = 2.0*np.pi*f*(torch.tile(self.t.reshape(1,n_meas),(N,1)) + i*self.T) # shift the initial time in reference to first segment return self.SNR_factor*(Asin*np.cos(phase) + Acos*np.sin(phase)) # return timeseries def gen_pars(self,n_data): """ Generate training data from the prior n_data - the number of separate instances n_max - the number of measurements per instance """ Asin = (torch.randn(size=(n_data,1))).to(device) # the normalised amplitude sin quadrature from a Gaussian (bs,1) Acos = (torch.randn(size=(n_data,1))).to(device) # the normalised amplitude cos quadrature from a Gaussian (bs,1) f0 = (torch.rand(size=(n_data,1))).to(device) # the normalised f0 value (bs,1) Omega = (torch.concatenate((Asin,Acos,f0),axis=1)).to(device) # group the parameters together (bs,n_par) return Omega def gen_data(self,n_data,n_max): """ Generate training data from the prior n_data - the number of separate instances n_max - the number of measurements per instance """ Omega = self.gen_pars(n_data) # generate signal parameters from prior n = (self.n_sigma*torch.randn(size=(n_data,self.n_meas,n_max))).to(device) # additive gaussian noise (bs,n_max) # generate the signal - the inputs all have shape (n_data,n_meas,n_max). The output has the same shape. #meas = self.sig(torch.tile(Omega.reshape(n_data,1,1),(1,self.n_meas,n_max)).flatten().cpu(), # torch.tile(Omega[:,1].reshape(n_data,1,1),(1,self.n_meas,n_max)).flatten().cpu(), # torch.tile(Omega[:,2].reshape(n_data,1,1),(1,self.n_meas,n_max)).flatten().cpu(), # torch.tile(torch.arange(self.n_meas).reshape(1,self.n_meas,1),(n_data,1,n_max)).flatten().cpu(), meas = self.sig(torch.tile(Omega.transpose(1,0).reshape(self.n_par,n_data,1),(1,1,n_max)).flatten(1,2).cpu(), # parameters have shape (n_par,m_meas*n_data*n_max) torch.tile(torch.arange(n_max).reshape(1,n_max),(n_data,1)).flatten().cpu()).reshape(n_data,self.n_meas,n_max).to(device) + n #meas = meas.reshape(n_data,self.n_meas,n_max).to(device) return meas.type(torch.cuda.FloatTensor), Omega def change_vars(self,samples): # convert samples in [Asin, Acos, f0] to [phase, f0, Amp] N = samples.shape[0] new_samples = torch.zeros((N,self.n_par)) new_samples[:,2] = torch.sqrt(samples[:,0]**2 + samples[:,1]**2) new_samples[:,1] = samples[:,2] new_samples[:,0] = torch.remainder(torch.atan2(samples[:,0],samples[:,1]),2*np.pi)/(2.0*np.pi) return new_samples def plot_test_data(self,d_test_tensor,plot_path): """ Function to plot the test data """ d_test = d_test_tensor.cpu().numpy() n_test = d_test.shape[0] fig, ax = plt.subplots(n_test,1, figsize=(32, 3*n_test), dpi=100) for i,d in enumerate(d_test): t = np.linspace(0,self.T*d.shape[1],d.shape[1]*d.shape[0]) ax[i].plot(t,d.flatten()) plt.savefig('{}/test_data.png'.format(plot_path)) plt.close() def log_prior_fast(self,x): """ Equivelent to the NS log_prior function but using tensors and fast """ return -0.5*x[:,0]**2 - 0.5*x[:,1]**2 - np.log(2.0*np.pi) def log_prior(self, x): """ Returns log of prior given a live point assuming uniform priors on each parameter. """ # Check if values are in bounds, returns True/False # Then take the log to get 0/-inf and make sure the dtype is float log_p = np.log(self.in_bounds(x), dtype="float") # Iterate through each parameter (x and y) # since the live points are a structured array we can # get each value using just the name #for n in self.names: log_p -= 0.5*(x["Asin"]**2 + x["Acos"]**2) # Gaussian priors on quadratures log_p -= np.log(self.bounds["f0"][1] - self.bounds["f0"][0]) # uniform prior on frequency return log_p def log_likelihood(self, x): """ Returns log likelihood of given live point assuming a Gaussian likelihood. """ log_l = 0.0 # initialise the log likelihood for i,d in enumerate(self.dvec): # loop over measurements Omega = torch.from_numpy(np.array([x["Asin"],x["Acos"],x["f0"]])).reshape(self.n_par,-1) idx = torch.from_numpy(np.ones(Omega.shape[1])*i) s = self.sig(Omega,idx) log_l += np.sum(norm.logpdf(s,loc=d,scale=self.n_sigma)) return log_l # initialise the signal model signal_model = cw_model(n_meas) def make_data(n_data,n_prior=100,n_max=1,n_post=None,flow=None,signal_model=None,cc_prior_model=None,cc_meas_model=None,meas=None,valtest=False): """ function to make training data n_data = number of training samples to make n_prior = the number of samples to use from the prior before compression n_post = the number of posterior samples n_max = the max number of measurements per training data sample meas = the distance measurements """ # generate measured data if none has been supplied Omega = None # initialise the true signal parameters as None if meas is None: meas, Omega = signal_model.gen_data(n_data,n_max) # compress the measured data - we should be in training mode for this since # it is the only place we do the compression if not valtest: cc_meas_model.train() else: cc_meas_model.eval() # compress the measured data #c_meas = torch.zeros(n_data,n_meas_cc_out,n_max).to(device) #for i in range(n_max): # c_meas[:,:,i] = cc_meas_model(meas[:,:,i].detach()) c_meas = cc_meas_model(meas.reshape(n_data*n_max,n_meas)).reshape(n_data,n_meas_cc_out,n_max) # initialise the prior samples tensor # there should be a small set of prior samples (and log probs) for each measurement and for each sample prior_label = torch.zeros(n_data,n_prior,n_par+1,n_max).to(device) # initialise the prior label tensor # PRIOR 1 - for 1st signals sample from the original prior #test = flow.inverse(torch.ones(n_data*n_prior,n_par),conditional=test_cond).to(device) #prior_label[:,:,0,0] = test[:,0].reshape(n_prior,n_data).transpose(1,0).to(device) prior_label[:,:,:n_par,0] = signal_model.gen_pars(n_data*n_prior).reshape(n_data,n_prior,n_par).to(device) #prior_label[:,:,0,0] = torch.randn(size=(n_data,n_prior)).to(device) #prior_label[:,:,1,0] = torch.randn(size=(n_data,n_prior)).to(device) #prior_label[:,:,2,0] = torch.rand(size=(n_data,n_prior)).to(device) #prior_label[:,:,3,0] = (-0.5*prior_label[:,:,0,0]**2 - 0.5*prior_label[:,:,1,0]**2 - np.log(2.0*np.pi)).to(device) #torch.zeros(size=(n_data,n_prior)).to(device) prior_label[:,:,n_par,0] = signal_model.log_prior_fast(prior_label[:,:,:n_par,0].reshape(-1,n_par)).reshape(n_data,n_prior).to(device) ############################################################### # we can stop here IF ONLY 1 measurement is being considered # Otherwise we need to put the 1st (nth) measurement through the flow to get a new prior for the 2nd (n+1) measurement flow.eval() # we DO NOT train the flow in the data generation step cc_prior_model.eval() # we DO NOT train the prior compression in the data generation step for i in range(n_max-1): # loop over each event from the zeroth to the n-1'th (we don't want to do the last one) test = prior_label[:,:,:,i].flatten(1,2) c_prior = cc_prior_model(prior_label[:,:,:,i]).detach() # compress the i'th prior data test_cond = torch.cat((c_meas[:,:,i].detach(),c_prior,i*torch.ones(size=(n_data,1)).to(device)),dim=1).to(device) # combine the compressed measurement and prior and measurement indices test_cond = test_cond.tile(n_prior,1).to(device) # tile it to generate n_prior samples for each of n_data (n_data*n_prior,n_cc+n_meas) with torch.no_grad(): # run the current flow state to generate new posterior -> prior samples and log-likelihoods prior_samples = flow.sample(n_data*n_prior,conditional=test_cond).to(device) # output shape should be (n_data*n_prior,n_cos) prior_logprob = flow.log_prob(prior_samples,conditional=test_cond).to(device) # output shape should be (n_data*n_prior) # fill in the prior labels - these are now the priors for the NEXT measurement #prior_label[:,:,0,i+1] = prior_samples[:,0].reshape(n_prior,n_data).transpose(1,0).to(device) #prior_label[:,:,1,i+1] = prior_samples[:,1].reshape(n_prior,n_data).transpose(1,0).to(device) #prior_label[:,:,2,i+1] = prior_samples[:,2].reshape(n_prior,n_data).transpose(1,0).to(device) prior_label[:,:,:n_par,i+1] = prior_samples.reshape(n_prior,n_data,n_par).transpose(1,0).to(device) prior_label[:,:,n_par,i+1] = prior_logprob.reshape(n_prior,n_data).transpose(1,0).to(device) # If we want the iteratively generated posteriors -> priors for plotting then we generate more samples # we still use the fixed lower number of samples generated above for each stage # and we now do compute the posterior after the final measurement post_label = None if n_post is not None: post_label = torch.zeros(n_data,n_post,n_par+1,n_max).to(device) # initialise the posterior label tensor flow.eval() # we DO NOT train the flow in the data generation step cc_prior_model.eval() # we DO NOT train the prior compression in the data generation step for i in range(n_max): # POST 1 - compress the uniform prior and add the 1st meas condition c_prior = cc_prior_model(prior_label[:,:,:,i]).detach() #.flatten(1,2)).detach() test_cond = torch.cat((c_meas[:,:,i].detach(),c_prior,i*torch.ones(size=(n_data,1)).to(device)),dim=1).to(device) test_cond = test_cond.tile(n_post,1).to(device) # has shape (n_data*n_prior,n_cc+n_meas) with torch.no_grad(): temp_prior_samples = flow.sample(n_data*n_post,conditional=test_cond).to(device) # output shape should be (n_data*n_post,n_cos) temp_prior_logprob = flow.log_prob(temp_prior_samples,conditional=test_cond).to(device) # output shape should be (n_data*n_post) # fill in the posteriors - these are the posteriors AFTER each event #post_label[:,:,0,i] = temp_prior_samples[:,0].reshape(n_post,n_data).transpose(1,0).to(device) #post_label[:,:,1,i] = temp_prior_samples[:,1].reshape(n_post,n_data).transpose(1,0).to(device) #post_label[:,:,2,i] = temp_prior_samples[:,2].reshape(n_post,n_data).transpose(1,0).to(device) post_label[:,:,:n_par,i] = temp_prior_samples.reshape(n_post,n_data,n_par).transpose(1,0).to(device) post_label[:,:,n_par,i] = temp_prior_logprob.reshape(n_post,n_data).transpose(1,0).to(device) return Omega, meas, c_meas, prior_label, post_label # save or load the true params and measurements try: os.mkdir('{}/{}'.format(ns_path,run_id)) except: pass parfile = '{}/{}/testdata_{}.dat'.format(ns_path,run_id,run_id) if os.path.isfile(parfile)==False: # generate test data data_test_tensor, d_test_tensor, c_d_test_tensor, prior_test_tensor, _ = make_data(n_data=n_test,n_max=n_max,n_prior=n_prior,signal_model=signal_model,flow=flow,cc_prior_model=cc_prior_model,cc_meas_model=cc_meas_model,valtest=True) test_n_sig = torch.ones(n_test)*n_max # and save it to file with open(parfile, 'wb') as f: # Python 3: open(..., 'wb') pickle.dump([data_test_tensor.cpu().numpy(), d_test_tensor.cpu().numpy()], f) print('saved par file {}'.format(parfile)) else: # open the parfile and read the data and parameters with open(parfile, 'rb') as f: # Python 3: open(..., 'rb') data_test_tensor, d_test_tensor = pickle.load(f) data_test_tensor = torch.tensor(data_test_tensor).to(device) d_test_tensor = torch.tensor(d_test_tensor).to(device) test_n_sig = torch.ones(n_test)*n_max print('read in par file {}'.format(parfile)) # plot test data signal_model.plot_test_data(d_test_tensor,plot_path) #fig, ax = plt.subplots(n_test,1, figsize=(32, 3*n_test), dpi=100) #for i,d in enumerate(d_test_tensor.cpu().numpy()): # t = np.linspace(0,d.shape[1],d.shape[1]*d.shape[0]) # ax[i].plot(t,d.flatten()) #plt.savefig('{}/test_data.png'.format(plot_path)) #plt.close() # The FlowSampler object is used to managed the sampling. Keyword arguments # are passed to the nested sampling. j = 0 ns_samples = [] for x,d in zip(data_test_tensor.cpu().numpy(),d_test_tensor.cpu().numpy()): for i in range(n_max): if os.path.isfile("{}/{}/ns_{}_{}_{}.dat".format(ns_path,run_id,run_id,i,j))==False: print('starting ns') output = "{}/{}/ns_{}_{}_{}/".format(ns_path,run_id,run_id,i,j) logger = setup_logger(output=output) fs = FlowSampler(cw_model(n_meas,d=d[:,:i+1].transpose(1,0)), output=output, resume=False, seed=1234) fs.run() temp_ns_samples = [] for s in fs.posterior_samples: temp_ns_samples.append([s[q] for q in np.arange(n_par+n_marg)]) temp_ns_samples = np.array(temp_ns_samples,dtype=np.float64) ns_samples.append(temp_ns_samples) # save nested sampling samples print('saving ns') with open("{}/{}/ns_{}_{}_{}.dat".format(ns_path,run_id,run_id,i,j), "wb") as f: temp_ns_samples.tofile(f) else: fn = '{}/{}/ns_{}_{}_{}.dat'.format(ns_path,run_id,run_id,i,j) print('loading ns file {}'.format(fn)) with open("{}/{}/ns_{}_{}_{}.dat".format(ns_path,run_id,run_id,i,j), "rb") as f: ns_samples.append(np.fromfile(f).reshape(-1,n_par+n_marg).astype(np.float64)) j += 1 # restructure the ns samples k = 0 min_len = 1e16 for i in range(n_test): for j in range(n_max): if ns_samples[k].shape[0] < min_len: min_len = ns_samples[k].shape[0] k += 1 ns_min = min(min_len,n_post) temp_ns_samples = np.zeros((n_test,n_max,ns_min,n_par+n_marg)) k = 0 for i in range(n_test): for j in range(n_max): idx = np.random.choice(ns_samples[k].shape[0],size=ns_min) temp_ns_samples[i,j,:,:] = ns_samples[k][idx,:] k += 1 # convert nested samples to amplitude and phase #ns_samples = np.zeros((n_test,n_max,ns_min,n_par+n_marg)) #ns_samples[:,:,:,2] = np.sqrt(temp_ns_samples[:,:,:,0]**2 + temp_ns_samples[:,:,:,1]**2) #ns_samples[:,:,:,1] = temp_ns_samples[:,:,:,2] #ns_samples[:,:,:,0] = np.remainder(np.atan2(temp_ns_samples[:,:,:,0],temp_ns_samples[:,:,:,1]),2*np.pi)/(2.0*np.pi) ns_samples = signal_model.change_vars(torch.from_numpy(temp_ns_samples).flatten(0,2)).reshape(n_test,n_max,ns_min,n_par).cpu().numpy() optimiser = torch.optim.Adam(chain(flow.parameters(),cc_prior_model.parameters(),cc_meas_model.parameters()),lr=lr) LUT = None iterations = 100000 batch_size = 1024 n_loss_avg = 1000 loss = dict(train=[], val=[]) train_loss_smooth = [] val_loss_smooth = [] scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimiser, iterations, eta_min=1e-6, last_epoch=-1) current_n_max = 1 sub_batch = np.zeros(n_max+1) sub_batch[0] = batch_size for i in range(iterations+1): # perform inference (while training) on the final measurement data_train_tensor = torch.empty(0,n_par).to(device) # <--- initialize train_cond = torch.empty(0,n_meas_cc_out+n_prior_cc_out+1).to(device) # <--- initialize orig_data_train_tensor, _, dist_train_tensor, priors_train_tensor, _ = make_data(n_data=batch_size,n_max=current_n_max,n_prior=n_prior,signal_model=signal_model,flow=flow,cc_prior_model=cc_prior_model,cc_meas_model=cc_meas_model) start_idx = 0 for k in range(current_n_max): end_idx = start_idx + int(sub_batch[k]) # all networks are set to eval mode inside the data generation - except for the measurement compression #temp_data_train_tensor, _, temp_dist_train_tensor, temp_priors_train_tensor, _ = make_data(n_data=int(sub_batch[k]),n_max=k+1,n_prior=n_prior,flow=flow,cc_prior_model=cc_prior_model,cc_meas_model=cc_meas_model,LUT=LUT) temp_data_train_tensor = orig_data_train_tensor[start_idx:end_idx,:] temp_dist_train_tensor = dist_train_tensor[start_idx:end_idx,:,:k+1] temp_priors_train_tensor = priors_train_tensor[start_idx:end_idx,:,:,:k+1] cc_prior_model.train() compressed_prior = cc_prior_model(temp_priors_train_tensor[:,:,:,-1]) #.flatten(1,2)) # take the last event prior temp_train_cond = torch.cat((temp_dist_train_tensor[:,:,-1],compressed_prior,k*torch.ones(size=(int(sub_batch[k]),1)).to(device)),dim=1).to(device) data_train_tensor = torch.cat((data_train_tensor,temp_data_train_tensor),dim=0) train_cond = torch.cat((train_cond,temp_train_cond),dim=0) start_idx += int(sub_batch[k]) flow.train() optimiser.zero_grad() _loss = -flow.log_prob(data_train_tensor, conditional=train_cond).mean() # compute the loss _loss.backward() optimiser.step() scheduler.step() train_loss = _loss.item() loss["train"].append(train_loss) train_loss_smooth.append(np.median(loss["train"][max(i-n_loss_avg,0):])) # validation if not i % 100: data_val_tensor = torch.empty(0,n_par).to(device) # <--- initialize val_cond = torch.empty(0,n_meas_cc_out+n_prior_cc_out+1).to(device) # <--- initialize orig_data_val_tensor, _, dist_val_tensor, priors_val_tensor, _ = make_data(n_data=batch_size,n_max=current_n_max,n_prior=n_prior,signal_model=signal_model,flow=flow,cc_prior_model=cc_prior_model,cc_meas_model=cc_meas_model,valtest=True) start_idx = 0 for k in range(current_n_max): end_idx = start_idx + int(sub_batch[k]) #temp_data_val_tensor, _, temp_dist_val_tensor, temp_priors_val_tensor, _ = make_data(n_data=int(sub_batch[k]),n_max=k+1,n_prior=n_prior,flow=flow,cc_prior_model=cc_prior_model,cc_meas_model=cc_meas_model,LUT=LUT,valtest=True) temp_data_val_tensor = orig_data_val_tensor[start_idx:end_idx,:] temp_dist_val_tensor = dist_val_tensor[start_idx:end_idx,:,:k+1] temp_priors_val_tensor = priors_val_tensor[start_idx:end_idx,:,:,:k+1] cc_prior_model.eval() compressed_prior = cc_prior_model(temp_priors_val_tensor[:,:,:,-1]) #.flatten(1,2)) # take the last event prior temp_val_cond = torch.cat((temp_dist_val_tensor[:,:,-1],compressed_prior,k*torch.ones(size=(int(sub_batch[k]),1)).to(device)),dim=1).to(device) data_val_tensor = torch.cat((data_val_tensor,temp_data_val_tensor),dim=0) val_cond = torch.cat((val_cond,temp_val_cond),dim=0) start_idx += int(sub_batch[k]) flow.eval() with torch.no_grad(): _loss = -flow.log_prob(data_val_tensor, conditional=val_cond).mean().item() val_loss = _loss loss["val"].append(val_loss) val_loss_smooth.append(val_loss) if not i % 2000 and i>0: my_lr = scheduler.get_last_lr()[0] print(f"Epoch {i} - train: {loss['train'][-1]:.3f}, val: {loss['val'][-1]:.3f}, lr: {my_lr:.3e}, n_max: {current_n_max}, sub_batch: {sub_batch}") fig, ax = plt.subplots(n_test,n_max+2, figsize=(4*(n_max+2),4*n_test), dpi=100) loss_fig, loss_ax = plt.subplots(1, 1, figsize=(8, 8), dpi=100) j = 0 for x,d,nd in zip(data_test_tensor,d_test_tensor,test_n_sig.cpu().numpy().astype(int)): _, _, _, old_samples, samples = make_data(1,n_prior=n_prior,n_max=nd,n_post=n_post,signal_model=signal_model,flow=flow,cc_prior_model=cc_prior_model,cc_meas_model=cc_meas_model,meas=d.reshape(1,n_meas,n_max),valtest=True) # reshape and plot temp_samples = torch.permute(samples,(0,3,1,2))[0,:,:,:].cpu().numpy().reshape(nd,n_post,n_par+1) temp_old_samples = torch.permute(old_samples,(0,3,1,2))[0,:,:,:].cpu().numpy().reshape(nd,n_prior,n_par+1) samples = np.zeros((nd,n_post,n_par+1)) old_samples = np.zeros((nd,n_prior,n_par+1)) samples[:,:,2] = np.sqrt(temp_samples[:,:,0]**2 + temp_samples[:,:,1]**2) samples[:,:,1] = temp_samples[:,:,2] samples[:,:,0] = np.remainder(np.atan2(temp_samples[:,:,0],temp_samples[:,:,1]),2*np.pi)/(2.0*np.pi) old_samples[:,:,2] = np.sqrt(temp_old_samples[:,:,0]**2 + temp_old_samples[:,:,1]**2) old_samples[:,:,1] = temp_old_samples[:,:,2] old_samples[:,:,0] = np.remainder(np.atan2(temp_old_samples[:,:,0],temp_old_samples[:,:,1]),2*np.pi)/(2.0*np.pi) temp_x = x.cpu().numpy() x = np.zeros(n_par) x[2] = np.sqrt(temp_x[0]**2 + temp_x[1]**2) x[1] = temp_x[2] x[0] = np.remainder(np.atan2(temp_x[0],temp_x[1]),2*np.pi)/(2.0*np.pi) ax[j,0].plot(old_samples[0,:,0],old_samples[0,:,1],'xc',markersize=10,label='SNR={:.2f}'.format(signal_model.SNR_factor*x[2]*np.sqrt(n_meas))) ax[j,0].legend(loc='upper right') for k in range(1,nd): prior_s = old_samples[k,:,:].reshape(n_prior,n_par+1) post_s = samples[k-1,:,:].reshape(n_post,n_par+1) ax[j,k].plot(post_s[:,0],post_s[:,1],'.b',markersize=1) ax[j,k].plot(ns_samples[j,k-1,:,0],ns_samples[j,k-1,:,1],'.r',markersize=1) ax[j,k].set_xlim([0.0,1.0]) ax[j,k].set_ylim([0.0,1.0]) ax[j,k].plot(x[0],x[1],'xk',label='truth',markersize=10) # loop over multiple different orders of the measurements #it = iter([d[:,np.random.permutation(n_max)] for i in range(50)]) #temp_new_samples = [] #new_n_post = n_post // 50 #for new_d in it: # _, _, _, _, temp_samples = make_data(1,n_prior=n_prior,n_max=nd,n_post=new_n_post,flow=flow,cc_prior_model=cc_prior_model,cc_meas_model=cc_meas_model,meas=new_d.reshape(1,n_meas,n_max),LUT=LUT,valtest=True) # temp_new_samples.append(torch.permute(temp_samples,(0,3,1,2))[0,-1,:,:].cpu().numpy().reshape(new_n_post,n_par+1)) #temp_new_samples = np.array(temp_new_samples).reshape(-1,n_par+1) #new_samples = np.zeros(temp_new_samples.shape) #new_samples[:,2] = np.sqrt(temp_new_samples[:,0]**2 + temp_new_samples[:,1]**2) #new_samples[:,1] = temp_new_samples[:,2] #new_samples[:,0] = np.remainder(np.atan2(temp_new_samples[:,0],temp_new_samples[:,1]),2*np.pi)/(2.0*np.pi) ax[j,nd].plot(samples[-1,:,0],samples[-1,:,1],'.b',markersize=1) #ax[j,nd+1].plot(new_samples[:,0],new_samples[:,1],'.g',markersize=1) ax[j,nd].plot(ns_samples[j,-1,:,0],ns_samples[j,-1,:,1],'.r',markersize=1) #ax[j,nd+1].plot(ns_samples[j,-1,:,0],ns_samples[j,-1,:,1],'.r',markersize=1) ax[j,nd].plot(x[0],x[1],'xk',markersize=10) #ax[j,nd+1].plot(x[0],x[1],'xk',markersize=10) ax[j,nd].set_xlim([0.0,1.0]) ax[j,nd].set_ylim([0.0,1.0]) #ax[j,nd+1].set_xlim([0.0,1.0]) #ax[j,nd+1].set_ylim([0.0,1.0]) ##ax[j,nd+1].legend(loc='upper right') # do importance sampling fs = cw_model(n_meas,d=d[:,:].transpose(1,0).cpu().numpy()) #, output=output, resume=False, seed=1234) logL = np.zeros((n_post,2)) t1 = time.time() for k,s in enumerate(temp_samples[-1,:,:]): a = {"Asin": s[0], "Acos": s[1], "f0": s[2]} logL[k,0] = fs.log_likelihood(a) + fs.log_prior(a) logL[k,1] = s[-1] #print(logL.shape,time.time()-t1) logw = logL[:,0] - logL[:,1] logw = logw - np.max(logw) w = np.exp(logw) nw = w/(np.sum(w)) ESS = 1.0/np.sum(nw**2) sumw = np.sum(w) print(ESS,sumw,n_post) w_samples = samples[-1,np.random.choice(np.arange(n_post),n_post,p=nw),:] ax[j,nd+1].plot(ns_samples[j,-1,:,0],ns_samples[j,-1,:,1],'.r',markersize=1) ax[j,nd+1].plot(w_samples[:,0],w_samples[:,1],'.k',markersize=1,label='ESS/n={:.2f}'.format(ESS/n_post)) ax[j,nd+1].plot(x[0],x[1],'xk',markersize=10) ax[j,nd+1].set_xlim([0.0,1.0]) ax[j,nd+1].set_ylim([0.0,1.0]) ax[j,nd+1].legend(loc='upper right') kw = {"plot_datapoints": False, "plot_density": False, "levels": [0.5,0.99]} labels = ["phi0","f0","A"] ax2 = corner.corner(ns_samples[j,-1,:,:3],bins=50,smooth=0.05,labels=labels,truths=[x[0],x[1],x[2]],color='r',range=[(0.0,1.0),(0.0,1.0),(0.0,1.0)],quantiles=None,**kw) corner.corner(samples[nd-1,:,:3],bins=50,smooth=0.05,range=[(0.0,1.0),(0.0,1.0),(0.0,1.0)],fig=ax2,color='b',**kw) ##corner.corner(new_samples[:,:2],bins=20,range=[(0.0,1.0),(0.0,1.0)],fig=ax2,color='g',**kw) ax2.savefig('{}/training_corner_{}_{}.png'.format(plot_path,i,j)) j += 1 fig.savefig('{}/training_{}.png'.format(plot_path,i)) loss_ax.semilogx(train_loss_smooth, alpha=0.5, label="Train") loss_ax.semilogx(np.arange(len(val_loss_smooth))*100,val_loss_smooth, alpha=0.5, label="Val.") loss_ax.set_ylim(np.min(train_loss_smooth)-0.1, np.percentile(np.array(train_loss_smooth),90)) loss_ax.set_xlim([1000,iterations]) loss_ax.set_xlabel("Epoch") loss_ax.set_ylabel("Loss") loss_ax.legend() loss_ax.grid('on') loss_fig.savefig('{}/loss.png'.format(plot_path)) plt.close('all') frac = np.zeros(n_max+1) sub_idx = (i+1)*(4*n_max)/float(iterations) for q in range(n_max+1): frac[q] = int(batch_size*max(((sub_idx-q)/sub_idx),0)) sub_batch = -(np.diff(frac)) sub_batch = np.append(sub_batch,0) old_n_max = current_n_max current_n_max = int(np.argwhere(sub_batch==0)[0]) if old_n_max != current_n_max: print('updated current nmax to {} and sub batch size to {}'.format(current_n_max,sub_batch)) if i > int(iterations/4): sub_batch = (batch_size/n_max)*np.ones(n_max) #if i>0 and not i % 50000 and current_n_max