from glasflow import RealNVP import matplotlib.pyplot as plt import matplotlib.patches as patches import numpy as np import torch import torch.nn as nn from itertools import chain, permutations from datetime import datetime import pickle from scipy.interpolate import RectBivariateSpline from scipy.integrate import quad import time import shutil import seaborn as sns import os from scipy.stats import norm from nessai.flowsampler import FlowSampler from nessai.model import Model from nessai.utils import setup_logger import corner seed = 2 torch.manual_seed(seed) np.random.seed(seed) sns.set_context("notebook") sns.set_palette("colorblind") device = "cuda" date_id = datetime.today().isoformat() plot_path = '/data/www.astro/chrism/uroboros/{}'.format(date_id) ns_path = './' try: os.mkdir('{}'.format(plot_path)) except: print('unable to make output directory {}'.format(plot_path)) exit(1) shutil.copyfile('./uroboros.py', '{}/uroboros.txt'.format(plot_path)) n_test = 4 # number of individual test data samples n_prior = 250 # number of samples used to represent the conditonal prior n_max = 5 # max the number of measurements per sample n_cos = 2 # the number of cosmological params n_cc_out = 16 # the size of compressed prior n_sigma = 0.025 # the scale of the distance noise n_post = 3000 # the number of posterior samples used for plotting lr = 1e-3 # the learning rate run_id = 'new_s{}_nmax{}_ntest{}'.format(seed,n_max,n_test) # define the Flow that will estimate the cosmological parameters conditional on new data and compressed prior information flow = RealNVP( n_inputs=n_cos, # number of cosmological params n_transforms=5, n_conditional_inputs=n_cc_out+1, # size of compressed prior plus size of measurement n_neurons=32, batch_norm_between_transforms=True, ) flow.to(device) print(f"Created flow and sent to {device}") # the compression model that takes samples from the prior and compresses them cc_model = nn.Sequential( nn.Linear((n_cos+1)*n_prior, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, n_cc_out), nn.Sigmoid() ) cc_model.to(device) def dL_slow(Om,z,H0=1.0): """ returns the luminosity distance given the Hubble constant, the matter energy density and the redshift This is a basic expansion of the true relation for low redshift. """ f = lambda zp: (Om*(1+zp)**3 + (1.0 - Om))**(-0.5) y, err = quad(f,0,z) dL = y/H0 return dL def dL_fast(H0,Om,z,LUT=None): if LUT is None: values = np.zeros((Om.size, z.size)) for i,a in enumerate(Om): for j,b in enumerate(z): values[i,j] = dL_slow(a,b) LUT = RectBivariateSpline(Om,z,values) return LUT else: return LUT.ev(Om,z)/H0 # TEST LUT = dL_fast(1.0,np.linspace(0,1,100),np.linspace(0,3,100)) zvec = np.linspace(0,3,100) plt.figure() plt.plot(zvec,dL_fast(0.1,0.3,zvec,LUT=LUT)) plt.plot(zvec,dL_fast(0.2,0.6,zvec,LUT=LUT)) plt.plot(zvec,dL_fast(0.7,0.1,zvec,LUT=LUT)) plt.savefig('{}/test.png'.format(plot_path)) plt.close() def dL(H0,Om,z): """ returns the luminosity distance given the Hubble constant, the matter energy density and the redshift This is a basic expansion of the true relation for low redshift. """ q0 = 1.0 - 0.5*Om dL = (1.0/H0)*(z + 0.5*(1.0-q0)*z**2) return dL class GaussianModel(Model): """A simple two-dimensional Gaussian likelihood.""" def __init__(self,d): # Names of parameters to sample self.nmeas = d.size self.dvec = d self.LUT = dL_fast(1.0,np.linspace(0,1,100),np.linspace(0,3,100)) print('dvec {}'.format(self.dvec)) self.names = ["H0", "Om"] self.bounds = {"H0": [0.2, 1], "Om": [0, 1]} for i in range(self.nmeas): self.names.append('z{}'.format(i)) self.bounds['z{}'.format(i)] = [0,3] def log_prior(self, x): """ Returns log of prior given a live point assuming uniform priors on each parameter. """ # Check if values are in bounds, returns True/False # Then take the log to get 0/-inf and make sure the dtype is float log_p = np.log(self.in_bounds(x), dtype="float") # Iterate through each parameter (x and y) # since the live points are a structured array we can # get each value using just the name #for n in self.names: log_p -= np.log(self.bounds["H0"][1] - self.bounds["H0"][0]) # uniform prior log_p -= np.log(self.bounds["Om"][1] - self.bounds["Om"][0]) # uniform prior for k in self.names[2:]: # a prior for each measurement log_p += np.log(3.0) - 3.0*np.log(self.bounds[k][1]) + 2.0*np.log(x[k]) # p(z) ~ z^2 prior return log_p def log_likelihood(self, x): """ Returns log likelihood of given live point assuming a Gaussian likelihood. """ log_l = 0.0 for d,z in zip(self.dvec,self.names[2:]): d0 = dL_fast(x["H0"],x["Om"],x[z],LUT=self.LUT) log_l += norm.logpdf(d0,loc=d,scale=n_sigma) return log_l def make_data(n_data,n_prior=100,n_max=1,n_post=None,flow=None,cc_model=None,meas=None,LUT=None): """ function to make training data n_data = number of training samples to make n_prior = the number of samples to use from the prior before compression n_post = the number of posterior samples n_max = the max number of measurements per training data sample meas = the distance measurements """ Omega = None z = None # generate measured data if none has been supplied if meas is None: H0 = (0.2 + (0.8*torch.rand(size=(n_data,1)))).to(device) # the true H0 value (bs,1) Om = (torch.rand(size=(n_data,1))).to(device) # the true Om value (bs,1) Omega = (torch.concatenate((H0,Om),axis=1)).to(device) z = 3*(torch.rand(size=(n_data,n_max))**(1.0/3.0)).to(device) # redshift (bs,n_max) n = (n_sigma*torch.randn(size=(n_data,n_max))).to(device) # noise on distance (bs,n_max) meas = dL_fast(torch.tile(H0,(1,n_max)).flatten().cpu(),torch.tile(Om,(1,n_max)).flatten().cpu(),z.flatten().cpu(),LUT=LUT).reshape(n_data,n_max).to(device) + n # distance (bs,n_max) meas = meas.reshape(n_data,n_max).to(device) meas = meas.type(torch.cuda.FloatTensor) # initialise the prior samples tensor # there should be a small set of prior samples (and log probs) for each measurement and for each sample prior_label = torch.zeros(n_data,n_prior,n_cos+1,n_max).to(device) # initialise the prior label tensor #post_label = None #if n_post is not None: #post_label = torch.zeros(n_data,n_post,n_cos+1,n_max).to(device) # initialise the posterior label tensor # PRIOR 1 - for 1st signals sample from uniform prior prior_label[:,:,0,0] = (0.2 + 0.8*torch.rand(size=(n_data,n_prior))).to(device) prior_label[:,:,1,0] = torch.rand(size=(n_data,n_prior)).to(device) prior_label[:,:,2,0] = -1.0*np.log(0.8)*torch.ones(size=(n_data,n_prior)).to(device) flow.eval() for i in range(n_max-1): c_prior = cc_model(prior_label[:,:,:,i].flatten(1,2)).detach() test_cond = torch.cat((meas[:,i].reshape(n_data,1).detach(),c_prior),dim=1).to(device) test_cond = test_cond.tile(n_prior,1).to(device) # has shape (n_data*n_prior,n_cc+n_meas) with torch.no_grad(): prior_samples = flow.sample(n_data*n_prior,conditional=test_cond).to(device) # output shape should be (n_data*n_prior,n_cos) prior_logprob = flow.log_prob(prior_samples,conditional=test_cond).to(device) # output shape should be (n_data*n_prior) prior_label[:,:,0,i+1] = prior_samples[:,0].reshape(n_prior,n_data).transpose(1,0).to(device) prior_label[:,:,1,i+1] = prior_samples[:,1].reshape(n_prior,n_data).transpose(1,0).to(device) prior_label[:,:,2,i+1] = prior_logprob.reshape(n_prior,n_data).transpose(1,0).to(device) post_label = None if n_post is not None: post_label = torch.zeros(n_data,n_post,n_cos+1,n_max).to(device) # initialise the posterior label tensor for i in range(n_max): # POST 1 - compress the uniform prior and add the 1st meas condition c_prior = cc_model(prior_label[:,:,:,i].flatten(1,2)).detach() test_cond = torch.cat((meas[:,i].reshape(n_data,1).detach(),c_prior),dim=1).to(device) test_cond = test_cond.tile(n_post,1).to(device) # has shape (n_data*n_prior,n_cc+n_meas) with torch.no_grad(): prior_samples = flow.sample(n_data*n_post,conditional=test_cond).to(device) # output shape should be (n_data*n_post,n_cos) prior_logprob = flow.log_prob(prior_samples,conditional=test_cond).to(device) # output shape should be (n_data*n_post) post_label[:,:,0,i] = prior_samples[:,0].reshape(n_post,n_data).transpose(1,0).to(device) post_label[:,:,1,i] = prior_samples[:,1].reshape(n_post,n_data).transpose(1,0).to(device) post_label[:,:,2,i] = prior_logprob.reshape(n_post,n_data).transpose(1,0).to(device) return Omega, meas, prior_label, z, post_label # save or load the true params and measurements try: os.mkdir('{}/{}'.format(ns_path,run_id)) except: pass parfile = '{}/{}/testdata_{}.dat'.format(ns_path,run_id,run_id) if os.path.isfile(parfile)==False: # generate test data LUT = dL_fast(1.0,np.linspace(0,1,100),np.linspace(0,3,100)) data_test_tensor, d_test_tensor, prior_test_tensor, test_z, _ = make_data(n_data=n_test,n_max=n_max,n_prior=n_prior,flow=flow,cc_model=cc_model,LUT=LUT) test_n_sig = torch.ones(n_test)*n_max with open(parfile, 'wb') as f: # Python 3: open(..., 'wb') pickle.dump([data_test_tensor.cpu().numpy(), d_test_tensor.cpu().numpy(), test_z.cpu().numpy()], f) print('saved par file {}'.format(parfile)) else: with open(parfile, 'rb') as f: # Python 3: open(..., 'rb') data_test_tensor, d_test_tensor, test_z = pickle.load(f) data_test_tensor = torch.tensor(data_test_tensor).to(device) d_test_tensor = torch.tensor(d_test_tensor).to(device) test_z = torch.tensor(test_z).to(device) test_n_sig = torch.ones(n_test)*n_max print('read in par file {}'.format(parfile)) # The FlowSampler object is used to managed the sampling. Keyword arguments # are passed to the nested sampling. j = 0 ns_samples = [] for x,d in zip(data_test_tensor.cpu().numpy(),d_test_tensor.cpu().numpy()): for i in range(n_max): if os.path.isfile("{}/{}/ns_{}_{}_{}.dat".format(ns_path,run_id,run_id,i,j))==False: print('starting ns') output = "{}/{}/ns_{}_{}_{}/".format(ns_path,run_id,run_id,i,j) logger = setup_logger(output=output) print(n_max,d[:i+1]) fs = FlowSampler(GaussianModel(d[:i+1]), output=output, resume=False, seed=1234) fs.run() temp_ns_samples = [] for s in fs.posterior_samples: temp_ns_samples.append([s[q] for q in np.arange(3)]) temp_ns_samples = np.array(temp_ns_samples,dtype=np.float64) ns_samples.append(temp_ns_samples) print(temp_ns_samples.shape) corner.corner(temp_ns_samples[:,:2],labels=["H0","Om"],truths=[x[0],x[1]]) plt.show() # save nested sampling samples print('saving ns') with open("{}/{}/ns_{}_{}_{}.dat".format(ns_path,run_id,run_id,i,j), "wb") as f: temp_ns_samples.tofile(f) else: fn = '{}/{}/ns_{}_{}_{}.dat'.format(ns_path,run_id,run_id,i,j) print('loading ns file {}'.format(fn)) with open("{}/{}/ns_{}_{}_{}.dat".format(ns_path,run_id,run_id,i,j), "rb") as f: ns_samples.append(np.fromfile(f).reshape(-1,n_cos+1).astype(np.float64)) j += 1 # restructure the ns samples k = 0 min_len = 1e16 for i in range(n_test): for j in range(n_max): if ns_samples[k].shape[0] < min_len: min_len = ns_samples[k].shape[0] k += 1 ns_min = min(min_len,n_post) temp_ns_samples = np.zeros((n_test,n_max,ns_min,n_cos+1)) k = 0 for i in range(n_test): for j in range(n_max): temp_ns_samples[i,j,:,:] = ns_samples[k][:ns_min,:] k += 1 ns_samples = temp_ns_samples print(ns_samples.shape) optimiser = torch.optim.Adam(chain(flow.parameters(),cc_model.parameters()),lr=lr) LUT = dL_fast(1.0,np.linspace(0,1,100),np.linspace(0,3,100)) iterations = 10000 batch_size = 1024 n_loss_avg = 1000 loss = dict(train=[], val=[]) train_loss_smooth = [] val_loss_smooth = [] scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimiser, iterations, eta_min=1e-6, last_epoch=-1) current_n_max = 1 sub_batch = batch_size for i in range(iterations+1): flow.eval() cc_model.eval() data_train_tensor, dist_train_tensor, priors_train_tensor, _, _ = make_data(n_data=sub_batch,n_max=1,n_prior=n_prior,flow=flow,cc_model=cc_model,LUT=LUT) train_cond = torch.cat((dist_train_tensor[:,-1].reshape(-1,1),cc_model(priors_train_tensor[:,:,:,-1].flatten(1,2))),dim=1).to(device) for k in range(2,current_n_max): temp_data_train_tensor, temp_dist_train_tensor, temp_priors_train_tensor, _, _ = make_data(n_data=sub_batch,n_max=k,n_prior=n_prior,flow=flow,cc_model=cc_model,LUT=LUT) temp_train_cond = torch.cat((temp_dist_train_tensor[:,-1].reshape(-1,1),cc_model(temp_priors_train_tensor[:,:,:,-1].flatten(1,2))),dim=1).to(device) data_train_tensor = torch.cat((data_train_tensor,temp_data_train_tensor),dim=0) train_cond = torch.cat((train_cond,temp_train_cond),dim=0) flow.train() cc_model.train() optimiser.zero_grad() _loss = -flow.log_prob(data_train_tensor, conditional=train_cond).mean() # compute the loss _loss.backward() optimiser.step() scheduler.step() train_loss = _loss.item() loss["train"].append(train_loss) train_loss_smooth.append(np.median(loss["train"][max(i-n_loss_avg,0):])) flow.eval() cc_model.eval() with torch.no_grad(): data_val_tensor, dist_val_tensor, priors_val_tensor, _, _ = make_data(n_data=sub_batch,n_max=1,n_prior=n_prior,flow=flow,cc_model=cc_model,LUT=LUT) val_cond = torch.cat((dist_val_tensor[:,-1].reshape(-1,1),cc_model(priors_val_tensor[:,:,:,-1].flatten(1,2))),dim=1).to(device) for k in range(2,current_n_max): with torch.no_grad(): temp_data_val_tensor, temp_dist_val_tensor, temp_priors_val_tensor, _, _ = make_data(n_data=sub_batch,n_max=k,n_prior=n_prior,flow=flow,cc_model=cc_model,LUT=LUT) temp_val_cond = torch.cat((temp_dist_val_tensor[:,-1].reshape(-1,1),cc_model(temp_priors_val_tensor[:,:,:,-1].flatten(1,2))),dim=1).to(device) data_val_tensor = torch.cat((data_val_tensor,temp_data_val_tensor),dim=0) val_cond = torch.cat((val_cond,temp_val_cond),dim=0) with torch.no_grad(): _loss = -flow.log_prob(data_val_tensor, conditional=val_cond).mean().item() val_loss = _loss loss["val"].append(val_loss) val_loss_smooth.append(np.median(loss["val"][max(i-n_loss_avg,0):])) if not i % 500 and i>0: my_lr = scheduler.get_last_lr()[0] print(f"Epoch {i} - train: {loss['train'][-1]:.3f}, val: {loss['val'][-1]:.3f}, lr: {my_lr:.3e}, n_max: {current_n_max}, sub_batch: {sub_batch}") fig, ax = plt.subplots(n_test,n_max+1, figsize=(3*n_max, 16), dpi=100) loss_fig, loss_ax = plt.subplots(1, 1, figsize=(8, 8), dpi=100) j = 0 print(d_test_tensor.device) for x,d,nd in zip(data_test_tensor,d_test_tensor,test_n_sig.cpu().numpy().astype(int)): _, _, old_samples, _, samples = make_data(1,n_prior=n_prior,n_max=nd,n_post=n_post,flow=flow,cc_model=cc_model,meas=d.reshape(1,-1),LUT=LUT) # reshape and plot samples = torch.permute(samples,(0,3,1,2))[0,:,:,:].cpu().numpy().reshape(nd,n_post,n_cos+1) old_samples = torch.permute(old_samples,(0,3,1,2))[0,:,:,:].cpu().numpy().reshape(nd,n_prior,n_cos+1) x = x.cpu().numpy() ax[j,0].plot(old_samples[0,:,0],old_samples[0,:,1],'xc',markersize=10) for k in range(1,nd): prior_s = old_samples[k,:,:].reshape(n_prior,n_cos+1) post_s = samples[k-1,:,:].reshape(n_post,n_cos+1) ax[j,k].plot(post_s[:,0],post_s[:,1],'.b',markersize=1) ax[j,k].plot(ns_samples[j,k-1,:,0],ns_samples[j,k-1,:,1],'.r',markersize=1) ax[j,k].set_xlim([0.2,1.0]) ax[j,k].set_ylim([0.0,1.0]) ax[j,k].plot(x[0],x[1],'xk',label='truth',markersize=10) ax[j,nd].plot(x[0],x[1],'xk',label='truth') ax[j,nd].set_xlim([0.2,1.0]) ax[j,nd].set_ylim([0.0,1.0]) ax[j,nd].legend(loc='upper right') # loop over multiple different orders of the measurements #if n_max<=4: # it = permutations(d) #else: it = iter([d[np.random.permutation(n_max)] for i in range(50)]) new_samples = [] new_n_post = n_post // 50 for new_d in it: _, _, _, _, temp_samples = make_data(1,n_prior=n_prior,n_max=nd,n_post=new_n_post,flow=flow,cc_model=cc_model,meas=new_d.reshape(1,-1),LUT=LUT) new_samples.append(torch.permute(temp_samples,(0,3,1,2))[0,-1,:,:].cpu().numpy().reshape(new_n_post,n_cos+1)) new_samples = np.array(new_samples).reshape(-1,n_cos+1) ax[j,nd].plot(samples[-1,:,0],samples[-1,:,1],'.b',markersize=1) ax[j,nd].plot(new_samples[:,0],new_samples[:,1],'.g',markersize=1) ax[j,nd].plot(ns_samples[j,-1,:,0],ns_samples[j,-1,:,1],'.r',markersize=1) kw = {"plot_datapoints": False, "plot_density": False, "levels": [0.5,0.99]} ax2 = corner.corner(ns_samples[j,-1,:,:2],bins=20,labels=["H0","Om"],truths=[x[0],x[1]],color='r', quantiles=None,**kw) #,'levels=[0.5,0.9]','plot_datapoints=False','smooth=0.02'}) corner.corner(samples[nd-1,:,:2],bins=20,fig=ax2,color='b',**kw) corner.corner(new_samples[:,:2],bins=20,fig=ax2,color='g',**kw) ax2.savefig('{}/training_corner_{}_{}.png'.format(plot_path,i,j)) j += 1 fig.savefig('{}/training_{}.png'.format(plot_path,i)) loss_ax.semilogx(train_loss_smooth, alpha=0.5, label="Train") loss_ax.semilogx(val_loss_smooth, alpha=0.5, label="Val.") loss_ax.set_ylim(-3, 0) loss_ax.set_xlabel("Epoch") loss_ax.set_ylabel("Loss") loss_ax.legend() loss_ax.grid('on') loss_fig.savefig('{}/loss.png'.format(plot_path)) plt.close('all') if i>0 and not i % 500 and current_n_max<=n_max: current_n_max += 1 sub_batch = batch_size // current_n_max print('updated current nmax to {} and sub batch size to {}'.format(current_n_max,sub_batch)) #if i>500 and i<1000: # current_n_max = 2 #elif i>=1000 and i<1500: # current_n_max = 3 #elif i>=1500: # current_n_max = 4 flow.eval() print("Finished training")