from glasflow import RealNVP, CouplingNSF import matplotlib.pyplot as plt import matplotlib.patches as patches import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.distributions import Normal, Uniform from itertools import chain, permutations from datetime import datetime import pickle from scipy.interpolate import RectBivariateSpline from scipy.integrate import quad from scipy.special import logsumexp import time import shutil import seaborn as sns import os from scipy.stats import uniform, norm, chi from scipy.special import expit, log_wright_bessel from nessai.flowsampler import FlowSampler from nessai.model import Model from nessai.utils import setup_logger from nessai.livepoint import dict_to_live_points import corner from uroboros_network import MAB, SAB, ISAB, PMA, SmallSetTransformer, PI_NeuralNetwork, data_NeuralNetwork #from uroboros_utils import calc_median_error, calculate_js from uroboros_signal import cw_model from uroboros_core import make_data, make_data_pretrain, sample_from_gmm_diagonal #torch.autograd.set_detect_anomaly(True) import configparser config = configparser.ConfigParser() config.read('/home/chrism/uroboros/config.ini') # testdata run_id = config['testdata']['run_id'] ns_path = config['testdata']['ns_path'] values_str = config['testdata']['n_max_list'] #n_max_list = [int(x.strip()) for x in values_str.split(',')] n_par = config.getint('testdata', 'n_par') n_meas = config.getint('testdata', 'n_meas') n_sig = config.getfloat('testdata', 'n_sig') # pretrain seed = config.getint('pretrain', 'seed') n_test = config.getint('pretrain', 'n_test') n_mix = config.getint('pretrain', 'n_mix') iterations = config.getint('pretrain','iterations') batch_size = config.getint('pretrain','batch_size') lr = config.getfloat('pretrain','lr') # the learning rate mode_scale = config.getfloat('pretrain','mode_scale') # network n_prior = config.getint('network','n_prior') # number of samples used to represent the conditonal prior n_prior_cc_out = config.getint('network','n_prior_cc_out') # the size of compressed prior n_meas_cc_out = config.getint('network','n_meas_cc_out') # the size of compressed measurement # postprocessing Nreorder = config.getint('postprocess','Nreorder') # plotting plot_step = config.getint('plotting','plot_step') n_post = config.getint('plotting','n_post') # the number of posterior samples used for plotting plot_path = config['plotting']['plot_path'] np.random.seed(seed) torch.manual_seed(seed) device = "cuda" n_max = 1 #n_max = n_max_list[-1] # max the number of measurements per sample #glob_n_max = n_max run_id = '{}_s{}_ntest{}_pretrain'.format(run_id,seed,n_test) date_id = datetime.today().isoformat() plot_path = '{}/{}_pretrain'.format(plot_path,date_id) try: os.mkdir('{}'.format(plot_path)) except: print('unable to make output directory {}'.format(plot_path)) exit(1) shutil.copyfile('./uroboros_pretrain.py', '{}/uroboros_pretrain.txt'.format(plot_path)) # define the Flow that will estimate the parameters conditional on new data and compressed prior information flow = RealNVP( #CouplingNSF( n_inputs=n_par, # number of params n_transforms=6, n_conditional_inputs=n_prior_cc_out+n_meas_cc_out, # size of compressed prior plus size of measurement n_neurons=128, n_blocks_per_transform=4, batch_norm_within_blocks=True, linear_transform='permutation', batch_norm_between_transforms=True, ).to(device) print(f"Created flow and sent to {device}") cc_prior_model = SmallSetTransformer(n_par,n_prior_cc_out).to(device) cc_meas_model = data_NeuralNetwork(n_meas,n_meas_cc_out).to(device) # load the true test data params and measurements parfile = '{}/{}/testdata_{}.dat'.format(ns_path,run_id,run_id) # open the parfile and read the data and parameters with open(parfile, 'rb') as f: # Python 3: open(..., 'rb') data_test_tensor, d_test_tensor, n_test_tensor, means, variances, weights, flag = pickle.load(f) data_test_tensor = torch.tensor(data_test_tensor).to(device) d_test_tensor = torch.tensor(d_test_tensor).to(device) print(d_test_tensor.shape) for a in d_test_tensor: print(a[:,0]) print(a[:,-1]) n_test_tensor = torch.tensor(n_test_tensor).to(device) means = torch.tensor(means) variances = torch.tensor(variances) weights = torch.tensor(weights) test_n_sig = torch.ones(n_test) print('read in par file {}'.format(parfile)) print('read in test params') # plot test data print('plotting test data') signal_model = cw_model(n_meas,n_max) signal_model.plot_test_data(d_test_tensor,n_test_tensor,plot_path) print('plotted test data') # generate prior samples for plotting # generate the samples from the GMMs prior_samples = sample_from_gmm_diagonal(means, variances, weights, n_post) # The FlowSampler object is used to managed the sampling. Keyword arguments # are passed to the nested sampling. j = 0 ns_samples = [] logZ = [] print(data_test_tensor.shape,d_test_tensor.shape) for j in range(n_test): fn = '{}/{}/ns_{}.dat'.format(ns_path,run_id,j) if os.path.exists(fn): print('loading ns file {}'.format(fn)) with open("{}/{}/ns_{}.dat".format(ns_path,run_id,j), "rb") as f: ns_samples.append(np.fromfile(f).reshape(-1,n_par).astype(np.float64)) fn = '{}/{}/logZ_{}.dat'.format(ns_path,run_id,j) print('loading logZ file {}'.format(fn)) with open("{}/{}/logZ_{}.dat".format(ns_path,run_id,j), "rb") as f: logZ.append(np.fromfile(f).reshape(-1,3).astype(np.float64)) else: ns_samples.append(np.zeros((n_post,n_par))) logZ.append(np.zeros((1,3))) j += 1 # restructure the ns samples k = 0 min_len = 1e16 for i in range(n_test): if ns_samples[k].shape[0] < min_len: min_len = ns_samples[k].shape[0] k += 1 ns_min = min(min_len,n_post) temp_ns_samples = np.zeros((n_test,ns_min,n_par)) ns_logZ = np.zeros((n_test,3)) k = 0 for i in range(n_test): idx = np.random.choice(ns_samples[k].shape[0],size=ns_min) temp_ns_samples[i,:,:] = ns_samples[k][idx,:] ns_logZ[i,:] = logZ[k] k += 1 ns_samples = torch.from_numpy(temp_ns_samples).reshape(n_test,ns_min,n_par).cpu().numpy() all_pars = chain(flow.parameters(),cc_prior_model.parameters(),cc_meas_model.parameters()) optimiser = torch.optim.Adam(all_pars,lr=lr) #optimiser = torch.optim.Adam(chain(flow.parameters(),cc_prior_model.parameters(),cc_meas_model.parameters()),lr=lr) n_loss_avg = 1000 loss = dict(train=[], val=[]) train_loss_smooth = [] val_loss_smooth = [] scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimiser, iterations, eta_min=1e-6, last_epoch=-1) current_n_max = 1 sub_batch = np.zeros(n_max+1) sub_batch[0] = batch_size #js = np.empty((0,len(n_max_list),n_test,n_par)) ivec = [] #jsfile = '{}/js.dat'.format(plot_path) logZ = [] prior_logZ = [] logZ_err = [] prior_logZ_err = [] model = cw_model(n_meas,1) test_model = cw_model(n_meas,n_max) # make dictionary of network stuff net = {} net['flow'] = flow net['cc_prior_model'] = cc_prior_model net['cc_meas_model'] = cc_meas_model net['n_meas'] = n_meas net['n_par'] = n_par net['n_prior_cc_out'] = n_prior_cc_out net['n_meas_cc_out'] = n_meas_cc_out def get_anneal(i,iterations): if i<10000: return 0 elif i>=10000 and i<20000: return (i-10000)/(10000) else: return 1.0 for i in range(iterations+1): optimiser.zero_grad() anneal = get_anneal(i,iterations) # make all the data for current_n_max cycles x_train, _, y_train, priors_train, _, condition_train = make_data_pretrain(n_data=batch_size, model=model, n_prior=n_prior, n_mix=n_mix, mode_scale=mode_scale, net=net, device=device, anneal=anneal) # compute loss flow.train() _loss = -flow.log_prob(x_train[:,:n_par].to(device),conditional=condition_train).to(device).mean() _loss.backward() optimiser.step() scheduler.step() train_loss = _loss.item() loss["train"].append(train_loss) train_loss_smooth.append(np.median(loss["train"][max(i-n_loss_avg,0):])) # validation if not i % 100: # run the analysis on the validation data x_val, _, y_val, priors_val, _, condition_val = make_data_pretrain(n_data=batch_size, model=model, n_prior=n_prior, n_mix=n_mix, mode_scale=mode_scale, net=net, valtest=True, device=device, anneal=anneal) # compute loss flow.eval() _loss = -flow.log_prob(x_val[:,:n_par].to(device),conditional=condition_val).to(device).mean() val_loss = _loss.item() loss["val"].append(val_loss) val_loss_smooth.append(val_loss) if not i % plot_step: my_lr = scheduler.get_last_lr()[0] print(f"Epoch {i} - train: {loss['train'][-1]:.3f}, val: {loss['val'][-1]:.3f}, lr: {my_lr:.3e}, n_max: {current_n_max}, sub_batch: {sub_batch}") prior_label = torch.zeros(1,n_post,n_par,2).to(device) cnt = 0 for a in range(n_par): for b in range(a+1,n_par): rt_ntest = int(np.sqrt(n_test) + 1.0 - 1e-12) print(n_test,rt_ntest) fig1, ax1 = plt.subplots(rt_ntest,rt_ntest, figsize=(4*rt_ntest,4*rt_ntest), dpi=100) fig2, bx2 = plt.subplots(rt_ntest,rt_ntest, figsize=(4*rt_ntest,4*rt_ntest), dpi=100) j = 0 for x,d,nd,m,v,w in zip(data_test_tensor,d_test_tensor,test_n_sig.cpu().numpy().astype(int),means,variances,weights): prior_label[:,:,:,0] = sample_from_gmm_diagonal(m.unsqueeze(0),v.unsqueeze(0),w.unsqueeze(0),n_post).reshape(1,n_post,n_par).to(device) _, _, _, samples, _, _ = make_data_pretrain(1, model=model, n_prior=n_prior, n_mix=n_mix, n_post=n_post, mode_scale=mode_scale, net=net, meas=d.reshape(1,n_meas,1), prior_label=prior_label, device=device, valtest=True) print(samples.shape,nd,n_max) samples = samples[:,:,:,1:] # ignore original priors # reshape and plot l = j % rt_ntest m = int(j/rt_ntest) ps = prior_samples[j] if flag[j]==1 else None signal_model.scatter_plot_pretrain(ax1[l,m],a,b,samples,ns_samples[j],ps,x,n_prior,n_post,n_par,ns_logZ[j,:],-1,-1) signal_model.scatter_plot_pretrain(bx2[l,m],a,b,samples,ns_samples[j],ps,x,n_prior,n_post,n_par,ns_logZ[j,:],-1,-1,change_vars=True) j += 1 fig1.savefig('{}/pretraining_{}_{}{}.png'.format(plot_path,i,a,b)) fig2.savefig('{}/pretraining_cv_{}_{}{}.png'.format(plot_path,i,a,b)) cnt += 1 loss_fig, loss_ax = plt.subplots(1, 1, figsize=(8, 8), dpi=100) loss_ax.semilogx(train_loss_smooth, alpha=0.5, label="Train") loss_ax.semilogx(np.arange(len(val_loss_smooth))*100,val_loss_smooth, alpha=0.5, label="Val.") #loss_ax.semilogx(val_loss_smooth, alpha=0.5, label="Val.") loss_ax.set_ylim(np.min(train_loss_smooth)-0.1, np.percentile(np.array(train_loss_smooth),99)) loss_ax.set_xlim([1000,iterations]) loss_ax.set_xlabel("Epoch") loss_ax.set_ylabel("Loss") loss_ax.legend() loss_ax.grid('on') loss_fig.savefig('{}/loss.png'.format(plot_path)) plt.close('all') #flow.eval() #cc_prior_model.eval() #cc_meas_model.eval()