import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
import time
from lal import GreenwichMeanSiderealTime
from astropy.time import Time
from astropy import coordinates as coord
import corner
import os
import shutil
import h5py
import json
from universal_divergence import estimate
import bilby

def get_param_index(all_pars,pars,sky_extra=None):
    """ 
    Get the list index of requested source parameter types
    """
    # identify the indices of wrapped and non-wrapped parameters - clunky code
    mask = []
    idx = []

    # loop over inference params
    for i,p in enumerate(all_pars):

        # loop over wrapped params
        flag = False
        for q in pars:
            if p==q:
                flag = True    # if inf params is a wrapped param set flag

        # record the true/false value for this inference param
        if flag==True:
            mask.append(True)
            idx.append(i)
        elif flag==False:
            mask.append(False)

    if sky_extra is not None:
        if sky_extra:
            mask.append(True)
            idx.append(len(all_pars))
        else:
            mask.append(False)

    return mask, idx, np.sum(mask)

def convert_ra_to_hour_angle(data, params, pars, single=False):
    """
    Converts right ascension to hour angle and back again
    """

    greenwich = coord.EarthLocation.of_site('greenwich')
    t = Time(params['ref_geocent_time'], format='gps', location=greenwich)
    t = t.sidereal_time('mean', 'greenwich').radian

    # compute single instance
    if single:
        return t - data

    for i,k in enumerate(pars):
        if k == 'ra':
            ra_idx = i

    # Check if RA exist
    try:
        ra_idx
    except NameError:
        print('...... RA is fixed. Not converting RA to hour angle.')
    else:
        # Iterate over all training samples and convert to hour angle
        for i in range(data.shape[0]):
            data[i,ra_idx] = t - data[i,ra_idx]

    return data

def convert_hour_angle_to_ra(data, params, pars, single=False):
    """
    Converts right ascension to hour angle and back again
    """
    greenwich = coord.EarthLocation.of_site('greenwich')
    t = Time(params['ref_geocent_time'], format='gps', location=greenwich)
    t = t.sidereal_time('mean', 'greenwich').radian

    # compute single instance
    if single:
        return np.remainder(t - data,2.0*np.pi)

    for i,k in enumerate(pars):
        if k == 'ra':
            ra_idx = i

    # Check if RA exist
    try:
        ra_idx
    except NameError:
        print('...... RA is fixed. Not converting RA to hour angle.')
    else:
        # Iterate over all training samples and convert to hour angle
        for i in range(data.shape[0]):
            data[i,ra_idx] = np.remainder(t - data[i,ra_idx],2.0*np.pi)

    return data

def load_data(params,bounds,fixed_vals,input_dir,inf_pars,test_data=False,silent=False):
    """ Function to load either training or testing data.
    """

    # Get list of all training/testing files and define dictionary to store values in files
    if type("%s" % input_dir) is str:
        dataLocations = ["%s" % input_dir]
    else:
        print('ERROR: input directory not a string')
        exit(0)

    # Sort files from first generated to last generated
    filenames = sorted(os.listdir(dataLocations[0]))
    print(filenames)

    # If loading by chunks, randomly shuffle list of training/testing filenames
    if params['load_by_chunks'] == True and not test_data:
        nfiles = np.min([int(params['load_chunk_size']/float(params['tset_split'])),len(filenames)])
        files_idx = np.random.randint(0,len(filenames),nfiles) 
        filenames= np.array(filenames)[files_idx]
        if not silent:
            print('...... shuffled filenames since we are loading in by chunks')

    # Iterate over all training/testing files and store source parameters, time series and SNR info in dictionary
    data={'x_data': [], 'y_data_noisefree': [], 'y_data_noisy': [], 'rand_pars': [], 'snrs': []}
    for filename in filenames:
        try:
            data['x_data'].append(h5py.File(dataLocations[0]+'/'+filename, 'r')['x_data'][:])
            data['y_data_noisefree'].append(h5py.File(dataLocations[0]+'/'+filename, 'r')['y_data_noisefree'][:])
            if test_data:
                data['y_data_noisy'].append(h5py.File(dataLocations[0]+'/'+filename, 'r')['y_data_noisy'][:])
            data['rand_pars'] = h5py.File(dataLocations[0]+'/'+filename, 'r')['rand_pars'][:]
            data['snrs'].append(h5py.File(dataLocations[0]+'/'+filename, 'r')['snrs'][:])
            if not silent:
                print('...... Loaded file ' + dataLocations[0] + '/' + filename)
        except OSError:
            print('Could not load requested file')
            continue
    if np.array(data['y_data_noisefree']).ndim == 3:
        data['y_data_noisefree'] = np.expand_dims(np.array(data['y_data_noisefree']),axis=0)
        data['y_data_noisy'] = np.expand_dims(np.array(data['y_data_noisy']),axis=0)
    data['x_data'] = np.concatenate(np.array(data['x_data']), axis=0).squeeze()
    data['y_data_noisefree'] = np.transpose(np.concatenate(np.array(data['y_data_noisefree']), axis=0),[0,2,1])
    if test_data:
        data['y_data_noisy'] = np.transpose(np.concatenate(np.array(data['y_data_noisy']), axis=0),[0,2,1])
    data['snrs'] = np.concatenate(np.array(data['snrs']), axis=0)
    # convert ra to hour angle
    data['x_data'] = convert_ra_to_hour_angle(data['x_data'], params, params['rand_pars'])

    # Normalise the source parameters
    for i,k in enumerate(data['rand_pars']):
        par_min = k.decode('utf-8') + '_min'
        par_max = k.decode('utf-8') + '_max'
        # normalize by bounds
        data['x_data'][:,i]=(data['x_data'][:,i] - bounds[par_min]) / (bounds[par_max] - bounds[par_min])

    # extract inference parameters from all source parameters loaded earlier
    idx = []
    infparlist = ''
    for k in inf_pars:
        infparlist = infparlist + k + ', '
        for i,q in enumerate(data['rand_pars']):
            m = q.decode('utf-8')
            if k==m:
                idx.append(i)
    data['x_data'] = tf.cast(data['x_data'][:,idx],dtype=tf.float32)
    if not silent:
        print('...... {} will be inferred'.format(infparlist))
   
    print(data['x_data'].shape)
    print(data['y_data_noisefree'].shape)

    return data['x_data'], tf.cast(data['y_data_noisefree'],dtype=tf.float32), tf.cast(data['y_data_noisy'],dtype=tf.float32), data['snrs']

def load_samples(params):
    """
    read in pre-computed posterior samples
    """
    if type("%s" % params['pe_dir']) is str:
        # load generated samples back in
        dataLocations = '%s_%s' % (params['pe_dir'],params['samplers'][1])
        print('... looking in {} for posterior samples'.format(dataLocations))
    else:
        print('ERROR: input samples directory not a string')
        exit(0)

    print(dataLocations)
    exit(0)

    # Iterate over requested number of testing samples to use
    for i in range(params['r']):

        filename = '%s/%s_%d.h5py' % (dataLocations,params['bilby_results_label'],i)
        if not os.path.isfile(filename):
            print('... unable to find file {}. Exiting.'.format(filename))
            exit(0)

        print('... Loading test sample -> ' + filename)
        data_temp = {}
        n = 0

        # Retrieve all source parameters to do inference on
        for q in params['bilby_pars']:
            p = q + '_post'
            par_min = q + '_min'
            par_max = q + '_max'
            data_temp[p] = h5py.File(filename, 'r')[p][:]
            if p == 'geocent_time_post':
                data_temp[p] = data_temp[p] - params['ref_geocent_time']
            data_temp[p] = (data_temp[p] - bounds[par_min]) / (bounds[par_max] - bounds[par_min])
            Nsamp = data_temp[p].shape[0]
            n = n + 1
        print('... read in {} samples from {}'.format(Nsamp,filename))

        # place retrieved source parameters in numpy array rather than dictionary
        j = 0
        XS = np.zeros((Nsamp,n))
        for p,d in data_temp.items():
            XS[:,j] = d
            j += 1
        print('... put the samples in an array')

        # Append test sample posteriors to existing array of other test sample posteriors
        rand_idx_posterior = np.random.randint(0,Nsamp,params['n_samples'])
        if i == 0:
            XS_all = np.expand_dims(XS[rand_idx_posterior,:], axis=0)
        else:
            XS_all = np.vstack((XS_all,np.expand_dims(XS[rand_idx_posterior,:], axis=0)))
        print('... appended {} samples to the total'.format(params['n_samples']))
    return XS_all

def plot_losses(train_loss, val_loss, epoch, run='testing'):
    """
    plots the losses
    """
    plt.figure()
    plt.semilogx(np.arange(1,epoch+1),train_loss[:epoch,0],'b',label='RECON')
    plt.semilogx(np.arange(1,epoch+1),train_loss[:epoch,1],'r',label='KL')
    plt.semilogx(np.arange(1,epoch+1),train_loss[:epoch,2],'g',label='TOTAL')
    plt.semilogx(np.arange(1,epoch+1),val_loss[:epoch,0],'--b',alpha=0.5)
    plt.semilogx(np.arange(1,epoch+1),val_loss[:epoch,1],'--r',alpha=0.5)
    plt.semilogx(np.arange(1,epoch+1),val_loss[:epoch,2],'--g',alpha=0.5)
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.legend()
    plt.grid()
    plt.ylim([np.min(1.1*train_loss[int(0.1*epoch):epoch,:]),np.max(1.1*train_loss[int(0.1*epoch):epoch,:])])
    plt.savefig('/data/www.astro/chrism/Heni_results/%s/loss.png' % (run))
    plt.close()

 
def plot_KL(KL_samples, step, run='testing'):
    """
    plots the KL evolution
    """
    N = KL_samples.shape[0]
    plt.figure()
    for kl in np.transpose(KL_samples):
        plt.semilogx(np.arange(1,N+1)*step,kl)
    plt.xlabel('epoch')
    plt.ylabel('KL')
    plt.grid()
    plt.ylim([-0.2,1.0])
    plt.savefig('/data/www.astro/chrism/Heni_results/%s/kl.png' % (run))
    plt.close()


def plot_posterior(samples,x_truth,epoch,idx,run='testing',other_samples=None):
    """
    plots the posteriors
    """

    # trim samples from outside the cube
    mask = []
    for s in samples:
        if (np.all(s>=0.0) and np.all(s<=1.0)):
            mask.append(True)
        else:
            mask.append(False)
    samples = tf.boolean_mask(samples,mask,axis=0)
    print('identified {} good samples'.format(samples.shape[0]))
    if samples.shape[0]<100:
        return -1.0

    if other_samples is not None:
        true_post = np.zeros([other_samples.shape[0],bilby_ol_len])
        true_x = np.zeros(inf_ol_len)
        true_XS = np.zeros([samples.shape[0],inf_ol_len])
        ol_pars = []
        cnt = 0
        for inf_idx,bilby_idx in zip(inf_ol_idx,bilby_ol_idx):
            inf_par = params['inf_pars'][inf_idx]
            bilby_par = params['bilby_pars'][bilby_idx]
            true_XS[:,cnt] = (samples[:,inf_idx] * (bounds[inf_par+'_max'] - bounds[inf_par+'_min'])) + bounds[inf_par+'_min']
            true_post[:,cnt] = (other_samples[:,bilby_idx] * (bounds[bilby_par+'_max'] - bounds[bilby_par+'_min'])) + bounds[bilby_par + '_min']
            true_x[cnt] = (x_truth[inf_idx] * (bounds[inf_par+'_max'] - bounds[inf_par+'_min'])) + bounds[inf_par + '_min']
            ol_pars.append(inf_par)
            cnt += 1
        parnames = []
        for k_idx,k in enumerate(params['rand_pars']):
            if np.isin(k, ol_pars):
                parnames.append(params['corner_labels'][k])

        # convert to RA
        true_XS = convert_hour_angle_to_ra(true_XS,params,ol_pars)
        true_x = convert_hour_angle_to_ra(np.reshape(true_x,[1,true_XS.shape[1]]),params,ol_pars).flatten()

        # compute KL estimate
        idx1 = np.random.randint(0,true_XS.shape[0],1000)
        idx2 = np.random.randint(0,true_post.shape[0],1000)
        try:
            KL_est = estimate(true_XS[idx1,:],true_post[idx2,:])
        except:
            KL_est = -1.0
            pass

    else:
        # Get corner parnames to use in plotting labels
        parnames = []
        for k_idx,k in enumerate(params['rand_pars']):
            if np.isin(k, params['inf_pars']):
                parnames.append(params['corner_labels'][k])
        # un-normalise full inference parameters
        full_true_x = np.zeros(len(params['inf_pars']))
        new_samples = np.zeros([samples.shape[0],len(params['inf_pars'])])
        for inf_par_idx,inf_par in enumerate(params['inf_pars']):
            new_samples[:,inf_par_idx] = (samples[:,inf_par_idx] * (bounds[inf_par+'_max'] - bounds[inf_par+'_min'])) + bounds[inf_par+'_min']
            full_true_x[inf_par_idx] = (x_truth[inf_par_idx] * (bounds[inf_par+'_max'] - bounds[inf_par+'_min'])) + bounds[inf_par + '_min']
        new_samples = convert_hour_angle_to_ra(new_samples,params,params['inf_pars'])
        full_true_x = convert_hour_angle_to_ra(np.reshape(full_true_x,[1,samples.shape[1]]),params,params['inf_pars']).flatten()       
        KL_est = -1.0

    # define general plotting arguments
    defaults_kwargs = dict(
                    bins=50, smooth=0.9, label_kwargs=dict(fontsize=16),
                    title_kwargs=dict(fontsize=16),
                    truth_color='tab:orange', quantiles=[0.16, 0.84],
                    levels=(0.68,0.90,0.95), density=True,
                    plot_density=False, plot_datapoints=True,
                    max_n_ticks=3)

    # 1-d hist kwargs for normalisation
    hist_kwargs = dict(density=True,color='tab:red')
    hist_kwargs_other = dict(density=True,color='tab:blue')

    if other_samples is None:
        figure = corner.corner(new_samples,**defaults_kwargs,labels=parnames,
                           color='tab:red',
                           fill_contours=True, truths=x_truth,
                           show_titles=True, hist_kwargs=hist_kwargs)
        print('Shapes of the corner plot inputs')
        #plotting the actual true values of the parameters
        # Extract the axes
        ndim=4
        axes = np.array(figure.axes).reshape((ndim, ndim))
        print('idx')
        print(idx)
        trueval=h5py.File('/data/wiay/2263373r/testdata/LISAtest_7par_set%d.hdf5' % (idx), "r")
        true_coord=trueval['x_data']
        print(np.shape(true_coord))
        for i in range(ndim):
            ax = axes[i, i]
            ax.axvline(true_coord[0,i], color="blue")
        for yi in range(ndim):
            for xi in range(yi):
                ax = axes[yi, xi]
                ax.plot(true_coord[0,xi], true_coord[0,yi], "bo") 
        plt.savefig('/data/www.astro/chrism/Heni_results/%s/full_posterior_epoch_%d_event_%d.png' % (run,epoch,idx))
        plt.close()
    else:
        figure = corner.corner(true_post, **defaults_kwargs,labels=parnames,
                           color='tab:blue',
                           show_titles=True, hist_kwargs=hist_kwargs_other)
        corner.corner(true_XS,**defaults_kwargs,
                           color='tab:red',
                           fill_contours=True, truths=true_x,
                           show_titles=True, fig=figure, hist_kwargs=hist_kwargs)
        plt.annotate('KL = {:.3f}'.format(KL_est),(0.2,0.95),xycoords='figure fraction',fontsize=18)
        plt.savefig('/data/www.astro/chrism/Heni_results/%s/comp_posterior_epoch_%d_event_%d.png' % (run,epoch,idx))
        plt.close()
    return KL_est

def plot_latent(mu_r1, z_r1, mu_q, z_q, epoch, idx, run='testing'):

    # define general plotting arguments
    defaults_kwargs = dict(
                    bins=50, smooth=0.9, label_kwargs=dict(fontsize=16),
                    title_kwargs=dict(fontsize=16),
                    truth_color='tab:orange', quantiles=[0.16, 0.84],
                    levels=(0.68,0.90,0.95), density=True,
                    plot_density=False, plot_datapoints=True,
                    max_n_ticks=3)

    # 1-d hist kwargs for normalisation
    hist_kwargs = dict(density=True,color='tab:red')
    hist_kwargs_other = dict(density=True,color='tab:blue')

    figure = corner.corner(z_q, **defaults_kwargs,
                           color='tab:blue',
                           show_titles=True, hist_kwargs=hist_kwargs_other)
    corner.corner(z_r1,**defaults_kwargs,
                           color='tab:red',
                           fill_contours=True,
                           show_titles=True, fig=figure, hist_kwargs=hist_kwargs)
    # Extract the axes
    z_dim = z_r1.shape[1]
    axes = np.array(figure.axes).reshape((z_dim, z_dim))

    # Loop over the histograms
    for yi in range(z_dim):
        for xi in range(yi):
            ax = axes[yi, xi]
            ax.plot(mu_r1[0,:,xi], mu_r1[0,:,yi], "sr")
            ax.plot(mu_q[0,xi], mu_q[0,yi], "sb")
    plt.savefig('/data/www.astro/chrism/Heni_results/%s/latent_epoch_%d_event_%d.png' % (run,epoch,idx))
    plt.close()

params = '/home/chrism/Heni_codes/params_heni.json'
bounds = '/home/chrism/Heni_codes/bounds_heni.json'
fixed_vals = '/home/chrism/Heni_codes/fixed_vals_heni.json'
#run = time.ctime().replace(' ', '-')
run = time.strftime('%y-%m-%d-%X-%Z')
EPS = 1e-3

path = os.path.join('/data/www.astro/chrism/Heni_results/', run) 
os.mkdir(path) 
shutil.copy('./vitamin_c_heni.py',path)
shutil.copy('./params_heni.json',path)

# Load parameters files
with open(params, 'r') as fp:
    params = json.load(fp)
with open(bounds, 'r') as fp:
    bounds = json.load(fp)
with open(fixed_vals, 'r') as fp:
    fixed_vals = json.load(fp)

# if doing hour angle, use hour angle bounds on RA
bounds['ra_min'] = convert_ra_to_hour_angle(bounds['ra_min'],params,None,single=True)
bounds['ra_max'] = convert_ra_to_hour_angle(bounds['ra_max'],params,None,single=True)
print('... converted RA bounds to hour angle')
inf_ol_mask, inf_ol_idx, inf_ol_len = get_param_index(params['inf_pars'],params['bilby_pars'])
bilby_ol_mask, bilby_ol_idx, bilby_ol_len = get_param_index(params['bilby_pars'],params['inf_pars'])

# identify the indices of different sets of physical parameters
vonmise_mask, vonmise_idx_mask, vonmise_len = get_param_index(params['inf_pars'],params['vonmise_pars'])
gauss_mask, gauss_idx_mask, gauss_len = get_param_index(params['inf_pars'],params['gauss_pars'])
sky_mask, sky_idx_mask, sky_len = get_param_index(params['inf_pars'],params['sky_pars'])
ra_mask, ra_idx_mask, ra_len = get_param_index(params['inf_pars'],['ra'])
dec_mask, dec_idx_mask, dec_len = get_param_index(params['inf_pars'],['dec'])
m1_mask, m1_idx_mask, m1_len = get_param_index(params['inf_pars'],['mass_1'])
m2_mask, m2_idx_mask, m2_len = get_param_index(params['inf_pars'],['mass_2'])
#idx_mask = np.argsort(gauss_idx_mask + vonmise_idx_mask + m1_idx_mask + m2_idx_mask + sky_idx_mask) # + dist_idx_mask)
idx_mask = np.argsort(m1_idx_mask + m2_idx_mask + gauss_idx_mask + vonmise_idx_mask) # + sky_idx_mask)
dist_mask, dist_idx_mask, dist_len = get_param_index(params['inf_pars'],['luminosity_distance'])
xyz_mask, xyz_idx_mask, xyz_len = get_param_index(params['inf_pars'],['luminosity_distance','ra','dec'])
not_xyz_mask, not_xyz_idx_mask, not_xyz_len = get_param_index(params['inf_pars'],['mass_1','mass_2','geocent_time'])
idx_xyz_mask = np.argsort(xyz_idx_mask + not_xyz_idx_mask)

not_dist_mask, not_dist_idx_mask, not_dist_len = get_param_index(params['inf_pars'],['mass_1','mass_2','psi','phase','geocent_time','theta_jn','ra','dec','a_1','a_2','tilt_1','tilt_2','phi_12','phi_jl'])
idx_dist_mask = np.argsort(not_dist_idx_mask + dist_idx_mask)

print(xyz_mask)
print(not_xyz_mask)
print(idx_xyz_mask)
#masses_len = m1_len + m2_len
print(params['inf_pars'])
print(vonmise_mask,vonmise_idx_mask)
print(gauss_mask,gauss_idx_mask)
print(m1_mask,m1_idx_mask)
print(m2_mask,m2_idx_mask)
print(sky_mask,sky_idx_mask)
print(idx_mask)

# define which gpu to use during training
gpu_num = str(params['gpu_num'])   
os.environ["CUDA_VISIBLE_DEVICES"]=gpu_num
print('... running on GPU {}'.format(gpu_num))

# Let GPU consumption grow as needed
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config)
print('... letting GPU consumption grow as needed')

# load the data
x_data_train, y_data_train, _, snrs_train = load_data(params,bounds,fixed_vals,params['train_set_dir'],params['inf_pars'])
x_data_val, y_data_val, _, snrs_val = load_data(params,bounds,fixed_vals,params['val_set_dir'],params['inf_pars'])
x_data_test, y_data_test_noisefree, y_data_test, snrs_test = load_data(params,bounds,fixed_vals,params['test_set_dir'],params['inf_pars'],test_data=True)

# load precomputed samples
if params['doPE']:
    bilby_samples = load_samples(params)

train_size = params['load_chunk_size']
batch_size = 32
val_size = 1000
test_size = 5

train_dataset = (tf.data.Dataset.from_tensor_slices((x_data_train,y_data_train))
                 .shuffle(train_size).batch(batch_size))
val_dataset = (tf.data.Dataset.from_tensor_slices((x_data_val,y_data_val))
                .shuffle(val_size).batch(batch_size))
test_dataset = (tf.data.Dataset.from_tensor_slices((x_data_test,y_data_test))
                .batch(1))

class CVAE(tf.keras.Model):
    """Convolutional variational autoencoder."""

    def __init__(self, x_dim, y_dim, n_channels, z_dim, n_modes):
        super(CVAE, self).__init__()
        self.z_dim = z_dim
        self.n_modes = n_modes
        self.x_dim = x_dim
        self.y_dim = y_dim
        self.n_channels = n_channels

        r1_input_y = tf.keras.Input(shape=(self.y_dim, self.n_channels))
 
        ### 1st layer
        #layer_1 = tf.keras.layers.Conv1D(filters=8, kernel_size=1, padding='same', activation='relu')(r1_input_y)
        #layer_1 = tf.keras.layers.Conv1D(filters=8, kernel_size=3, padding='same', activation='relu')(layer_1)
        #layer_2 = tf.keras.layers.Conv1D(filters=8, kernel_size=1, padding='same', activation='relu')(r1_input_y)
        #layer_2 = tf.keras.layers.Conv1D(filters=8, kernel_size=5, padding='same', activation='relu')(layer_2)
        #layer_3 = tf.keras.layers.MaxPooling1D(pool_size=3, strides=1, padding='same')(r1_input_y)
        #layer_3 = tf.keras.layers.Conv1D(filters=8, kernel_size=1, padding='same', activation='relu')(layer_3)
        #layer_4 = tf.keras.layers.Conv1D(filters=10, kernel_size=1, padding='same', activation='relu')(r1_input_y)
        #mid_1 = tf.keras.layers.concatenate([layer_1, layer_2, layer_3], axis=2)

        #layer_1b = tf.keras.layers.Conv1D(filters=10, kernel_size=1, padding='same', activation='relu')(mid_1)
        #layer_1b = tf.keras.layers.Conv1D(filters=10, kernel_size=3, padding='same', activation='relu')(layer_1b)
        #layer_2b = tf.keras.layers.Conv1D(filters=10, kernel_size=1, padding='same', activation='relu')(mid_1)
        #layer_2b = tf.keras.layers.Conv1D(filters=10, kernel_size=5, padding='same', activation='relu')(layer_2b)
        #layer_3b = tf.keras.layers.MaxPooling1D(pool_size=3, strides=1, padding='same')(mid_1)
        #layer_3b = tf.keras.layers.Conv1D(filters=10, kernel_size=1, padding='same', activation='relu')(layer_3b)
        #layer_4b = tf.keras.layers.Conv1D(filters=10, kernel_size=1, padding='same', activation='relu')(mid_1)
        #mid_2 = tf.keras.layers.concatenate([layer_1b, layer_2b, layer_3b, layer_4b], axis=2)

        #a = tf.keras.layers.Flatten()(mid_1)

        # the r1 encoder network
        r1_input_y = tf.keras.Input(shape=(self.y_dim, self.n_channels))
        r1_input_red = tf.keras.layers.MaxPooling1D(pool_size=3, strides=2, padding='same')(r1_input_y)
        a1 = tf.keras.layers.Conv1D(filters=8, kernel_size=1, strides=1, padding='same', activation='relu')(r1_input_red)
        a1 = tf.keras.layers.Conv1D(filters=8, kernel_size=3, strides=1, padding='same', activation='relu')(a1)
        a2 = tf.keras.layers.Conv1D(filters=8, kernel_size=1, strides=1, padding='same', activation='relu')(r1_input_red)
        a2 = tf.keras.layers.Conv1D(filters=8, kernel_size=5, strides=1, padding='same', activation='relu')(a2)
        a3 = tf.keras.layers.MaxPooling1D(pool_size=3, strides=1, padding='same')(r1_input_red)
        a3 = tf.keras.layers.Conv1D(filters=8, kernel_size=1, strides=1, padding='same', activation='relu')(a3)
        a4 = tf.keras.layers.Conv1D(filters=8, kernel_size=1, padding='same', activation='relu')(r1_input_red)
        in1 = tf.keras.layers.concatenate([a1, a2, a3, a4], axis=2)
        in1 = tf.keras.layers.MaxPooling1D(pool_size=3, strides=2, padding='same')(in1)
        
        a5 = tf.keras.layers.Conv1D(filters=8, kernel_size=1, strides=1, padding='same', activation='relu')(in1)
        a5 = tf.keras.layers.Conv1D(filters=8, kernel_size=3, strides=1, padding='same', activation='relu')(a5)
        a6 = tf.keras.layers.Conv1D(filters=8, kernel_size=1, strides=1, padding='same', activation='relu')(in1)
        a6 = tf.keras.layers.Conv1D(filters=8, kernel_size=5, strides=1, padding='same', activation='relu')(a6)
        a7 = tf.keras.layers.MaxPooling1D(pool_size=3, strides=1, padding='same')(in1)
        a7 = tf.keras.layers.Conv1D(filters=8, kernel_size=1, strides=1, padding='same', activation='relu')(a7)
        a8 = tf.keras.layers.Conv1D(filters=8, kernel_size=1, padding='same', activation='relu')(in1)
        in2 = tf.keras.layers.concatenate([a5, a6, a7, a8], axis=2)
        a = tf.keras.layers.MaxPooling1D(pool_size=3, strides=2, padding='same')(in2)

        #a = tf.keras.layers.MaxPooling1D(pool_size=2,strides=2)(a)
        #a = tf.keras.layers.Conv1D(filters=16, kernel_size=5, strides=1, activation='relu')(a)
        #a = tf.keras.layers.Conv1D(filters=16, kernel_size=5, strides=1, activation='relu')(a)
        #a = tf.keras.layers.Conv1D(filters=16, kernel_size=5, strides=1, activation='relu')(a)
        #a = tf.keras.layers.MaxPooling1D(pool_size=2,strides=2)(a)        
        #a = tf.keras.layers.Conv1D(filters=16, kernel_size=5, strides=2, activation='relu')(a)
        #a = tf.keras.layers.Conv1D(filters=16, kernel_size=5, strides=2, activation='relu')(a)
        #a = tf.keras.layers.Conv1D(filters=16, kernel_size=5, strides=2, activation='relu')(a)
        #a = tf.keras.layers.MaxPooling1D(pool_size=2,strides=2)(a)
        #a = tf.keras.layers.Conv1D(filters=8, kernel_size=5, strides=2, activation='relu')(a)
        #a = tf.keras.layers.Conv1D(filters=8, kernel_size=5, strides=2, activation='relu')(a)
        #a = tf.keras.layers.Conv1D(filters=8, kernel_size=5, strides=2, activation='relu')(a)
        #a = tf.keras.layers.Conv1D(filters=32, kernel_size=5, strides=1, activation='relu')(a)
        #a = tf.keras.layers.Conv1D(filters=24, kernel_size=5, strides=1, activation='relu')(a)
        #a = tf.keras.layers.MaxPooling1D(pool_size=2,strides=2)(a)
        a = tf.keras.layers.Flatten()(a)
        a2 = tf.keras.layers.Dense(512,activation='relu')(a)
        a2 = tf.keras.layers.Dense(512,activation='relu')(a2)
        #a2 = tf.keras.layers.Dense(1024,activation='relu')(a2)
        a2 = tf.keras.layers.Dense(2*self.z_dim*self.n_modes + self.n_modes)(a2)
        self.encoder_r1 = tf.keras.Model(inputs=r1_input_y, outputs=a2)
        print(self.encoder_r1.summary())

        # the q encoder network
        #q_input_y = tf.keras.Input(shape=(self.y_dim, self.n_channels))
        q_input_x = tf.keras.Input(shape=(self.x_dim))
        #b = tf.keras.layers.Conv1D(filters=48, kernel_size=5, strides=1, activation='relu')(q_input_y)
        #b = tf.keras.layers.Conv1D(filters=48, kernel_size=7, strides=1, activation='relu')(b)
        #b = tf.keras.layers.MaxPooling1D(pool_size=2,strides=2)(b)
        #b = tf.keras.layers.Conv1D(filters=48, kernel_size=7, strides=1, activation='relu')(b)
        #b = tf.keras.layers.Conv1D(filters=48, kernel_size=5, strides=1, activation='relu')(b)
        #b = tf.keras.layers.MaxPooling1D(pool_size=2,strides=2)(b)
        #b = tf.keras.layers.Flatten()(b)
        c = tf.keras.layers.Flatten()(q_input_x)
        d = tf.keras.layers.concatenate([a,c])        
        e = tf.keras.layers.Dense(512,activation='relu')(d)
        e = tf.keras.layers.Dense(512,activation='relu')(e)
        #e = tf.keras.layers.Dense(1024,activation='relu')(e)
        e = tf.keras.layers.Dense(2*self.z_dim)(e)
        self.encoder_q = tf.keras.Model(inputs=[r1_input_y, q_input_x], outputs=e)
        print(self.encoder_q.summary())

        # the r2 decoder network
        #r2_input_y = tf.keras.Input(shape=(self.y_dim, self.n_channels))
        r2_input_z = tf.keras.Input(shape=(self.z_dim))
        #f = tf.keras.layers.Conv1D(filters=48, kernel_size=5, strides=1, activation='relu')(r2_input_y)
        #f = tf.keras.layers.Conv1D(filters=48, kernel_size=7, strides=1, activation='relu')(f)
        #f = tf.keras.layers.MaxPooling1D(pool_size=2,strides=2)(f)
        #f = tf.keras.layers.Conv1D(filters=48, kernel_size=7, strides=1, activation='relu')(f)
        #f = tf.keras.layers.Conv1D(filters=48, kernel_size=5, strides=1, activation='relu')(f)
        #f = tf.keras.layers.MaxPooling1D(pool_size=2,strides=2)(f) 
        #f = tf.keras.layers.Flatten()(f)
        g = tf.keras.layers.Flatten()(r2_input_z)
        h = tf.keras.layers.concatenate([a,g])
        i = tf.keras.layers.Dense(512,activation='relu')(h)
        i = tf.keras.layers.Dense(512,activation='relu')(i)
        #i = tf.keras.layers.Dense(1024,activation='relu')(i)
        j = tf.keras.layers.Dense(2*self.x_dim)(i)   # one extra for sky mean
        #j = tf.keras.activations.relu(j, max_value=3.0) - 1.0    # limit the location parameters to span the full range +/-1
        #k = -1.0*tf.keras.layers.Dense(self.x_dim,activation='relu')(i)   # should be one less dimension for sky variance but it's easier to code like this
        #m = tf.keras.layers.concatenate([j,k])    # make log variances negative only
        self.decoder_r2 = tf.keras.Model(inputs=[r1_input_y, r2_input_z], outputs=j)
        print(self.decoder_r2.summary())

    def encode_r1(self, y=None):
        mean, logvar, weight = tf.split(self.encoder_r1(y), num_or_size_splits=[self.z_dim*self.n_modes, self.z_dim*self.n_modes,self.n_modes], axis=1)
        return tf.reshape(mean,[-1,self.n_modes,self.z_dim]), tf.reshape(logvar,[-1,self.n_modes,self.z_dim]), tf.reshape(weight,[-1,self.n_modes])
        #return tf.split(self.encoder_r1(y), num_or_size_splits=[self.z_dim*self.n_modes, self.z_dim*self.n_modes,self.n_modes], axis=1)

    def encode_q(self, x=None, y=None):
        #mean, logvar = tf.split(self.encoder_q([y,x]), num_or_size_splits=[self.z_dim,self.z_dim], axis=1)
        #return tf.reshape(mean,[-1,self.z_dim]), tf.reshape(logvar,[-1,self.z_dim])
        return tf.split(self.encoder_q([y,x]), num_or_size_splits=[self.z_dim,self.z_dim], axis=1)

    def decode_r2(self, y=None, z=None, apply_sigmoid=False):
        return tf.split(self.decoder_r2([y,z]), num_or_size_splits=[self.x_dim,self.x_dim], axis=1)
        #return mean, logvar

optimizer = tf.keras.optimizers.Adam(1e-4)

def compute_loss(model, x, y, ramp=1.0, noiseamp=1.0):

    # randomise distance
    old_d = bounds['luminosity_distance_min'] + tf.boolean_mask(x,dist_mask,axis=1)*(bounds['luminosity_distance_max'] - bounds['luminosity_distance_min'])
    new_x = tf.random.uniform(shape=tf.shape(old_d), minval=0.0, maxval=1.0, dtype=tf.dtypes.float32)
    new_d = bounds['luminosity_distance_min'] + new_x*(bounds['luminosity_distance_max'] - bounds['luminosity_distance_min'])
    x = tf.gather(tf.concat([tf.reshape(tf.boolean_mask(x,not_dist_mask,axis=1),[-1,tf.shape(x)[1]-1]), tf.reshape(new_x,[-1,1])],axis=1),tf.constant(idx_dist_mask),axis=1)
    dist_scale = tf.tile(tf.expand_dims(old_d/new_d,axis=1),(1,tf.shape(y)[1],1))

    # add noise and randomise distance again
    y = (y*dist_scale + noiseamp*tf.random.normal(shape=tf.shape(y), mean=0.0, stddev=1.0, dtype=tf.float32))/params['y_normscale']
    #y = (y + noiseamp*tf.random.normal(shape=tf.shape(y), mean=0.0, stddev=1.0, dtype=tf.float32))/params['y_normscale']
    mean_r1, logvar_r1, logweight_r1 = model.encode_r1(y=y)
    scale_r1 = EPS + tf.sqrt(tf.exp(logvar_r1))
    gm_r1 = tfd.MixtureSameFamily(mixture_distribution=tfd.Categorical(logits=logweight_r1),
            components_distribution=tfd.MultivariateNormalDiag(
            loc=mean_r1,
            scale_diag=scale_r1))
    mean_q, logvar_q = model.encode_q(x=x,y=y)
    scale_q = EPS + tf.sqrt(tf.exp(logvar_q))
    mvn_q = tfp.distributions.MultivariateNormalDiag(
                          loc=mean_q,
                          scale_diag=scale_q)
    #mvn_q = tfd.Normal(loc=mean_q,scale=scale_q)
    z_samp = mvn_q.sample()
    mean_r2, logvar_r2 = model.decode_r2(z=z_samp,y=y)
    scale_r2 = EPS + tf.sqrt(tf.exp(logvar_r2))

    # SIMPLE RECON LOSS
    mvn_r2 = tfp.distributions.MultivariateNormalDiag(
                          loc=tf.slice(mean_r2,[0,0],[-1,model.x_dim]),
                          scale_diag=scale_r2)

    # LOCALISATION
    #ra = 2*np.pi*tf.reshape(tf.boolean_mask(x,ra_mask,axis=1),[-1,1])       # convert the scaled 0->1 true RA value back to radians
    #dec = np.pi*(tf.reshape(tf.boolean_mask(x,dec_mask,axis=1),[-1,1]) - 0.5) # convert the scaled 0>1 true dec value back to radians
    #dist = tf.boolean_mask(x,dist_mask,axis=1)
    #xyz_actual = tf.reshape(tf.concat([dist*tf.cos(ra)*tf.cos(dec),dist*tf.sin(ra)*tf.cos(dec),dist*tf.sin(dec)],axis=1),[-1,3])   # construct the true parameter unit vector   
    #new_x = tf.gather(tf.concat([xyz_actual, tf.boolean_mask(x,not_xyz_mask,axis=1)],axis=1),tf.constant(idx_xyz_mask),axis=1)
    simple_cost_recon = -1.0*tf.reduce_mean(mvn_r2.log_prob(x))

    selfent_q = -1.0*tf.reduce_mean(mvn_q.entropy())
    log_r1_q = gm_r1.log_prob(z_samp)   # evaluate the log prob of r1 at the q samples
    cost_KL = selfent_q - tf.reduce_mean(log_r1_q)
    return simple_cost_recon, cost_KL

    # TRUNCATED GAUSSIAN
    loc_gauss = tf.boolean_mask(mean_r2,gauss_mask,axis=1)
    scale_gauss = tf.boolean_mask(scale_r2,gauss_mask,axis=1)
    delta = tf.convert_to_tensor(10.0*(1.0-ramp), dtype=tf.float32)    # evolve boundaries 
    gauss = tfd.TruncatedNormal(loc_gauss,scale_gauss,-1.0*delta,1.0 + delta)   # shrink the truncation with the ramp
    log_prob_gauss = tf.reduce_sum(gauss.log_prob(tf.boolean_mask(x,gauss_mask,axis=1)),axis=1)

    # CONDITIONAL MASS
    mean_m1 = tf.boolean_mask(mean_r2,m1_mask,axis=1)
    mean_m2 = tf.boolean_mask(mean_r2,m2_mask,axis=1)
    scale_m1 = tf.boolean_mask(scale_r2,m1_mask,axis=1)
    scale_m2 = tf.boolean_mask(scale_r2,m2_mask,axis=1)
    delta = tf.convert_to_tensor(10.0*(1.0-ramp), dtype=tf.float32)     # evolve boundaries
    joint = tfd.JointDistributionSequential([    # shrink the truncation with the ramp
               tfd.TruncatedNormal(mean_m1,scale_m1,-1.0*delta,1.0 + delta,validate_args=True,allow_nan_stats=False), #reinterpreted_batch_ndims=None),  # m1
               lambda m1: tfd.TruncatedNormal(mean_m2,scale_m2,-1.0*delta,m1 + delta,validate_args=True,allow_nan_stats=False)],    # m2
        validate_args=True)
    log_prob_masses = joint.log_prob((tf.boolean_mask(x,m1_mask,axis=1),tf.boolean_mask(x,m2_mask,axis=1)))

    # VON-MISES
    delta = tf.convert_to_tensor(0.1 + 0.9*ramp, dtype=tf.float32)
    mean_vonmise = (tf.boolean_mask(mean_r2,vonmise_mask,axis=1) - 0.5)*delta + 0.5
    scale_vonmise = tf.boolean_mask(scale_r2,vonmise_mask,axis=1)*delta 
    scaled_x = (tf.boolean_mask(x,vonmise_mask,axis=1) - 0.5)*delta + 0.5    # shrink the parameter space into the centre of the VM
    con = tf.reshape(tf.square(tf.math.reciprocal(scale_vonmise)),[-1,vonmise_len])   # modelling wrapped scale output as log variance
    vonmise = tfp.distributions.VonMises(loc=2.0*np.pi*mean_vonmise, concentration=con)
    log_prob_vonmise = tf.reduce_sum(vonmise.log_prob(2.0*np.pi*tf.reshape(scaled_x,[-1,vonmise_len])),axis=1)

    # SKY
    #mean_sky = tf.boolean_mask(mean_r2,sky_mask,axis=1) - 0.5   # centre the domain around zero
    #scale_sky = tf.boolean_mask(scale_r2,ra_mask[:-1],axis=1)   # just use the ra mask
    #con = tf.reshape(tf.square(tf.math.reciprocal(scale_sky)),[-1])   # modelling wrapped scale output as log variance - only 1 concentration parameter for all sky
    #von_mises_fisher = tfp.distributions.VonMisesFisher(
    #                      mean_direction=tf.math.l2_normalize(tf.reshape(mean_sky,[-1,3]),axis=1),
    #                      concentration=con)   # define p_vm(2*pi*mu,con=1/sig^2)
    #ra = 2*np.pi*tf.reshape(tf.boolean_mask(x,ra_mask[:-1],axis=1),[-1,1])       # convert the scaled 0->1 true RA value back to radians
    #dec = np.pi*(tf.reshape(tf.boolean_mask(x,dec_mask[:-1],axis=1),[-1,1]) - 0.5) # convert the scaled 0>1 true dec value back to radians
    #xyz_unit = tf.reshape(tf.concat([tf.cos(ra)*tf.cos(dec),tf.sin(ra)*tf.cos(dec),tf.sin(dec)],axis=1),[-1,3])   # construct the true parameter unit vector
    #log_prob_sky = von_mises_fisher.log_prob(tf.math.l2_normalize(xyz_unit,axis=1))   # normalise it for safety (should already be normalised) and compute the logprob   
    #gauss_sky = tfp.distributions.MultivariateNormalDiag(
    #                      loc=tf.boolean_mask(mean_r2,xyz_mask,axis=1),
    #                      scale_diag=tf.boolean_mask(scale_r2,xyz_mask,axis=1))
    #log_prob_sky = gauss_sky.log_prob(xyz_actual)

    cost_recon = -1.0*tf.reduce_mean(log_prob_gauss + log_prob_masses + log_prob_vonmise) # + log_prob_sky)
    #cost_recon = simple_cost_recon*(1.0-ramp) + adv_cost_recon*ramp     # transition from simple to advanced
    #log_q_q = mvn_q.log_prob(z_samp)
    selfent_q = -1.0*tf.reduce_mean(mvn_q.entropy())
    log_r1_q = gm_r1.log_prob(z_samp)   # evaluate the log prob of r1 at the q samples
    #cost_KL = tf.reduce_mean(log_q_q - log_r1_q)
    cost_KL = selfent_q - tf.reduce_mean(log_r1_q)
    return cost_recon, cost_KL

def gen_samples(model, y, ramp=1.0, nsamples=1000, max_samples=32):

    y = y/params['y_normscale']
    y = tf.tile(y,(max_samples,1,1))
    samp_iterations = int(nsamples/max_samples) + 1
    for i in range(samp_iterations):
        mean_r1, logvar_r1, logweight_r1 = model.encode_r1(y=y)
        scale_r1 = EPS + tf.sqrt(tf.exp(logvar_r1))
        gm_r1 = tfd.MixtureSameFamily(mixture_distribution=tfd.Categorical(logits=logweight_r1),
            components_distribution=tfd.MultivariateNormalDiag(
            loc=mean_r1,
            scale_diag=scale_r1))
        z_samp = gm_r1.sample()
        mean_r2, logvar_r2 = model.decode_r2(z=z_samp,y=y)
        scale_r2 = EPS + tf.sqrt(tf.exp(logvar_r2))

        mvn_r2 = tfp.distributions.MultivariateNormalDiag(
                          loc=tf.slice(mean_r2,[0,0],[-1,model.x_dim]),
                          scale_diag=scale_r2)
        if i==0:
            x_sample = mvn_r2.sample()
        else:
            x_sample = tf.concat([x_sample,mvn_r2.sample()],axis=0)
    return x_sample

    # LOCALISATION
    #xyz = tf.reshape(tf.boolean_mask(x_sample,xyz_mask,axis=1),[-1,3])          # sample the distribution
    #samp_dist = tf.reshape(tf.math.reduce_euclidean_norm(xyz,axis=1),[-1,1])
    #normed_xyz = tf.math.l2_normalize(xyz,axis=1)
    #samp_ra = tf.math.floormod(tf.atan2(tf.slice(normed_xyz,[0,1],[-1,1]),tf.slice(normed_xyz,[0,0],[-1,1])),2.0*np.pi)/(2.0*np.pi)   # convert to the rescaled 0->1 RA from the unit vector
    #samp_dec = (tf.asin(tf.slice(normed_xyz,[0,2],[-1,1])) + 0.5*np.pi)/np.pi                       # convert to the rescaled 0->1 dec from the unit vector
    #x_samp_sky = tf.reshape(tf.concat([samp_dist,samp_ra,samp_dec],axis=1),[-1,3])             # group the sky samples
    #new_x = tf.gather(tf.concat([x_samp_sky, tf.boolean_mask(x_sample,not_xyz_mask,axis=1)],axis=1),tf.constant(idx_xyz_mask),axis=1)
    #return new_x
    
    # GAUSSIAN PARAMS
    loc_gauss = tf.boolean_mask(mean_r2,gauss_mask,axis=1)
    scale_gauss = tf.boolean_mask(scale_r2,gauss_mask,axis=1)
    delta = tf.convert_to_tensor(10.0*(1.0-ramp), dtype=tf.float32)
    gauss = tfd.TruncatedNormal(loc_gauss,scale_gauss,-1.0*delta,1.0+delta)   # shrink the truncation with the ramp
    x_samp_gauss = gauss.sample()

    # CONDITIONAL MASS
    mean_m1 = tf.boolean_mask(mean_r2,m1_mask,axis=1)
    mean_m2 = tf.boolean_mask(mean_r2,m2_mask,axis=1)
    scale_m1 = tf.boolean_mask(scale_r2,m1_mask,axis=1)
    scale_m2 = tf.boolean_mask(scale_r2,m2_mask,axis=1)
    delta = tf.convert_to_tensor(10.0*(1.0-ramp), dtype=tf.float32)
    joint = tfd.JointDistributionSequential([    # shrink the truncation with the ramp
               tfd.TruncatedNormal(mean_m1,scale_m1,-1.0*delta,1.0+delta,validate_args=True,allow_nan_stats=False), #reinterpreted_batch_ndims=None),  # m1
               lambda m1: tfd.TruncatedNormal(mean_m2,scale_m2,-1.0*delta,m1+delta,validate_args=True,allow_nan_stats=False)],    # m2
        validate_args=True)
    x_samp_masses = tf.transpose(tf.reshape(joint.sample(),[2,-1]))

    # VON-MISES
    delta = tf.convert_to_tensor(0.1 + 0.9*ramp, dtype=tf.float32)
    mean_vonmise = (tf.boolean_mask(mean_r2,vonmise_mask,axis=1) - 0.5)*delta + 0.5
    scale_vonmise = tf.boolean_mask(scale_r2,vonmise_mask,axis=1)*delta
    con = tf.reshape(tf.square(tf.math.reciprocal(scale_vonmise)),[-1,vonmise_len])   # modelling wrapped scale output as log variance
    vonmise = tfp.distributions.VonMises(loc=2.0*np.pi*mean_vonmise, concentration=con)
    x_samp_vonmise = tf.math.floormod(vonmise.sample(),(2.0*np.pi))/(2.0*np.pi)
    
    # SKY
    #gauss_sky = tfp.distributions.MultivariateNormalDiag(
    #                      loc=tf.boolean_mask(mean_r2,xyz_mask,axis=1),
    #                      scale_diag=tf.boolean_mask(scale_r2,xyz_mask,axis=1))
    #xyz = gauss_sky.sample()          # sample the distribution
    #samp_dist = tf.reshape(tf.math.reduce_euclidean_norm(xyz,axis=1),[-1,1])
    #normed_xyz = tf.math.l2_normalize(xyz,axis=1)
    #samp_ra = tf.math.floormod(tf.atan2(tf.slice(normed_xyz,[0,1],[-1,1]),tf.slice(normed_xyz,[0,0],[-1,1])),2.0*np.pi)/(2.0*np.pi)   # convert to the rescaled 0->1 RA from the unit vector
    #samp_dec = (tf.asin(tf.slice(normed_xyz,[0,2],[-1,1])) + 0.5*np.pi)/np.pi                       # convert to the rescaled 0->1 dec from the unit vector
    #x_samp_sky = tf.reshape(tf.concat([samp_dist,samp_ra,samp_dec],axis=1),[-1,3])             # group the sky samples
    #mean_sky = tf.boolean_mask(mean_r2,sky_mask,axis=1) - 0.5    # centre the domain around zero
    #scale_sky = tf.boolean_mask(scale_r2,ra_mask[:-1],axis=1)   # just use the ra mask
    #con = tf.reshape(tf.square(tf.math.reciprocal(scale_sky)),[-1])   # modelling wrapped scale output as log variance - only 1 concentration parameter for all sky
    #von_mises_fisher = tfp.distributions.VonMisesFisher(
    #                      mean_direction=tf.math.l2_normalize(tf.reshape(mean_sky,[-1,3]),axis=1),
    #                      concentration=con)   # define p_vm(2*pi*mu,con=1/sig^2)
    #xyz = tf.reshape(von_mises_fisher.sample(),[-1,3])          # sample the distribution
    #samp_ra = tf.math.floormod(tf.atan2(tf.slice(xyz,[0,1],[-1,1]),tf.slice(xyz,[0,0],[-1,1])),2.0*np.pi)/(2.0*np.pi)   # convert to the rescaled 0->1 RA from the unit vector
    #samp_dec = (tf.asin(tf.slice(xyz,[0,2],[-1,1])) + 0.5*np.pi)/np.pi                       # convert to the rescaled 0->1 dec from the unit vector
    #x_samp_sky = tf.reshape(tf.concat([samp_ra,samp_dec],axis=1),[-1,2])             # group the sky samples

    return tf.gather(tf.concat([x_samp_masses, x_samp_gauss, x_samp_vonmise],axis=1),tf.constant(idx_mask),axis=1)

def gen_z_samples(model, x, y, nsamples=1000, max_samples=32):
    y = y/params['y_normscale']
    y = tf.tile(y,(max_samples,1,1))
    x = tf.tile(x,(max_samples,1))
    samp_iterations = int(nsamples/max_samples) + 1
    for i in range(samp_iterations):
        mean_r1_temp, logvar_r1, logweight_r1 = model.encode_r1(y=y)
        scale_r1 = EPS + tf.sqrt(tf.exp(logvar_r1))
        gm_r1 = tfd.MixtureSameFamily(mixture_distribution=tfd.Categorical(logits=logweight_r1),
            components_distribution=tfd.MultivariateNormalDiag(
            loc=mean_r1_temp,
            scale_diag=scale_r1))
        #z_samp_r1 = gm_r1.sample()
        mean_q_temp, logvar_q = model.encode_q(x=x,y=y)
        scale_q = EPS + tf.sqrt(tf.exp(logvar_q))
        mvn_q = tfp.distributions.MultivariateNormalDiag(
                          loc=mean_q_temp,
                          scale_diag=scale_q)
        if i==0:    
            z_samp_q = mvn_q.sample()
            z_samp_r1 = gm_r1.sample()
            mean_r1 = mean_r1_temp
            mean_q = mean_q_temp
        else:
            z_samp_q = tf.concat([z_samp_q,mvn_q.sample()],axis=0)
            z_samp_r1 = tf.concat([z_samp_r1,gm_r1.sample()],axis=0)
            mean_r1 = tf.concat([mean_r1,mean_r1_temp],axis=0)
            mean_q = tf.concat([mean_q,mean_q_temp],axis=0)
    return mean_r1, z_samp_r1, mean_q, z_samp_q

@tf.function
def train_step(model, x, y, optimizer, ramp=1.0):
    """Executes one training step and returns the loss.
    This function computes the loss and gradients, and uses the latter to
    update the model's parameters.
    """
    with tf.GradientTape() as tape:
        r_loss, kl_loss = compute_loss(model, x, y, ramp=ramp)
        loss = r_loss + ramp*kl_loss
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    train_loss_metric(loss)
    return r_loss, kl_loss

epochs = 25000
train_loss_metric = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
train_log_dir = '/data/www.astro/chrism/Heni_results/' + run + '/logs'
#os.mkdir(train_log_dir)
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
model = CVAE(x_data_train.shape[1], params['ndata'], y_data_train.shape[2], params['z_dimension'], params['n_modes'])

# start the training loop
train_loss = np.zeros((epochs,3))
val_loss = np.zeros((epochs,3))
ramp_start = 350
ramp_stop = 600
ramp_grad = 1.0/(ramp_stop - ramp_start)
KL_samples = []
for epoch in range(1, epochs + 1):

    train_loss_kl_q = 0.0
    train_loss_kl_r1 = 0.0
    start_time_train = time.time()
    ramp = tf.convert_to_tensor(np.min([(epoch-ramp_start)*ramp_grad,1.0]).astype(np.single) if epoch>ramp_start else 0.0,dtype=tf.float32)
    for step, (x_batch_train, y_batch_train) in train_dataset.enumerate():
        temp_train_r_loss, temp_train_kl_loss = train_step(model, x_batch_train, y_batch_train, optimizer, ramp=ramp)
        train_loss[epoch-1,0] += temp_train_r_loss
        train_loss[epoch-1,1] += temp_train_kl_loss
    train_loss[epoch-1,2] = train_loss[epoch-1,0] + ramp*train_loss[epoch-1,1]
    train_loss[epoch-1,:] /= float(step+1)
    end_time_train = time.time()
    with train_summary_writer.as_default():
        tf.summary.scalar('loss', train_loss_metric.result(), step=epoch)
    train_loss_metric.reset_states()

    start_time_val = time.time()
    for step, (x_batch_val, y_batch_val) in val_dataset.enumerate():
        temp_val_r_loss, temp_val_kl_loss = compute_loss(model, x_batch_val, y_batch_val, ramp=ramp)
        val_loss[epoch-1,0] += temp_val_r_loss
        val_loss[epoch-1,1] += temp_val_kl_loss
    val_loss[epoch-1,2] = val_loss[epoch-1,0] + ramp*val_loss[epoch-1,1]
    val_loss[epoch-1,:] /= float(step+1)
    end_time_val = time.time()

    print('Epoch: {}, Training RECON: {}, KL: {}, TOTAL: {}, time elapsed: {}'
        .format(epoch, train_loss[epoch-1,0], train_loss[epoch-1,1], train_loss[epoch-1,2], end_time_train - start_time_train))
    print('Epoch: {}, Validation RECON: {}, KL: {}, TOTAL: {}, time elapsed {}'
        .format(epoch, val_loss[epoch-1,0], val_loss[epoch-1,1], val_loss[epoch-1,2], end_time_val - start_time_val))

    # update loss plot
    plot_losses(train_loss, val_loss, epoch, run=run)

    # generate and plot posterior samples for the latent space and the parameter space 
    if epoch % 250 == 0:
        for step, (x_batch_test, y_batch_test) in test_dataset.enumerate():             
            mu_r1, z_r1, mu_q, z_q = gen_z_samples(model, x_batch_test, y_batch_test, nsamples=8000)
            plot_latent(mu_r1,z_r1,mu_q,z_q,epoch,step,run=run)
            start_time_test = time.time()
            samples = gen_samples(model, y_batch_test, ramp=ramp, nsamples=8000)
            end_time_test = time.time()
            if np.any(np.isnan(samples)):
                print('Epoch: {}, found nans in samples. Not making plots'.format(epoch))
                for k,s in enumerate(samples):
                    if np.any(np.isnan(s)):
                        print(k,s)
                KL_est = -1.0
            else:
                print('Epoch: {}, Testing time elapsed for 8000 samples: {}'.format(epoch,end_time_test - start_time_test))
                if params['doPE']==True:
                    KL_est = plot_posterior(samples,x_batch_test[0,:],epoch,step,other_samples=bilby_samples[step,:],run=run)
                else:
                    KL_est = -1.0
                _ = plot_posterior(samples,x_batch_test[0,:],epoch,step,run=run)
            KL_samples.append(KL_est)

        # plot KL evolution
        #plot_KL(np.reshape(np.array(KL_samples),[-1,5]),250,run=run)
    #trying to plot pp
    #if epoch % 50 == 0:
    #    x_data_pp, y_data_pp_noisefree, y_data_pp, snrs_pp = load_data(params,bounds,fixed_vals,params['pp_test_set_dir'],params['inf_pars'],test_data=True)
    #    pp_dataset = (tf.data.Dataset.from_tensor_slices((x_data_pp,y_data_pp)).batch(1))
    #    results = []
    #    for step, (x_batch_test, y_batch_test) in pp_dataset.enumerate():
    #        start_time_test = time.time()
    #        samples=gen_samples(model,y_batch_test,ramp=ramp,nsamples=8000)
    #        end_time_test = time.time()
    #        result=bilby.result.Result(label='test',injection_parameters=x_data_pp,posterior=samples,search_parameter_keys=params['inf_pars'])
    #        results.append(result)
    #    bilby.result.make_pp_plot(results, filename=('/data/www.astro/chrism/Heni_results/%s/pp_plot_epoch_%d_event_%d.png' % (run,epoch,step)),
    #                          confidence_interval=0.9)

    # load more noisefree training data back in
    if epoch % 25 == 0:
        x_data_train, y_data_train, _, snrs_train = load_data(params,bounds,fixed_vals,params['train_set_dir'],params['inf_pars'],silent=True)
        train_dataset = (tf.data.Dataset.from_tensor_slices((x_data_train,y_data_train))
                 .shuffle(train_size).batch(batch_size))

