from glasflow import RealNVP, CouplingNSF import matplotlib.pyplot as plt import matplotlib.patches as patches import numpy as np import gc import torch import torch.nn as nn import torch.nn.functional as F from torch.distributions import Normal, Uniform from itertools import chain, permutations from datetime import datetime import pickle from scipy.interpolate import RectBivariateSpline from scipy.integrate import quad from scipy.special import logsumexp import time import shutil import seaborn as sns import os os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" from scipy.stats import uniform, norm, chi from scipy.special import expit, log_wright_bessel from nessai.flowsampler import FlowSampler from nessai.model import Model from nessai.utils import setup_logger from nessai.livepoint import dict_to_live_points import corner from uroboros_network import MAB, SAB, ISAB, PMA, SmallSetTransformer, PI_NeuralNetwork, data_NeuralNetwork from uroboros_utils import calc_median_error, calculate_js from uroboros_signal import cw_model from uroboros_core import make_data, make_data_pretrain, sample_from_gmm_diagonal #torch.autograd.set_detect_anomaly(True) gc.collect() torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() import configparser config = configparser.ConfigParser() config.read('/home/chrism/uroboros/config.ini') # testdata run_id = config['general']['run_id'] ns_path = config['testdata']['ns_path'] values_str = config['testdata']['n_max_list'] n_max_list = [int(x.strip()) for x in values_str.split(',')] n_par = config.getint('testdata', 'n_par') n_meas = config.getint('general', 'n_meas') n_sig = config.getfloat('testdata', 'n_sig') # training pretrain_network_path = config['training']['network_path'] # pretrain seed = config.getint('pretrain', 'seed') test_seed = config.getint('testdata', 'seed') n_test = config.getint('testdata', 'n_test') n_mix = config.getint('pretrain', 'n_mix') iterations = config.getint('pretrain','iterations') batch_size = config.getint('training','batch_size') lr = config.getfloat('pretrain','lr') # the learning rate mode_scale = config.getfloat('pretrain','mode_scale') #ns_path = config['pretrain']['ns_path'] pretrain_network_path = config['pretrain']['network_path'] train_network_path = config['training']['network_path'] # network n_prior = config.getint('network','n_prior') # number of samples used to represent the conditonal prior n_prior_cc_out = config.getint('network','n_prior_cc_out') # the size of compressed prior n_meas_cc_out = config.getint('network','n_meas_cc_out') # the size of compressed measurement # postprocessing Nreorder = config.getint('postprocess','Nreorder') # plotting plot_step = config.getint('plotting','train_plot_step') n_post = config.getint('plotting','n_post') # the number of posterior samples used for plotting plot_path = config['plotting']['plot_path'] np.random.seed(seed) torch.manual_seed(seed) device = "cuda" torch.cuda.set_device(0) n_max = n_max_list[-1] # max the number of measurements per sample glob_n_max = n_max test_run_id = '{}_s{}_nmax{}_ntest{}'.format(run_id,test_seed,n_max,n_test) run_id = '{}_s{}_nmax{}_ntest{}_train'.format(run_id,seed,n_max,n_test) date_id = datetime.today().isoformat() plot_path = '{}/{}_train'.format(plot_path,date_id) try: os.mkdir('{}'.format(plot_path)) except: print('unable to make output directory {}'.format(plot_path)) exit(1) shutil.copyfile('./uroboros_train.py', '{}/uroboros_train.txt'.format(plot_path)) # define the Flow that will estimate the parameters conditional on new data and compressed prior information flow = CouplingNSF( n_inputs=n_par, # number of params n_transforms=10, n_conditional_inputs=n_prior_cc_out+n_meas_cc_out, # size of compressed prior plus size of measurement n_neurons=128, num_bins=10, # for NSF flow only n_blocks_per_transform=4, batch_norm_within_blocks=True, linear_transform='permutation', batch_norm_between_transforms=True, ).to(device) print(f"Created flow and sent to {device}") cc_prior_model = SmallSetTransformer(n_par,n_prior_cc_out).to(device) cc_meas_model = data_NeuralNetwork(n_meas,n_meas_cc_out).to(device) # load pretrained model weights flow.load_state_dict(torch.load('{}/flow_weights.pth'.format(pretrain_network_path))) cc_meas_model.load_state_dict(torch.load('{}/cc_meas_model_weights.pth'.format(pretrain_network_path))) cc_prior_model.load_state_dict(torch.load('{}/cc_prior_model_weights.pth'.format(pretrain_network_path))) print('loaded in pretrained network parameters') # load the true test data params and measurements parfile = '{}/{}/testdata_{}.dat'.format(ns_path,test_run_id,test_run_id) # open the parfile and read the data and parameters with open(parfile, 'rb') as f: # Python 3: open(..., 'rb') data_test_tensor, d_test_tensor, n_test_tensor = pickle.load(f) data_test_tensor = torch.tensor(data_test_tensor).to(device) d_test_tensor = torch.tensor(d_test_tensor).to(device) n_test_tensor = torch.tensor(n_test_tensor).to(device) test_n_sig = n_max*torch.ones(n_test) print('read in par file {}'.format(parfile)) print('read in test params') # plot test data print('plotting test data') signal_model = cw_model(n_meas,n_max) signal_model.plot_test_data(d_test_tensor,n_test_tensor,plot_path) print('plotted test data') # The FlowSampler object is used to managed the sampling. Keyword arguments # are passed to the nested sampling. j = 0 ns_samples = [] logZ = [] for j in range(n_test): for i in n_max_list: for k in range(2): path = "{}/{}/ns_{}_{}_{}.dat".format(ns_path,test_run_id,i,j,k) if os.path.isfile("{}/{}/ns_{}_{}_{}.dat".format(ns_path,test_run_id,i,j,k))==True: fn = '{}/{}/ns_{}_{}_{}.dat'.format(ns_path,test_run_id,i,j,k) print('loading ns file {}'.format(fn)) with open("{}/{}/ns_{}_{}_{}.dat".format(ns_path,test_run_id,i,j,k), "rb") as f: ns_samples.append(np.fromfile(f).reshape(-1,n_par).astype(np.float64)) fn = '{}/{}/logZ_{}_{}_{}.dat'.format(ns_path,test_run_id,i,j,k) print('loading logZ file {}'.format(fn)) with open("{}/{}/logZ_{}_{}_{}.dat".format(ns_path,test_run_id,i,j,k), "rb") as f: #temp = np.fromfile(f).reshape(-1,3).astype(np.float64) logZ.append(np.fromfile(f).reshape(-1,3).astype(np.float64)) else: ns_samples.append(np.zeros((n_post,n_par))) logZ.append(np.zeros((1,3))) # restructure the ns samples k = 0 min_len = 1e16 for i in range(n_test): for j in n_max_list: for b in range(2): if ns_samples[k].shape[0] < min_len: min_len = ns_samples[k].shape[0] k += 1 ns_min = min(min_len,n_post) temp_ns_samples = np.zeros((n_test,len(n_max_list),ns_min,2,n_par)) ns_logZ = np.zeros((n_test,len(n_max_list),2,3)) k = 0 for i in range(n_test): c = 0 for j in n_max_list: for b in range(2): idx = np.random.choice(ns_samples[k].shape[0],size=ns_min) temp_ns_samples[i,c,:,b,:] = ns_samples[k][idx,:] ns_logZ[i,c,b,:] = logZ[k] k += 1 c += 1 # increment the n_max index ns_samples = torch.from_numpy(temp_ns_samples).flatten(0,3).reshape(n_test,len(n_max_list),ns_min,2,n_par).cpu().numpy() all_pars = chain(flow.parameters(),cc_prior_model.parameters(),cc_meas_model.parameters()) optimiser = torch.optim.Adam(all_pars,lr=lr) n_loss_avg = 1000 loss = dict(train=[], val=[]) train_loss_smooth = [] val_loss_smooth = [] scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimiser, iterations, eta_min=1e-6, last_epoch=-1) #current_n_max = 1 #sub_batch = np.zeros(n_max+1) #sub_batch[0] = batch_size #js = np.empty((0,len(n_max_list),n_test,n_par)) ivec = [] #jsfile = '{}/js.dat'.format(plot_path) logZ = [] prior_logZ = [] logZ_err = [] prior_logZ_err = [] model = cw_model(n_meas,n_max) # make dictionary of network stuff net = {} net['flow'] = flow net['cc_prior_model'] = cc_prior_model net['cc_meas_model'] = cc_meas_model net['n_meas'] = n_meas net['n_par'] = n_par net['n_prior_cc_out'] = n_prior_cc_out net['n_meas_cc_out'] = n_meas_cc_out for i in range(iterations+1): torch.cuda.synchronize() optimiser.zero_grad() # make all the data for current_n_max cycles x_train, _, y_train, priors_train, _, condition_train = make_data(batch_size, model=model, n_prior=n_prior, n_max=n_max, net=net, device=device) # compute loss flow.train() _loss = -flow.log_prob(x_train[:,:n_par].to(device),conditional=condition_train).to(device).mean() _loss.backward() optimiser.step() scheduler.step() train_loss = _loss.item() loss["train"].append(train_loss) train_loss_smooth.append(np.median(loss["train"][max(i-n_loss_avg,0):])) # validation if not i % 100: # run the analysis on the validation data x_val, _, y_val, priors_val, _, condition_val = make_data(batch_size, model=model, n_prior=n_prior, n_max=n_max, net=net, valtest=True, device=device) # compute loss flow.eval() _loss = -flow.log_prob(x_val[:,:n_par].to(device),conditional=condition_val).to(device).mean() val_loss = _loss.item() loss["val"].append(val_loss) val_loss_smooth.append(val_loss) print('epoch {} time {}'.format(i,time.asctime())) if not i % plot_step: # and i>0: my_lr = scheduler.get_last_lr()[0] print(f"Epoch {i} - train: {loss['train'][-1]:.3f}, val: {loss['val'][-1]:.3f}, lr: {my_lr:.3e}, n_max: {n_max}") # save models torch.save(flow.state_dict(), '{}/flow_weights.pth'.format(train_network_path)) torch.save(cc_meas_model.state_dict(), '{}/cc_meas_model_weights.pth'.format(train_network_path)) torch.save(cc_prior_model.state_dict(), '{}/cc_prior_model_weights.pth'.format(train_network_path)) # evidence and JS calculation j = 0 temp_js = np.zeros((1,len(n_max_list),n_test,n_par)) temp_logZ = np.zeros(n_test) temp_logZ_err = np.zeros(n_test) temp_prior_logZ = np.zeros(n_test) temp_prior_logZ_err = np.zeros(n_test) for x,d,nd in zip(data_test_tensor,d_test_tensor,test_n_sig.cpu().numpy().astype(int)): temp_model = cw_model(n_meas,nd,d=d.transpose(1,0).detach().cpu().numpy()) _, _, _, samples, _, test_cond = make_data(1,model=temp_model,n_prior=n_prior,n_max=nd,n_post=n_post,net=net,meas=d.reshape(1,n_meas,n_max),valtest=True,device=device) # loop over parameters and then nmax for idx,s in enumerate(np.transpose(samples.detach().cpu()[0,:,:,np.array(n_max_list,dtype=np.int64)-1],axes=(2,1,0))): # now (n_max,npar,npost) for r,q in enumerate(s): if torch.sum(torch.isnan(q))==0: temp_js[0,idx,j,r] = 1e3*calculate_js(q, ns_samples[j,idx,:,0,r]).median #(n_test,len(n_max_list),ns_min,2,n_par) logpost = 0.0 if torch.sum(torch.isnan(samples[0,:,:,-1]))==0: logpost = flow.log_prob(samples[0,:,:,-1],conditional=test_cond.tile(n_post,1)).detach().cpu().numpy() x = {} for h,nm in enumerate(temp_model.names): x["{}".format(nm)] = samples[0,:,h,-1].detach().cpu().numpy() logpi = temp_model.log_prior(x) logL = temp_model.log_likelihood(x).detach().cpu().numpy() temp_logZ[j] = logsumexp(logpi + logL - logpost) - np.log(n_post) mini_logZ = np.zeros(10) for e in range(10): idx = np.random.randint(0,n_post,int(n_post//10)) mini_logZ[e] = logsumexp(logpi[idx] + logL[idx] - logpost[idx]) - np.log(n_post//10) temp_logZ_err[j] = np.sqrt(np.var(mini_logZ)/10) # compare with samples just from the prior prior_x = temp_model.new_point(N=n_post) prior_logL = temp_model.log_likelihood(prior_x).detach().cpu().numpy() temp_prior_logZ[j] = logsumexp(prior_logL) - np.log(n_post) mini_logZ = np.zeros(10) for e in range(10): idx = np.random.randint(0,n_post,int(n_post//10)) mini_logZ[e] = logsumexp(prior_logL[idx]) - np.log(n_post//10) temp_prior_logZ_err[j] = np.sqrt(np.var(mini_logZ)/10) j += 1 logZ.append(temp_logZ) logZ_err.append(temp_logZ_err) prior_logZ.append(temp_prior_logZ) prior_logZ_err.append(temp_prior_logZ_err) print('log evidence (UROB) = {}'.format(logZ)) print('log evidence (NESS) = {}'.format(ns_logZ)) print('log evidence (PRIO) = {}'.format(prior_logZ)) # plot the Bayesfactor cw_model(n_meas,glob_n_max).plot_Bayesfactor(ns_logZ[:,-1,0,:].reshape(n_test,3), np.array(prior_logZ).reshape(-1,n_test).transpose(1,0), np.array(prior_logZ_err).reshape(-1,n_test).transpose(1,0), np.array(logZ).reshape(-1,n_test).transpose(1,0), np.array(logZ_err).reshape(-1,n_test).transpose(1,0), plot_path) cnt = 0 for a in range(n_par): for b in range(a+1,n_par): fig1, ax1 = plt.subplots(n_test,len(n_max_list)+1, figsize=(4*(len(n_max_list)+1),4*n_test), dpi=100) fig2, bx2 = plt.subplots(n_test,len(n_max_list)+1, figsize=(4*(len(n_max_list)+1),4*n_test), dpi=100) j = 0 for x,d,nd in zip(data_test_tensor,d_test_tensor,test_n_sig.cpu().numpy().astype(int)): print('data shape {}, nd {}'.format(d.shape,nd)) _, _, _, samples, _, _ = make_data(1,model=model,n_prior=n_prior,n_max=nd,n_post=n_post,net=net,meas=d.reshape(1,n_meas,nd),valtest=True,device=device) samples = samples[:,:,:,1:] # ignore original priors # reorder observations #reord_old_samples = torch.empty(1,0,n_par,n_max).to(device) reord_samples = torch.empty(1,0,n_par,n_max).to(device) for _ in range(Nreorder): idx = torch.randperm(n_max) _, _, _, temp_samples, _, _ = make_data(1,model=model,n_prior=n_prior,n_max=nd,n_post=int(n_post/float(Nreorder)),net=net,meas=d.reshape(1,n_meas,n_max)[:,:,idx],valtest=True,device=device) #reord_old_samples = torch.cat((reord_old_samples,temp_old_samples),dim=1) print(reord_samples.shape,temp_samples.shape) reord_samples = torch.cat((reord_samples,temp_samples[:,:,:,1:]),dim=1) # reshape and plot signal_model.scatter_plot(ax1[j],a,b,samples,ns_samples[j],reord_samples,x,n_max_list,n_prior,n_post,n_par,ns_logZ[j,-1,0,:],0,0) signal_model.scatter_plot(bx2[j],a,b,samples,ns_samples[j],reord_samples,x,n_max_list,n_prior,n_post,n_par,ns_logZ[j,-1,0,:],0,0,change_vars=True) j += 1 fig1.savefig('{}/training_{}_{}{}.png'.format(plot_path,i,a,b)) fig2.savefig('{}/training_cv_{}_{}{}.png'.format(plot_path,i,a,b)) cnt += 1 loss_fig, loss_ax = plt.subplots(1, 1, figsize=(8, 8), dpi=100) loss_ax.semilogx(train_loss_smooth, alpha=0.5, label="Train") loss_ax.semilogx(np.arange(len(val_loss_smooth))*100,val_loss_smooth, alpha=0.5, label="Val.") #loss_ax.semilogx(val_loss_smooth, alpha=0.5, label="Val.") loss_ax.set_ylim(np.min(train_loss_smooth)-0.1, np.percentile(np.array(train_loss_smooth),99)) loss_ax.set_xlim([1000,iterations]) loss_ax.set_xlabel("Epoch") loss_ax.set_ylabel("Loss") loss_ax.legend() loss_ax.grid('on') loss_fig.savefig('{}/loss.png'.format(plot_path)) plt.close('all') #flow.eval() #cc_prior_model.eval() #cc_meas_model.eval()