from glasflow import RealNVP import matplotlib.pyplot as plt import matplotlib.patches as patches import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from itertools import chain, permutations from datetime import datetime import pickle from scipy.interpolate import RectBivariateSpline from scipy.integrate import quad import time import shutil # for JS diveregcen calculation from scipy.spatial.distance import jensenshannon from scipy.stats import gaussian_kde from collections import namedtuple import seaborn as sns import os from scipy.stats import uniform, norm, chi from scipy.special import expit, log_wright_bessel from nessai.flowsampler import FlowSampler from nessai.model import Model from nessai.utils import setup_logger from nessai.livepoint import dict_to_live_points import corner ###################################### # TO-DO LIST (in no particular order) # # 1. Widen frequency band - go to longer observation times # 2. Add declination # 3. Merge amplitude and alpha into 2 Euclidian parameters (wrap around sky) # 4. Save JS values to file and plot comparisons # 5. Add attention head to conditional data compression # 6. Move to 16,32, ... segments # 7. Switch to properly heterodyned data # 8. Solve issue of Annual frequency variation > band # 9. Do importance sampling for the final epoch on the reordered samples seed = 1 torch.manual_seed(seed) np.random.seed(seed) sns.set_context("notebook") sns.set_palette("colorblind") device = "cuda" date_id = datetime.today().isoformat() plot_path = '/data/www.astro/chrism/uroboros/{}'.format(date_id) ns_path = './' try: os.mkdir('{}'.format(plot_path)) except: print('unable to make output directory {}'.format(plot_path)) exit(1) shutil.copyfile('./uroboros.py', '{}/uroboros.txt'.format(plot_path)) iterations = 200000 batch_size = 1024 plot_step = 10000 n_test = 4 # number of individual test data samples n_prior = 128 # number of samples used to represent the conditonal prior n_max_list = [1,2,4,8] n_max = n_max_list[-1] # max the number of measurements per sample glob_n_max = n_max n_par = 3 # the number of hyperparameters n_meas = 64 # the size of a measurement n_prior_cc_out = 16 # the size of compressed prior n_meas_cc_out = 16 # the size of compressed measurement n_sigma = 1.0 # the scale of the distance noise n_post = 3000 # the number of posterior samples used for plotting lr = 1e-3 # the learning rate run_id = 'cw_alpha_s{}_nmax{}_ntest{}'.format(seed,n_max,n_test) Nreorder = 10 multiflow = False multiflow_n_max = 1 if multiflow: multiflow_n_max = n_max # define the Flow that will estimate the parameters conditional on new data and compressed prior information flow = [] for i in range(multiflow_n_max): flow.append(RealNVP( n_inputs=n_par, # number of params n_transforms=4, n_conditional_inputs=n_prior_cc_out+n_meas_cc_out + 1, # size of compressed prior plus size of measurement plus segment index n_neurons=32, n_blocks_per_transform=2, batch_norm_within_blocks=True, linear_transform='permutation', batch_norm_between_transforms=True, ).to(device)) print(f"Created flow and sent to {device}") class MAB(nn.Module): def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): super(MAB, self).__init__() self.dim_V = dim_V self.num_heads = num_heads self.fc_q = nn.Linear(dim_Q, dim_V) self.fc_k = nn.Linear(dim_K, dim_V) self.fc_v = nn.Linear(dim_K, dim_V) if ln: self.ln0 = nn.LayerNorm(dim_V) self.ln1 = nn.LayerNorm(dim_V) self.fc_o = nn.Linear(dim_V, dim_V) def forward(self, Q, K): Q = self.fc_q(Q) K, V = self.fc_k(K), self.fc_v(K) dim_split = self.dim_V // self.num_heads Q_ = torch.cat(Q.split(dim_split, 2), 0) K_ = torch.cat(K.split(dim_split, 2), 0) V_ = torch.cat(V.split(dim_split, 2), 0) A = torch.softmax(Q_.bmm(K_.transpose(1,2))/np.sqrt(self.dim_V), 2) O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2) O = O if getattr(self, 'ln0', None) is None else self.ln0(O) O = O + F.relu(self.fc_o(O)) O = O if getattr(self, 'ln1', None) is None else self.ln1(O) return O class SAB(nn.Module): def __init__(self, dim_in, dim_out, num_heads, ln=False): super(SAB, self).__init__() self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln) def forward(self, X): return self.mab(X, X) class ISAB(nn.Module): def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False): super(ISAB, self).__init__() self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out)) nn.init.xavier_uniform_(self.I) self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln) self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln) def forward(self, X): H = self.mab0(self.I.repeat(X.size(0), 1, 1), X) return self.mab1(X, H) class PMA(nn.Module): def __init__(self, dim, num_heads, num_seeds, ln=False): super(PMA, self).__init__() self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim)) nn.init.xavier_uniform_(self.S) self.mab = MAB(dim, dim, dim, num_heads, ln=ln) def forward(self, X): return self.mab(self.S.repeat(X.size(0), 1, 1), X) class SmallSetTransformer(nn.Module): def __init__(self,): super().__init__() self.enc = nn.Sequential( SAB(dim_in=n_par, dim_out=32, num_heads=4), SAB(dim_in=32, dim_out=32, num_heads=4), ) self.dec = nn.Sequential( PMA(dim=32, num_heads=4, num_seeds=1), nn.Linear(in_features=32, out_features=n_prior_cc_out), nn.Sigmoid(), # I added this ) def forward(self, x): x = self.enc(x) x = self.dec(x) return x.squeeze(-2) class PI_NeuralNetwork(nn.Module): def __init__(self): super().__init__() self.flatten = nn.Flatten() self.nn1 = nn.Sequential( nn.Linear(n_par, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 64), ) self.nn2 = nn.Sequential( nn.Linear(n_par + 64, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, n_prior_cc_out), ) def forward(self, x0): # the input data has shape (bs,n_prior,n_par+1) # we want to make this (bs*n_prior,n_par) x = x0.flatten(0,1) # new shape (batch*n_prior,ndim) x = self.nn1(x) # output shape (batch*n_prior,64) x = x.reshape(-1,n_prior,64) # reshape to (batch,n_prior,64) print(x.shape,x0.shape) x = torch.concat([x,x0],dim=2) # concat to get (batch, n_prior,64+ndim) x = torch.mean(x,dim=1) # take mean to get (batch,64+ndim) x = self.nn2(x) # process again to get (batch,n_prior_cc_out) return x cc_prior_model = [] for i in range(multiflow_n_max): cc_prior_model.append(SmallSetTransformer().to(device)) # the compression model that takes measurement samples and compresses them cc_meas_model = nn.Sequential( nn.Linear(n_meas, 32), nn.ReLU(), nn.Linear(32, 32), nn.ReLU(), #nn.Linear(64, 64), #nn.ReLU(), nn.Linear(32, n_meas_cc_out), nn.Sigmoid() ) cc_meas_model.to(device) def calc_median_error(jsvalues, quantiles=(0.16, 0.84)): quants_to_compute = np.array([quantiles[0], 0.5, quantiles[1]]) quants = np.percentile(jsvalues, quants_to_compute * 100) summary = namedtuple("summary", ["median", "lower", "upper"]) summary.median = quants[1] summary.plus = quants[2] - summary.median summary.minus = summary.median - quants[0] return summary def calculate_js(samplesA, samplesB, ntests=10, xsteps=100): js_array = np.zeros(ntests) for j in range(ntests): nsamples = min([len(samplesA), len(samplesB)]) A = np.random.choice(samplesA, size=nsamples, replace=False) B = np.random.choice(samplesB, size=nsamples, replace=False) xmin = np.min([np.min(A), np.min(B)]) xmax = np.max([np.max(A), np.max(B)]) x = np.linspace(xmin, xmax, xsteps) A_pdf = gaussian_kde(A)(x) B_pdf = gaussian_kde(B)(x) js_array[j] = np.nan_to_num(np.power(jensenshannon(A_pdf, B_pdf), 2)) return calc_median_error(js_array) #def sig(Asin,Acos,f0,t,i): # """ # phi0 - the phase normalise between 0 and 1 # f0 - the frequency normalised in reference to the nyquist frequency (0-1) # N - the number of samples in teh timeseries # i - the index of the timeseries # """ # f = (0.4 + 0.2*f0) # fraction of the Nyquist frequency (0.4 - 0.6) # #return 0.25*np.sin(2.0*np.pi*(phi0 + f*0.5*(t + i))) # phase = 2.0*np.pi*(f*0.5*(t + i)) # return 0.25*(Asin*np.cos(phase) + Acos*np.sin(phase)) class cw_model(Model): """A simple two-dimensional Gaussian likelihood.""" def __init__(self,n_meas,n_max,d=None,i_ref=0): # Names of parameters to sample self.n_meas = n_meas self.dvec = d self.T = 1800.0 self.i_ref = i_ref # the starting segment index self.dt = self.T/self.n_meas self.t = torch.arange(self.n_meas)*self.dt # define time vector #self.t = np.arange(self.n_meas)*self.dt self.fmax = 0.5/self.dt self.fhet = 100.0 self.n_max = n_max self.n_sigma = 1.0 self.SNR_factor = 0.25 #0.25 # a scaling that controls the average SNR self.n_par = 2 # number of parameters self.asini = 6e6/3e8 self.Omega = 2.0*np.pi/86400.0 self.names = ["A", "f0", "alpha"] self.plot_bounds =[[0,3],[0,1],[0,1]] self.bounds = {"A": [0, np.inf], "f0": [0, 1], "alpha": [0,1]} #for i in range(n_max): # self.names.append("phase{}".format(i)) # self.plot_bounds.append([0,1]) # self.bounds["phase{}".format(i)] = [0,1] #self.repar_bounds =[[0,1],[0,1],[0,3]] #def sig(self,x,i): # """ # We assume that the input shape of x is [n_par,N] where N=n_data*n_max # We assume that the input shape of i is [N] # All tensors are expanded to have an extra dimension of n_meas # Output has shape [N,n_meas] # """ # N = x.shape[1] # number of locations # x = torch.tile(x.reshape(self.n_par,N,1),(1,1,self.n_meas)) # convert to (n_par,N,n_meas) # Asin = x[0] # the phase sin quadrature (N,n_meas) # Acos = x[1] # the phase cos quadrature (N,n_meas) # f0 = x[2] # the frequency normalised in reference to the nyquist frequency (0-1) (N,n_meas) # i = torch.tile(i.reshape(N,1),(1,n_meas)) #i - the index of the timeseries (N,n_meas) # # f = (0.4 + 0.2*f0)*self.fmax # f0 is fraction of the Nyquist frequency (0.4 - 0.6) (N,n_meas) # phase = 2.0*np.pi*f*(torch.tile(self.t.reshape(1,n_meas),(N,1)) + i*self.T) # shift the initial time in reference to first segment (N,n_meas) # return self.SNR_factor*(Asin*np.cos(phase) + Acos*np.sin(phase)) # return timeseries (N,n_meas) def sig(self,A,f0,alpha,phase,t,i): """ phi0 - the phase normalise between 0 and 1 f0 - the frequency normalised in reference to the nyquist frequency (0-1) N - the number of samples in teh timeseries i - the index of the timeseries """ f = self.fhet + (0.4 + 0.2*f0)*self.fmax # fraction of the Nyquist frequency (0.4 - 0.6) t = t + i*self.T phi = 2.0*np.pi*((f-self.fhet)*t + f*self.asini*np.sin(2.0*np.pi*alpha + self.Omega*t) + phase) return self.SNR_factor*A*np.sin(phi), phi def gen_pars(self,n_data): Asin = (torch.randn(size=(n_data,1))).to(device) Acos = (torch.randn(size=(n_data,1))).to(device) A = torch.sqrt(Asin**2 + Acos**2) phase = torch.remainder(torch.atan2(Asin,Acos),2*np.pi)/(2.0*np.pi) f0 = (torch.rand(size=(n_data,1))).to(device) # the true f0 value (bs,1) alpha = (torch.rand(size=(n_data,1))).to(device) Omega = (torch.concatenate((A,f0,alpha,phase),axis=1)).to(device) return Omega def gen_sig(self,n_data,n_max): Omega = self.gen_pars(n_data) A = Omega[:,0] f0 = Omega[:,1] alpha = Omega[:,2] phase = Omega[:,3] n = (self.n_sigma*torch.randn(size=(n_data,n_meas,n_max))).to(device) # noise on distance (bs,n_max) # generate the signal - the inputs all have shape (n_data,n_meas,n_max). The output has the same shape. meas, _ = self.sig(torch.tile(A.reshape(n_data,1,1),(1,n_meas,n_max)).flatten().cpu(), torch.tile(f0.reshape(n_data,1,1),(1,n_meas,n_max)).flatten().cpu(), torch.tile(alpha.reshape(n_data,1,1),(1,n_meas,n_max)).flatten().cpu(), torch.tile(phase.reshape(n_data,1,1),(1,n_meas,n_max)).flatten().cpu(), torch.tile(self.dt*torch.arange(n_meas).reshape(1,n_meas,1),(n_data,1,n_max)).flatten().cpu(), torch.tile(torch.arange(n_max).reshape(1,1,n_max),(n_data,n_meas,1)).flatten().cpu()) meas = meas.reshape(n_data,n_meas,n_max).to(device) + n meas = meas.type(torch.cuda.FloatTensor) return meas, Omega, n def new_point(self, N=1): """Draw n points. This is used for the initial sampling. Points do not need to be drawn from the exact prior but algorithm will be more efficient if they are. """ # There are various ways to create live points in nessai, such as # from dictionaries and numpy arrays. See nessai.livepoint for options d = { "A": chi(2).rvs(size=N), "f0": uniform.rvs(loc=self.bounds["f0"][0], scale=self.bounds["f0"][1] - self.bounds["f0"][0], size=N), "alpha": uniform.rvs(loc=self.bounds["alpha"][0], scale=self.bounds["alpha"][1] - self.bounds["alpha"][0], size=N), } #for i in range(self.n_max): # d["phase{}".format(i)] = uniform.rvs(loc=self.bounds["phase{}".format(i)][0], scale=self.bounds["phase{}".format(i)][1] - self.bounds["phase{}".format(i)][0], size=N) return dict_to_live_points(d) def new_point_log_prob(self, x): """Returns the log-probability for a new point. Since we have redefined `new_point` we also need to redefine this function. """ return self.log_prior(x) def log_prior(self, x): """ Returns log of prior given a live point assuming uniform priors on each parameter. """ # Check if values are in bounds, returns True/False # Then take the log to get 0/-inf and make sure the dtype is float log_p = np.log(self.in_bounds(x), dtype="float") # Iterate through each parameter (x and y) # since the live points are a structured array we can # get each value using just the name #for n in self.names: log_p += chi(2).logpdf(x["A"]) log_p -= np.log(self.bounds["f0"][1] - self.bounds["f0"][0]) # uniform prior log_p -= np.log(self.bounds["alpha"][1] - self.bounds["alpha"][0]) #for i in range(self.n_max): # log_p -= np.log(self.bounds["phase{}".format(i)][1] - self.bounds["phase{}".format(i)][0]) return log_p def log_likelihood(self, x): """ Returns log likelihood of given live point assuming a Gaussian likelihood. """ log_l = 0.0 # initialise the log likelihood A = np.array(x["A"]).reshape(-1) f0 = np.array(x["f0"]).reshape(-1) alpha = np.array(x["alpha"]).reshape(-1) N = A.shape[0] for i,d in enumerate(self.dvec): # loop over measurements #Omega = torch.from_numpy(np.array([x["A"],x["f0"],x["phase{}".format(i)]])).reshape(self.n_par,-1) #phase = torch.from_numpy(np.array(x["phase{}".format(i)])) idx = torch.from_numpy(np.ones(N)*i) + self.i_ref _,phi = self.sig(torch.tile(torch.from_numpy(np.array(x["A"])).reshape(-1,1),(1,n_meas)).flatten().cpu(), torch.tile(torch.from_numpy(np.array(x["f0"])).reshape(-1,1),(1,n_meas)).flatten().cpu(), torch.tile(torch.from_numpy(np.array(x["alpha"])).reshape(-1,1),(1,n_meas)).flatten().cpu(), torch.tile(torch.zeros(1).reshape(1,1,1),(N,n_meas)).flatten().cpu(), torch.tile(self.dt*torch.arange(n_meas).reshape(1,n_meas),(N,1)).flatten().cpu(), torch.tile(i*torch.ones(1).reshape(1,1,1),(N,n_meas)).flatten().cpu()) #s = s.reshape(N,n_meas) phi = phi.reshape(N,n_meas).numpy() #s = self.sig(A,f0,phase,torch.arange(self.n_meas),idx) #log_l += np.sum(norm.logpdf(s,loc=d,scale=self.n_sigma)) log_l += -0.25*n_meas*(A*self.SNR_factor/self.n_sigma)**2 + log_wright_bessel(1,1,0.25*((A*self.SNR_factor/self.n_sigma**2)**2)*(np.sum(d*np.sin(phi),axis=-1)**2 + np.sum(d*np.cos(phi),axis=-1)**2)) return log_l #def log_likelihood(self, x): # """ # Returns log likelihood of given live point assuming a Gaussian # likelihood. # """ # log_l = 0.0 # for i,d in enumerate(self.dvec): # s = sig(x["Asin"],x["Acos"],x["f0"],np.arange(self.n_meas),self.n_meas*i + self.i_ref*self.n_meas) # log_l += np.sum(norm.logpdf(s,loc=d,scale=n_sigma)) # return log_l def change_vars(self,samples): return samples # convert samples in [Asin, Acos, f0] to [phase, f0, Amp] #N = samples.shape[0] #new_samples = torch.zeros((N,self.n_par)) #new_samples[:,2] = torch.sqrt(samples[:,0]**2 + samples[:,1]**2) #new_samples[:,1] = samples[:,2] #new_samples[:,0] = torch.remainder(torch.atan2(samples[:,0],samples[:,1]),2*np.pi)/(2.0*np.pi) #return new_samples #def shift_time(self,samples,i): # # convert samples in [Asin, Acos, f0] to a different reference time # N = samples.shape[0] # new_samples = torch.zeros((N,self.n_par)) # phase = torch.remainder(torch.atan2(samples[:,0],samples[:,1]),2*np.pi)/(2.0*np.pi) # f = (0.4 + 0.2*samples[:,2])*0.5 # new_phase = phase + 2.0*np.pi*f*self.T*i # A = torch.sqrt(samples[:,0]**2 + samples[:,1]**2) # new_samples[:,0] = A*torch.sin(new_phase) # new_samples[:,1] = A*torch.cos(new_phase) # new_samples[:,2] = samples[:,2] # return new_samples def scatter_plot(self,ax,samples,ns_samples,reord_samples,x,n_max_list,n_prior,n_post,n_par,change_vars=False): samples = torch.permute(samples,(0,3,1,2))[0,:,:,:].reshape(n_max_list[-1],n_post,n_par) #old_samples = torch.permute(old_samples,(0,3,1,2))[0,:,:,:].reshape(nd,n_prior,n_par) reord_samples = torch.permute(reord_samples,(0,3,1,2))[0,:,:,:].reshape(n_max_list[-1],-1,n_par) if change_vars==True: samples = signal_model.change_vars(samples[:,:,:n_par].flatten(0,1)).reshape(n_max_list[-1],n_post,n_par).cpu().numpy() #old_samples = signal_model.change_vars(old_samples[:,:,:n_par].flatten(0,1)).reshape(nd,n_prior,n_par).cpu().numpy() reord_samples = signal_model.change_vars(reord_samples[:,:,:n_par].flatten(0,1)).reshape(n_max_list[-1],-1,n_par).cpu().numpy() x = signal_model.change_vars(x.reshape(1,n_par)).flatten().cpu().numpy() ns_samples = signal_model.change_vars(torch.from_numpy(ns_samples).flatten(0,2)).reshape(len(n_max_list),-1,2,n_par).cpu().numpy() else: samples = samples.cpu().numpy() #old_samples = old_samples.cpu().numpy() reord_samples = reord_samples.cpu().numpy() x = x.flatten().cpu().numpy() #ns_samples = ns_samples.flatten().reshape(n_max,-1,2,n_par).cpu().numpy() idx = 0 # index for the ns samples for k in n_max_list: # index for the uroboros samples # prior_s = old_samples[k+1,:,:n_par].reshape(n_prior,n_par) post_s = samples[k-1,:,:n_par].reshape(n_post,n_par) js_a = 1e3*calculate_js(post_s[:,a], ns_samples[idx,:,0,a]).median js_b = 1e3*calculate_js(post_s[:,b], ns_samples[idx,:,0,b]).median ax[idx].plot(post_s[:,a],post_s[:,b],'.b',alpha=0.5,markersize=1) ax[idx].plot(ns_samples[idx,:,0,a],ns_samples[idx,:,0,b],'.r',alpha=0.5,markersize=1) ax[idx].plot(ns_samples[idx,:,1,a],ns_samples[idx,:,1,b],'.g',alpha=0.5,markersize=1) ax[idx].annotate('JS = {:.0f},{:.0f}'.format(js_a,js_b),(0.5,0.9),xycoords='axes fraction') # #ax[j,k].set_xlim(signal_model.repar_bounds[a]) # #ax[j,k].set_ylim(signal_model.repar_bounds[b]) if change_vars==True: ax[idx].set_xlim(self.repar_bounds[a]) ax[idx].set_ylim(self.repar_bounds[b]) else: ax[idx].set_xlim(self.plot_bounds[a]) ax[idx].set_ylim(self.plot_bounds[b]) ax[idx].plot(x[a],x[b],'xk',label='truth',markersize=10) ax[idx].annotate('{}'.format(k),(0.9,0.05),xycoords='axes fraction') idx += 1 #js_a = 1e3*calculate_js(samples[-1,:,a], ns_samples[-1,:,0,a]).median #js_b = 1e3*calculate_js(samples[-1,:,b], ns_samples[-1,:,0,b]).median #ax[nd-1].plot(samples[-1,:,a],samples[-1,:,b],'.b',alpha=0.5,markersize=1) #ax[nd-1].plot(ns_samples[-1,:,0,a],ns_samples[-1,:,0,b],'.r',alpha=0.5,markersize=1) #ax[nd-1].plot(ns_samples[-1,:,1,a],ns_samples[-1,:,1,b],'.g',alpha=0.5,markersize=1) #ax[nd-1].plot(x[a],x[b],'xk',markersize=10) #ax[nd-1].annotate('JS = {:.0f},{:.0f}'.format(js_a,js_b),(0.5,0.9),xycoords='axes fraction') # #ax[j,nd-1].set_xlim(signal_model.repar_bounds[a]) # #ax[j,nd-1].set_ylim(signal_model.repar_bounds[b]) #if change_vars==True: # ax[nd-1].set_xlim(self.repar_bounds[a]) # ax[nd-1].set_ylim(self.repar_bounds[b]) #else: # ax[nd-1].set_xlim(self.plot_bounds[a]) # ax[nd-1].set_ylim(self.plot_bounds[b]) js_a = 1e3*calculate_js(reord_samples[-1,:,a], ns_samples[-1,:,0,a]).median js_b = 1e3*calculate_js(reord_samples[-1,:,b], ns_samples[-1,:,0,b]).median ax[-1].plot(reord_samples[-1,:,a],reord_samples[-1,:,b],'.k',alpha=0.5,markersize=1) ax[-1].plot(ns_samples[-1,:,0,a],ns_samples[-1,:,0,b],'.r',alpha=0.5,markersize=1) #ax[nd].plot(ns_samples[-1,:,1,a],ns_samples[-1,:,1,b],'.g',alpha=0.5,markersize=1) ax[-1].plot(x[a],x[b],'xk',markersize=10) ax[-1].annotate('JS = {:.0f},{:.0f}'.format(js_a,js_b),(0.5,0.9),xycoords='axes fraction') # #ax[j,nd-1].set_xlim(signal_model.repar_bounds[a]) # #ax[j,nd-1].set_ylim(signal_model.repar_bounds[b]) if change_vars==True: ax[-1].set_xlim(self.repar_bounds[a]) ax[-1].set_ylim(self.repar_bounds[b]) else: ax[-1].set_xlim(self.plot_bounds[a]) ax[-1].set_ylim(self.plot_bounds[b]) ax[-1].annotate('{}'.format(n_max_list[-1]),(0.9,0.05),xycoords='axes fraction') def plot_test_data(self,d_test_tensor,n_test_tensor,plot_path): """ Function to plot the test data """ d_test = d_test_tensor.cpu().numpy() print(d_test.shape) n_test = n_test_tensor.cpu().numpy() fig, ax = plt.subplots(d_test.shape[0],1, figsize=(32, 3*d_test.shape[0]), dpi=100) i = 0 for d,n in zip(d_test,n_test): t = np.linspace(0,self.T*d.shape[1],d.shape[1]*d.shape[0]) ax[i].plot(t,d.transpose(1,0).flatten()) ax[i].plot(t,d.transpose(1,0).flatten()-n.transpose(1,0).flatten()) i += 1 plt.savefig('{}/test_data.png'.format(plot_path)) plt.close() def make_data(n_data,n_prior=100,n_max=1,n_post=None,flow=None,cc_prior_model=None,cc_meas_model=None,meas=None,valtest=False): """ function to make training data n_data = number of training samples to make n_prior = the number of samples to use from the prior before compression n_post = the number of posterior samples n_max = the max number of measurements per training data sample meas = the distance measurements """ Omega = None n = None if meas is None: meas, Omega, n = cw_model(n_meas,n_max).gen_sig(n_data,n_max) # put everything else into eval mode for f,c in zip(flow,cc_prior_model): f.eval() c.eval() # compress the measured data - we should be in training mode for this since # it is the only place we do the compression if not valtest: cc_meas_model.train() c_meas = torch.zeros(n_data,n_meas_cc_out,n_max).to(device) for i in range(n_max): c_meas[:,:,i] = cc_meas_model(meas[:,:,i].detach()) else: cc_meas_model.eval() c_meas = torch.zeros(n_data,n_meas_cc_out,n_max).to(device) for i in range(n_max): with torch.no_grad(): c_meas[:,:,i] = cc_meas_model(meas[:,:,i].detach()) # initialise the prior samples tensor # there should be a small set of prior samples (and log probs) for each measurement and for each sample prior_label = torch.zeros(n_data,n_prior,n_par,n_max).to(device) # initialise the prior label tensor # PRIOR 1 - for 1st signals sample from the original prior #test = flow.inverse(torch.ones(n_data*n_prior,n_par),conditional=test_cond).to(device) #prior_label[:,:,0,0] = test[:,0].reshape(n_prior,n_data).transpose(1,0).to(device) prior_label[:,:,0,0] = torch.sqrt(torch.randn(size=(n_data,n_prior))**2 + torch.randn(size=(n_data,n_prior))**2).to(device) prior_label[:,:,1,0] = torch.rand(size=(n_data,n_prior)).to(device) prior_label[:,:,2,0] = torch.rand(size=(n_data,n_prior)).to(device) #prior_label[:,:,3,0] = (-0.5*prior_label[:,:,0,0]**2 - 0.5*prior_label[:,:,1,0]**2 - np.log(2.0*np.pi)).to(device) #torch.zeros(size=(n_data,n_prior)).to(device) # sort the prior samples by log-lik order #temp, idx = torch.sort(prior_label[:,:,3,0],dim=1) #prior_label[:,:,3,0] = temp ############################################################### # we can stop here IF ONLY 1 measurement is being considered # Otherwise we need to put the 1st (nth) measurement through the flow to get a new prior for the 2nd (n+1) measurement #flow.eval() # we DO NOT train the flow in the data generation step #cc_prior_model.eval() # we DO NOT train the prior compression in the data generation step for i in range(n_max-1): # loop over each event from the zeroth to the n-1'th (we don't want to do the last one) ii = i if multiflow else 0 with torch.no_grad(): c_prior = cc_prior_model[ii](prior_label[:,:,:,i]).detach() #.flatten(1,2)).detach() # compress the i'th prior data test_cond = torch.cat((c_meas[:,:,i].detach(),c_prior,(i/float(glob_n_max))*torch.ones(size=(n_data,1)).to(device)),dim=1).to(device) # combine the compressed measurement and prior and measurement indices test_cond = test_cond.tile(n_prior,1).to(device) # tile it to generate n_prior samples for each of n_data (n_data*n_prior,n_cc+n_meas) with torch.no_grad(): # run the current flow state to generate new posterior -> prior samples and log-likelihoods prior_samples = flow[ii].sample(n_data*n_prior,conditional=test_cond).to(device) # output shape should be (n_data*n_prior,n_cos) #prior_samples, _ = flow.inverse(torch.tile(z_prior,(n_data,1)),conditional=test_cond) #prior_logprob = flow[ii].log_prob(prior_samples,conditional=test_cond).to(device) # output shape should be (n_data*n_prior) # fill in the prior labels - these are now the priors for the NEXT measurement prior_label[:,:,0,i+1] = prior_samples[:,0].reshape(n_prior,n_data).transpose(1,0).to(device) prior_label[:,:,1,i+1] = prior_samples[:,1].reshape(n_prior,n_data).transpose(1,0).to(device) prior_label[:,:,2,i+1] = prior_samples[:,2].reshape(n_prior,n_data).transpose(1,0).to(device) #prior_label[:,:,3,i+1] = prior_logprob.reshape(n_prior,n_data).transpose(1,0).to(device) # If we want the iteratively generated posteriors -> priors for plotting then we generate more samples # we still use the fixed lower number of samples generated above for each stage # and we now do compute the posterior after the final measurement post_label = None if n_post is not None: post_label = torch.zeros(n_data,n_post,n_par,n_max).to(device) # initialise the posterior label tensor #flow.eval() # we DO NOT train the flow in the data generation step #cc_prior_model.eval() # we DO NOT train the prior compression in the data generation step for i in range(n_max): ii = i if multiflow else 0 # POST 1 - compress the uniform prior and add the 1st meas condition with torch.no_grad(): c_prior = cc_prior_model[ii](prior_label[:,:,:,i]).detach() #.flatten(1,2)).detach() test_cond = torch.cat((c_meas[:,:,i].detach(),c_prior,(i/float(glob_n_max))*torch.ones(size=(n_data,1)).to(device)),dim=1).to(device) test_cond = test_cond.tile(n_post,1).to(device) # has shape (n_data*n_prior,n_cc+n_meas) # keep doing this stage until we have the desired number of samples AFTER importance sampling #flag = True #rsum = 0 #prior_samples = np.zeros((n_post,n_par)) #prior_logprob = np.zeros(n_post) #while flag: with torch.no_grad(): temp_prior_samples = flow[ii].sample(n_data*n_post,conditional=test_cond).to(device) # output shape should be (n_data*n_post,n_cos) #temp_prior_logprob = flow[ii].log_prob(temp_prior_samples,conditional=test_cond).to(device) # output shape should be (n_data*n_post) # fill in the posteriors - these are the posteriors AFTER each event post_label[:,:,0,i] = temp_prior_samples[:,0].reshape(n_post,n_data).transpose(1,0).to(device) post_label[:,:,1,i] = temp_prior_samples[:,1].reshape(n_post,n_data).transpose(1,0).to(device) post_label[:,:,2,i] = temp_prior_samples[:,2].reshape(n_post,n_data).transpose(1,0).to(device) #post_label[:,:,3,i] = temp_prior_logprob.reshape(n_post,n_data).transpose(1,0).to(device) return Omega, meas, c_meas, prior_label, post_label, n # save or load the true params and measurements try: os.mkdir('{}/{}'.format(ns_path,run_id)) except: pass parfile = '{}/{}/testdata_{}.dat'.format(ns_path,run_id,run_id) print(parfile) if os.path.isfile(parfile)==False: # generate test data print('making test data') data_test_tensor, d_test_tensor, _, _, _, n_test_tensor = make_data(n_data=n_test, n_max=n_max, n_prior=n_prior, flow=flow, cc_prior_model=cc_prior_model, cc_meas_model=cc_meas_model, valtest=True) test_n_sig = torch.ones(n_test)*n_max print('made test data') # and save it to file with open(parfile, 'wb') as f: # Python 3: open(..., 'wb') pickle.dump([data_test_tensor.cpu().numpy(), d_test_tensor.cpu().numpy(), n_test_tensor.cpu().numpy()], f) print('saved par file {}'.format(parfile)) else: # open the parfile and read the data and parameters with open(parfile, 'rb') as f: # Python 3: open(..., 'rb') data_test_tensor, d_test_tensor, n_test_tensor = pickle.load(f) data_test_tensor = torch.tensor(data_test_tensor).to(device) d_test_tensor = torch.tensor(d_test_tensor).to(device) n_test_tensor = torch.tensor(n_test_tensor).to(device) test_n_sig = torch.ones(n_test)*n_max print('read in par file {}'.format(parfile)) print('read in test params') # plot test data print('plotting test data') signal_model = cw_model(n_meas,n_max) signal_model.plot_test_data(d_test_tensor,n_test_tensor,plot_path) print('plotted test data') # The FlowSampler object is used to managed the sampling. Keyword arguments # are passed to the nested sampling. j = 0 ns_samples = [] print(data_test_tensor.shape,d_test_tensor.shape) for x,d in zip(data_test_tensor.cpu().numpy(),d_test_tensor.cpu().numpy()): for i in n_max_list: for k in range(2): if os.path.isfile("{}/{}/ns_{}_{}_{}.dat".format(ns_path,run_id,i,j,k))==False: output = "{}/{}/ns_{}_{}_{}/".format(ns_path,run_id,i,j,k) print('starting ns and outputting to {}'.format(output)) logger = setup_logger(output=output) if k==0: fs = FlowSampler(cw_model(n_meas,n_max,d=d.transpose(1,0)[:i,:]), output=output, resume=False, reset_flow=8, volume_fraction=0.98, seed=1234) else: fs = FlowSampler(cw_model(n_meas,n_max,d=d.transpose(1,0)[i-1,:].reshape(1,-1),i_ref=i-1), output=output, reset_flow=8, volume_fraction=0.98, resume=False, seed=1234) fs.run() temp_ns_samples = [] for s in fs.posterior_samples: temp_ns_samples.append([s[q] for q in np.arange(n_par)]) temp_ns_samples = np.array(temp_ns_samples,dtype=np.float64) ns_samples.append(temp_ns_samples) # save nested sampling samples print('saving ns') with open("{}/{}/ns_{}_{}_{}.dat".format(ns_path,run_id,i,j,k), "wb") as f: temp_ns_samples.tofile(f) else: fn = '{}/{}/ns_{}_{}_{}.dat'.format(ns_path,run_id,i,j,k) print('loading ns file {}'.format(fn)) with open("{}/{}/ns_{}_{}_{}.dat".format(ns_path,run_id,i,j,k), "rb") as f: ns_samples.append(np.fromfile(f).reshape(-1,n_par).astype(np.float64)) j += 1 # restructure the ns samples k = 0 min_len = 1e16 for i in range(n_test): for j in n_max_list: for b in range(2): if ns_samples[k].shape[0] < min_len: min_len = ns_samples[k].shape[0] k += 1 ns_min = min(min_len,n_post) temp_ns_samples = np.zeros((n_test,len(n_max_list),ns_min,2,n_par)) k = 0 for i in range(n_test): c = 0 for j in n_max_list: for b in range(2): idx = np.random.choice(ns_samples[k].shape[0],size=ns_min) temp_ns_samples[i,c,:,b,:] = ns_samples[k][idx,:] k += 1 c += 1 # increment the n_max index ns_samples = torch.from_numpy(temp_ns_samples).flatten(0,3).reshape(n_test,len(n_max_list),ns_min,2,n_par).cpu().numpy() flow_pars = [] cc_prior_pars = [] #cc_prior_model[0].parameters() for i in range(multiflow_n_max): flow_pars = chain(flow_pars,flow[i].parameters()) cc_prior_pars = chain(cc_prior_pars,cc_prior_model[i].parameters()) all_pars = chain(flow_pars,cc_prior_pars,cc_meas_model.parameters()) optimiser = torch.optim.Adam(all_pars,lr=lr) #optimiser = torch.optim.Adam(chain(flow.parameters(),cc_prior_model.parameters(),cc_meas_model.parameters()),lr=lr) n_loss_avg = 1000 loss = dict(train=[], val=[]) train_loss_smooth = [] val_loss_smooth = [] scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimiser, iterations, eta_min=1e-6, last_epoch=-1) current_n_max = 1 sub_batch = np.zeros(n_max+1) sub_batch[0] = batch_size for i in range(iterations+1): # perform inference (while training) on the final measurement data_train_tensor = torch.empty(0,n_par).to(device) # <--- initialize train_cond = torch.empty(0,n_meas_cc_out+n_prior_cc_out+1).to(device) # <--- initialize #orig_data_train_tensor, _, dist_train_tensor, priors_train_tensor, _ = make_data(n_data=batch_size,n_max=current_n_max,n_prior=n_prior,flow=flow,cc_prior_model=cc_prior_model,cc_meas_model=cc_meas_model) #start_idx = 0 optimiser.zero_grad() _loss = 0.0 for k in range(current_n_max): kk = k if multiflow else 0 #end_idx = start_idx + int(sub_batch[k]) # all networks are set to eval mode inside the data generation - except for the measurement compression temp_data_train_tensor, _, temp_dist_train_tensor, temp_priors_train_tensor, _, _ = make_data(n_data=int(sub_batch[k]),n_max=k+1,n_prior=n_prior,flow=flow,cc_prior_model=cc_prior_model,cc_meas_model=cc_meas_model) #temp_data_train_tensor = orig_data_train_tensor[start_idx:end_idx,:].detach() #temp_dist_train_tensor = dist_train_tensor[start_idx:end_idx,:,:k+1] #temp_priors_train_tensor = priors_train_tensor[start_idx:end_idx,:,:,:k+1].detach() temp_data_train_tensor = temp_data_train_tensor.detach() temp_priors_train_tensor = temp_priors_train_tensor.detach() cc_prior_model[kk].train() compressed_prior = cc_prior_model[kk](temp_priors_train_tensor[:,:,:,-1]) #.flatten(1,2)) # take the last event prior temp_train_cond = torch.cat((temp_dist_train_tensor[:,:,-1],compressed_prior,(k/float(glob_n_max))*torch.ones(size=(int(sub_batch[k]),1)).to(device)),dim=1).to(device) data_train_tensor = torch.cat((data_train_tensor,temp_data_train_tensor[:,:n_par]),dim=0) train_cond = torch.cat((train_cond,temp_train_cond),dim=0) #start_idx += int(sub_batch[k]) flow[kk].train() #optimiser.zero_grad() _loss -= flow[kk].log_prob(temp_data_train_tensor[:,:n_par], conditional=temp_train_cond).mean()*(sub_batch[k]/float(batch_size)) # compute the loss _loss.backward() optimiser.step() scheduler.step() train_loss = _loss.mean().item() loss["train"].append(train_loss) train_loss_smooth.append(np.median(loss["train"][max(i-n_loss_avg,0):])) # validation if not i % 100: data_val_tensor = torch.empty(0,n_par).to(device) # <--- initialize val_cond = torch.empty(0,n_meas_cc_out+n_prior_cc_out+1).to(device) # <--- initialize #orig_data_val_tensor, _, dist_val_tensor, priors_val_tensor, _ = make_data(n_data=batch_size,n_max=current_n_max,n_prior=n_prior,flow=flow,cc_prior_model=cc_prior_model,cc_meas_model=cc_meas_model,valtest=True) #start_idx = 0 _loss = 0.0 for k in range(current_n_max): ii = i if multiflow else 0 #end_idx = start_idx + int(sub_batch[k]) temp_data_val_tensor, _, temp_dist_val_tensor, temp_priors_val_tensor, _, _ = make_data(n_data=int(sub_batch[k]),n_max=k+1,n_prior=n_prior,flow=flow,cc_prior_model=cc_prior_model,cc_meas_model=cc_meas_model,valtest=True) #temp_data_val_tensor = orig_data_val_tensor[start_idx:end_idx,:].detach() #temp_dist_val_tensor = dist_val_tensor[start_idx:end_idx,:,:k+1].detach() #temp_priors_val_tensor = priors_val_tensor[start_idx:end_idx,:,:,:k+1].detach() temp_data_val_tensor = temp_data_val_tensor.detach() temp_dist_val_tensor = temp_dist_val_tensor.detach() temp_priors_val_tensor = temp_priors_val_tensor.detach() cc_prior_model[kk].eval() with torch.no_grad(): compressed_prior = cc_prior_model[kk](temp_priors_val_tensor[:,:,:,-1]) #.flatten(1,2)) # take the last event prior temp_val_cond = torch.cat((temp_dist_val_tensor[:,:,-1],compressed_prior,(k/float(glob_n_max))*torch.ones(size=(int(sub_batch[k]),1)).to(device)),dim=1).to(device) data_val_tensor = torch.cat((data_val_tensor,temp_data_val_tensor[:,:n_par]),dim=0) val_cond = torch.cat((val_cond,temp_val_cond),dim=0) #start_idx += int(sub_batch[k]) flow[kk].eval() with torch.no_grad(): _loss -= flow[kk].log_prob(temp_data_val_tensor[:,:n_par], conditional=temp_val_cond).mean()*(sub_batch[k]/float(batch_size)) # .item() val_loss = _loss.item() loss["val"].append(val_loss) val_loss_smooth.append(val_loss) if np.isnan(loss["val"][-1]): print('validation loss is nan: i={}'.format(i)) print('sub_batch is {}'.format(sub_batch)) #print('last_idx is {}'.format(last_idx)) print('cond_val any nans = {}'.format(torch.isnan(val_cond).any())) print('y_val any nans = {}'.format(torch.isnan(orig_data_val_tensor).any())) #print('x_val any nans = {}'.format(torch.isnan(x_val).any())) print('x_comp_val any nans = {}'.format(torch.isnan(dist_val_tensor).any())) print('y_prior_val any nans = {}'.format(torch.isnan(temp_priors_val_tensor).any())) print('compressed_prior any nans = {}'.format(torch.isnan(compressed_prior).any())) print('cond_val any nans = {}'.format(torch.isnan(val_cond).any())) if not i % plot_step: my_lr = scheduler.get_last_lr()[0] print(f"Epoch {i} - train: {loss['train'][-1]:.3f}, val: {loss['val'][-1]:.3f}, lr: {my_lr:.3e}, n_max: {current_n_max}, sub_batch: {sub_batch}") for a in range(n_par): for b in range(a+1,n_par): fig1, ax1 = plt.subplots(n_test,len(n_max_list)+1, figsize=(4*(len(n_max_list)+1),4*n_test), dpi=100) #fig2, bx2 = plt.subplots(n_test,n_max+1, figsize=(4*(n_max+1),4*n_test), dpi=100) loss_fig, loss_ax = plt.subplots(1, 1, figsize=(8, 8), dpi=100) j = 0 for x,d,nd in zip(data_test_tensor,d_test_tensor,test_n_sig.cpu().numpy().astype(int)): _, _, _, _, samples, _ = make_data(1,n_prior=n_prior,n_max=nd,n_post=n_post,flow=flow,cc_prior_model=cc_prior_model,cc_meas_model=cc_meas_model,meas=d.reshape(1,n_meas,n_max),valtest=True) # reorder observations #reord_old_samples = torch.empty(1,0,n_par,n_max).to(device) reord_samples = torch.empty(1,0,n_par,n_max).to(device) for _ in range(Nreorder): idx = torch.randperm(n_max) _, _, _, _, temp_samples, _ = make_data(1,n_prior=n_prior,n_max=nd,n_post=int(n_post/float(Nreorder)),flow=flow,cc_prior_model=cc_prior_model,cc_meas_model=cc_meas_model,meas=d.reshape(1,n_meas,n_max)[:,:,idx],valtest=True) #reord_old_samples = torch.cat((reord_old_samples,temp_old_samples),dim=1) reord_samples = torch.cat((reord_samples,temp_samples),dim=1) # reshape and plot signal_model.scatter_plot(ax1[j],samples,ns_samples[j],reord_samples,x,n_max_list,n_prior,n_post,n_par) #signal_model.scatter_plot(bx2[j],samples,ns_samples[j],reord_samples,x,nd,n_prior,n_post,n_par,change_vars=True) j += 1 fig1.savefig('{}/training_{}_{}{}.png'.format(plot_path,i,a,b)) #fig2.savefig('{}/training_cv_{}_{}{}.png'.format(plot_path,i,a,b)) loss_ax.semilogx(train_loss_smooth, alpha=0.5, label="Train") loss_ax.semilogx(np.arange(len(val_loss_smooth))*100,val_loss_smooth, alpha=0.5, label="Val.") #loss_ax.semilogx(val_loss_smooth, alpha=0.5, label="Val.") loss_ax.set_ylim(np.min(train_loss_smooth)-0.1, np.percentile(np.array(train_loss_smooth),90)) loss_ax.set_xlim([1000,iterations]) loss_ax.set_xlabel("Epoch") loss_ax.set_ylabel("Loss") loss_ax.legend() loss_ax.grid('on') loss_fig.savefig('{}/loss.png'.format(plot_path)) plt.close('all') frac = np.zeros(n_max+1) sub_idx = (i+1)*(4*n_max)/float(iterations) for q in range(n_max+1): frac[q] = int(0.5*batch_size*max(((sub_idx-q)/sub_idx),0)) sub_batch = -2*(np.diff(frac)) sub_batch = np.append(sub_batch,0) old_n_max = current_n_max current_n_max = int(np.argwhere(sub_batch==0)[0]) if old_n_max != current_n_max: print('updated current nmax to {} and sub batch size to {}'.format(current_n_max,sub_batch)) if i > int(iterations/4): sub_batch = (batch_size/n_max)*np.ones(n_max) # if starting training another flow then copy the weights across from the previous one if current_n_max > old_n_max and multiflow: print('updated current_n_max from {} -> {}'.format(old_n_max,current_n_max)) start = old_n_max - 1 end = current_n_max - 1 flow[end].load_state_dict(flow[start].state_dict()) cc_prior_model[end].load_state_dict(cc_prior_model[start].state_dict()) print('copied model {} to model {}'.format(start,end)) old_n_max = current_n_max #if i>0 and not i % 50000 and current_n_max