#!/usr/bin/env python3
"""
newcos7-4.py

- Restores gwmemory-based memory addition exactly as in newcos4-2.py:
  osc(FD) + mem(TD->FD) on the bilby frequency grid.
- Avoids bilby waveform parameter KeyErrors by sampling in mass_1/mass_2
  (detector-frame masses), and keeping spin/sky as DeltaFunctions.
- Always computes both reweightings and always applies SNR selection.
- Output structure:
  /data/www.astro/chrism/newcos/outdir/<run-id>_<mem>/<run-id>_<mem>_<seed>/<outputs>
- Adds Time Domain plotting for Signal (Osc vs Mem) + Noise.
"""

import argparse
import glob
import os
from pathlib import Path
import shutil
import copy
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes, mark_inset
from scipy.optimize import brentq
from scipy.signal.windows import tukey

from scipy.signal import butter, filtfilt

import gwmemory
import bilby
from bilby.core.utils import nfft, infft
from astropy.cosmology import Planck18 as cosmo
from astropy.cosmology import z_at_value
import astropy.units as u

from ldc.lisa.noise.noise import AnalyticNoise

# LISA/BBHx constants used in LISA mode
try:
    from bbhx.utils.constants import PC_SI, YRSID_SI
except ImportError:
    # Fallback if bbhx not installed, though likely needed for LISA mode
    PC_SI = 3.0856775814913673e16
    YRSID_SI = 31558149.763545603

# physical constants (SI) for simple sanity checks
G_SI = 6.67430e-11
C_SI = 299792458.0
MSUN_SI = 1.988409870698051e30


# ---- helpers for scaling ----

_MS = 1e-3  # seconds
_MPC_PER_GPC = 1e3

# =========================
# Scaling helpers
# =========================

import inspect
import warnings

warnings.filterwarnings(
    "ignore",
    message=r".*invalid value encountered in power.*",
    category=RuntimeWarning,
    module=r"gwmemory\.waveforms\.mwm",
)
warnings.filterwarnings(
    "ignore",
    message=r".*Casting complex values to real discards the imaginary part.*",
    module=r"gwmemory\.waveforms\.mwm",
)

def save_fisher_results(out, outdir, label="fisher"):
    """
    Save Fisher results dict returned by
    fisher_uncertainties_finite_difference_scaled_marginalised
    """

    os.makedirs(outdir, exist_ok=True)

    filename = os.path.join(outdir, f"{label}_results.npz")

    np.savez_compressed(
        filename,
        fisher_phi_full=out["fisher_phi_full"],
        cov_phi_full=out["cov_phi_full"],
        fisher_phi_marg=out["fisher_phi_marg"],
        cov_phi_marg=out["cov_phi_marg"],
        cov_theta_marg=out["cov_theta_marg"],
        kept_keys=np.array(out["kept_keys"], dtype=object),
        nuisance_keys=np.array(out["nuisance_keys"], dtype=object),
        jacobian_diag_kept=out["jacobian_diag_kept_dtheta_dphi"],
    )

    print(f"Saved Fisher results to {filename}")

def make_emcee_pos0_from_injection(priors, injection_parameters, nwalkers, u_sigma=1e-4, seed=0):
    """
    Build emcee initial positions (pos0) as an (nwalkers, ndim) array,
    clustered around the injection truth.

    Strategy:
      - Work in unit-cube coordinates u using prior CDF (if available)
      - Add small Gaussian jitter in u (scale u_sigma), clip to (eps, 1-eps)
      - Map back to physical params via inverse-CDF / rescale (if available)

    Falls back to simple multiplicative jitter in physical space if CDF/rescale is unavailable.
    """
    rng = np.random.default_rng(seed)
    keys = list(priors.non_fixed_keys)  # bilby PriorDict ordering of sampled parameters :contentReference[oaicite:2]{index=2}
    ndim = len(keys)

    pos0 = np.zeros((nwalkers, ndim), dtype=float)
    eps = 1e-12

    for j, k in enumerate(keys):
        prior = priors[k]
        x0 = float(injection_parameters[k])

        # Preferred route: u0 = CDF(x0), then inverse-CDF back via rescale
        if hasattr(prior, "cdf") and callable(getattr(prior, "cdf")) and hasattr(prior, "rescale") and callable(getattr(prior, "rescale")):
            try:
                u0 = float(prior.cdf(x0))
                if not np.isfinite(u0):
                    raise ValueError("Non-finite CDF at injection")

                u = u0 + rng.normal(0.0, u_sigma, size=nwalkers)
                u = np.clip(u, eps, 1.0 - eps)

                pos0[:, j] = np.array([prior.rescale(ui) for ui in u], dtype=float)
                continue
            except Exception:
                pass  # fall back below

        # Fallback: small relative jitter in physical space
        # (clip to prior bounds if present)
        scale = 1e-4 * (abs(x0) if abs(x0) > 0 else 1.0)
        x = x0 + rng.normal(0.0, scale, size=nwalkers)

        # Clip if prior looks bounded (Uniform etc.)
        minimum = getattr(prior, "minimum", None)
        maximum = getattr(prior, "maximum", None)
        if minimum is not None and maximum is not None:
            x = np.clip(x, float(minimum) + 1e-15, float(maximum) - 1e-15)

        pos0[:, j] = x

    return pos0, keys


def filter_kwargs_for_callable(func, kwargs, aliases=None, strict=False):
    """
    Filter kwargs to only those accepted by `func` (based on inspect.signature),
    with optional alias mapping.

    Parameters
    ----------
    func : callable
    kwargs : dict
    aliases : dict[str,str], optional
        Map from our internal key -> func expected key.
        E.g. {"luminosity_distance": "distance"}.
    strict : bool
        If True, raise if a required alias target isn't present in signature.

    Returns
    -------
    dict
        Filtered kwargs ready to pass to func(**kwargs).
    """
    aliases = {} if aliases is None else dict(aliases)

    sig = inspect.signature(func)
    params = sig.parameters

    # If func accepts **kwargs, we can pass everything (but still apply aliases).
    has_var_kw = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values())

    out = {}
    # First apply alias mapping (if applicable)
    for k, v in kwargs.items():
        kk = aliases.get(k, k)
        out[kk] = v

    if has_var_kw:
        return out

    # Otherwise filter by explicit parameter names
    allowed = set(params.keys())
    filtered = {k: v for k, v in out.items() if k in allowed}

    if strict:
        # Optionally ensure alias targets actually exist if strict
        for src, tgt in aliases.items():
            if src in kwargs and tgt not in allowed:
                raise TypeError(
                    f"Callable {func} does not accept alias target '{tgt}' "
                    f"(from '{src}'). Signature keys: {sorted(allowed)}"
                )

    return filtered


def default_scaling_spec():
    """
    Scaling rules for better conditioning.

    Rules supported:
      ("log",)        : phi = log(theta)
      ("div", a)      : phi = theta / a
      ("identity",)   : phi = theta

    Requested:
      - log(chirp_mass)
      - geocent_time in ms  (phi = t / 1e-3)
      - H0 in H0/100        (phi = H0 / 100)
      - log(luminosity_distance)
    """
    return {
        "chirp_mass": ("log",),
        "luminosity_distance": ("log",),
        "geocent_time": ("div", _MS),
        "H0": ("div", 100.0),
        # everything else defaults to identity if not listed
    }

def theta_to_phi(theta, rule):
    kind = rule[0]
    if kind == "log":
        if theta <= 0:
            raise ValueError(f"log scaling requires theta>0, got {theta}")
        return np.log(theta)
    if kind == "div":
        return theta / rule[1]
    if kind == "identity":
        return theta
    raise ValueError(f"Unknown scaling rule: {rule}")

def phi_to_theta(phi, rule):
    kind = rule[0]
    if kind == "log":
        return np.exp(phi)
    if kind == "div":
        return phi * rule[1]
    if kind == "identity":
        return phi
    raise ValueError(f"Unknown scaling rule: {rule}")

def dtheta_dphi(theta0, rule):
    """Jacobian diagonal element dtheta/dphi evaluated at theta0."""
    kind = rule[0]
    if kind == "log":
        # theta = exp(phi) => dtheta/dphi = theta
        return float(theta0)
    if kind == "div":
        # theta = a*phi => dtheta/dphi = a
        return float(rule[1])
    if kind == "identity":
        return 1.0
    raise ValueError(f"Unknown scaling rule: {rule}")


# =========================
# Schur complement marginalisation
# =========================

def schur_marginalise_fisher(
    fisher: np.ndarray,
    param_keys: list[str],
    nuisance_keys: list[str],
    pinv_rcond: float = 1e-12,
):
    """
    Marginalise nuisance parameters from a Fisher matrix using the Schur complement.

    Partition params into kept x and nuisance n:
        Γ = [[A, B],
             [Bᵀ, C]]

    After marginalising n:
        Γ_marg = A - B C^{-1} Bᵀ

    Returns:
      - fisher_marg (kept x kept)
      - cov_kept    = pinv(fisher_marg)
      - sigma_kept dict
      - corr_kept matrix
      - kept_keys list
      - also indices used
    """
    fisher = np.asarray(fisher, float)
    fisher = 0.5 * (fisher + fisher.T)

    key_to_i = {k: i for i, k in enumerate(param_keys)}
    missing = [k for k in nuisance_keys if k not in key_to_i]
    if missing:
        raise KeyError(f"nuisance_keys not found in param_keys: {missing}")

    nuis_idx = np.array([key_to_i[k] for k in nuisance_keys], dtype=int)
    keep_idx = np.array([i for i, k in enumerate(param_keys) if k not in set(nuisance_keys)], dtype=int)
    kept_keys = [param_keys[i] for i in keep_idx]

    A = fisher[np.ix_(keep_idx, keep_idx)]
    B = fisher[np.ix_(keep_idx, nuis_idx)]
    C = fisher[np.ix_(nuis_idx, nuis_idx)]

    Cinv = np.linalg.pinv(C, rcond=pinv_rcond)
    fisher_marg = A - B @ Cinv @ B.T
    fisher_marg = 0.5 * (fisher_marg + fisher_marg.T)

    cov_kept = np.linalg.pinv(fisher_marg, rcond=pinv_rcond)

    sigma_kept = {}
    for i, k in enumerate(kept_keys):
        v = cov_kept[i, i]
        sigma_kept[k] = float(np.sqrt(v)) if v > 0 else float("nan")

    corr_kept = np.zeros_like(cov_kept)
    for i in range(len(kept_keys)):
        for j in range(len(kept_keys)):
            denom = np.sqrt(cov_kept[i, i] * cov_kept[j, j])
            corr_kept[i, j] = cov_kept[i, j] / denom if denom > 0 else np.nan

    return {
        "kept_keys": kept_keys,
        "keep_idx": keep_idx,
        "nuis_idx": nuis_idx,
        "fisher_marg": fisher_marg,
        "cov_kept": cov_kept,
        "sigma_kept": sigma_kept,
        "corr_kept": corr_kept,
    }


# =========================
# Main Fisher computation in SCALED space + marginalisation + map back
# =========================

def fisher_uncertainties_finite_difference_scaled_marginalised(
    likelihood,
    params,
    param_keys,
    nuisance_keys=None,
    steps=None,
    relative_steps=None,
    scaling_spec=None,
    fmin=None,
    fmax=None,
    pinv_rcond=1e-12,
    default_rel=1e-4,
    default_abs=1e-4,
    verbose=True,
):
    """
    1) Build Fisher in scaled coordinates phi (better conditioning).
    2) Marginalise nuisance parameters via Schur complement (optional).
    3) Invert to get covariance on kept params.
    4) Map covariance back to original theta coords for kept params via J dtheta/dphi.

    Returns dict containing:
      - fisher_phi_full, cov_phi_full, sigma_phi_full (not marginalised)
      - fisher_phi_marg, cov_phi_marg, sigma_phi_marg (kept params, in phi)
      - cov_theta_marg, sigma_theta_marg, corr_theta_marg (kept params, in theta)
      - kept_keys
      - jacobian_diag_kept
    """
    steps = {} if steps is None else dict(steps)
    relative_steps = {} if relative_steps is None else dict(relative_steps)
    scaling_spec = default_scaling_spec() if scaling_spec is None else dict(scaling_spec)
    nuisance_keys = [] if nuisance_keys is None else list(nuisance_keys)

    # --- Detect likelihood type / get waveform and PSD containers (uses your existing helpers) ---
    is_ground = hasattr(likelihood, "interferometers") and hasattr(likelihood, "waveform_generator")
    is_lisa = hasattr(likelihood, "data_fd") and hasattr(likelihood, "psd_fd") and hasattr(likelihood, "waveform_model")
    if not (is_ground or is_lisa):
        raise TypeError(
            "Unsupported likelihood type. Expected GravitationalWaveTransient (ground) "
            "or LisaAETFrequencyDomainLikelihood (LISA)."
        )

    if is_ground:
        ifos = likelihood.interferometers
        wfgen = likelihood.waveform_generator
        f = np.asarray(ifos[0].strain_data.frequency_array, float)
        psd_dict = {ifo.name: np.asarray(ifo.power_spectral_density_array, float) for ifo in ifos}
        if fmin is None:
            fmin = getattr(ifos[0], "minimum_frequency", None)
        if fmax is None:
            fmax = getattr(ifos[0], "maximum_frequency", None)

        def h_of(theta_params):
            return _ground_strain_fd_from_params(ifos, wfgen, theta_params)

    else:
        f = np.asarray(likelihood.frequency_array, float)
        psd_dict = {k: np.asarray(v, float) for k, v in likelihood.psd_fd.items()}
        if fmin is None:
            fmin = getattr(likelihood, "fmin", None)
        if fmax is None:
            fmax = getattr(likelihood, "fmax", None)

        def h_of(theta_params):
            return _lisa_strain_fd_from_params(likelihood, theta_params)

    if f.size < 2:
        raise ValueError("Frequency array too small for Fisher computation.")
    df = float(np.median(np.diff(f)))

    m = _mask_freqs(f, fmin=fmin, fmax=fmax)
    m &= np.isfinite(f)

    # Base point theta0
    theta0 = dict(params)

    # Build phi0 over param_keys
    phi0 = {}
    for k in param_keys:
        th = float(theta0[k])
        rule = scaling_spec.get(k, ("identity",))
        phi0[k] = float(theta_to_phi(th, rule))

    def theta_from_phi(phi_dict):
        out = dict(theta0)
        for k in param_keys:
            rule = scaling_spec.get(k, ("identity",))
            out[k] = float(phi_to_theta(phi_dict[k], rule))
        return out

    # Mask PSD
    psd_masked = {ch: np.asarray(psd, float)[m] for ch, psd in psd_dict.items()}

    # Finite-difference derivatives in phi
    derivs = {}  # key -> dict[ch] -> dh/dphi
    for k in param_keys:
        p0k = float(phi0[k])

        if k in steps:
            delta = float(steps[k])
        else:
            #print(relative_steps)
            rel = float(relative_steps.get(k, default_rel))
            delta = rel * abs(p0k) if abs(p0k) > 0 else float(default_abs)

        if delta == 0 or not np.isfinite(delta):
            raise ValueError(f"Bad finite-difference step for scaled {k}: {delta}")

        phi_plus = dict(phi0);  phi_plus[k] = p0k + delta
        phi_minus = dict(phi0); phi_minus[k] = p0k - delta

        h_plus = h_of(theta_from_phi(phi_plus))
        h_minus = h_of(theta_from_phi(phi_minus))

        dh = {}
        for ch in h_plus.keys():
            dh[ch] = (np.asarray(h_plus[ch]) - np.asarray(h_minus[ch])) / (2.0 * delta)
        derivs[k] = dh

        #if verbose:
        #    print(f"[Fisher] d/dphi({k}) step={delta:.3e} phi0={p0k:.6g} theta0={float(theta0[k]):.6g} rule={scaling_spec.get(k,('identity',))}")

    # Build full Fisher in phi
    n = len(param_keys)
    Gamma_phi = np.zeros((n, n), dtype=float)

    for i, ki in enumerate(param_keys):
        for j, kj in enumerate(param_keys[: i + 1]):
            val = 0.0
            #print(ki,kj)
            for ch in psd_masked.keys():
                di = np.asarray(derivs[ki][ch])[m]
                dj = np.asarray(derivs[kj][ch])[m]
                #print(di,dj)
                val += _fd_inner_product(di, dj, psd_masked[ch], df)
            Gamma_phi[i, j] = val
            Gamma_phi[j, i] = val

    Gamma_phi = 0.5 * (Gamma_phi + Gamma_phi.T)

    # Full inverse (often ill-conditioned; kept for reference)
    Cov_phi_full = np.linalg.pinv(Gamma_phi, rcond=pinv_rcond)
    sigma_phi_full = {k: (float(np.sqrt(Cov_phi_full[i,i])) if Cov_phi_full[i,i] > 0 else float("nan"))
                      for i, k in enumerate(param_keys)}

    # --- Marginalise nuisance in phi space ---
    if nuisance_keys:
        marg = schur_marginalise_fisher(
            fisher=Gamma_phi,
            param_keys=param_keys,
            nuisance_keys=nuisance_keys,
            pinv_rcond=pinv_rcond,
        )
        kept_keys = marg["kept_keys"]
        Gamma_phi_marg = marg["fisher_marg"]
        Cov_phi_marg = marg["cov_kept"]
        sigma_phi_marg = marg["sigma_kept"]
        corr_phi_marg = marg["corr_kept"]
    else:
        kept_keys = list(param_keys)
        Gamma_phi_marg = Gamma_phi
        Cov_phi_marg = Cov_phi_full
        sigma_phi_marg = sigma_phi_full
        corr_phi_marg = None

    # --- Map marginalised covariance back to theta space for kept params ---
    # J is diagonal for our simple per-parameter scalings
    J_diag = np.array(
        [dtheta_dphi(float(theta0[k]), scaling_spec.get(k, ("identity",))) for k in kept_keys],
        dtype=float
    )
    J = np.diag(J_diag)

    Cov_theta_marg = J @ Cov_phi_marg @ J.T

    sigma_theta_marg = {}
    for i, k in enumerate(kept_keys):
        v = Cov_theta_marg[i, i]
        sigma_theta_marg[k] = float(np.sqrt(v)) if v > 0 else float("nan")

    corr_theta_marg = np.zeros_like(Cov_theta_marg)
    for i in range(len(kept_keys)):
        for j in range(len(kept_keys)):
            denom = np.sqrt(Cov_theta_marg[i, i] * Cov_theta_marg[j, j])
            corr_theta_marg[i, j] = Cov_theta_marg[i, j] / denom if denom > 0 else np.nan

    if verbose:
        try:
            cond_full = float(np.linalg.cond(Gamma_phi))
        except Exception:
            cond_full = float("inf")
        try:
            cond_marg = float(np.linalg.cond(Gamma_phi_marg))
        except Exception:
            cond_marg = float("inf")

        #print(f"[Fisher] df={df:.6g}, fmin={fmin}, fmax={fmax}, masked bins={int(np.sum(m))}/{f.size}")
        #print(f"[Fisher] cond(full Gamma_phi)={cond_full:.3e}")
        #print(f"[Fisher] cond(marg Gamma_phi)={cond_marg:.3e}")
        #print("[Fisher] Marginalised 1-sigma (theta space):")
        #for k in kept_keys:
        #    print(f"  sigma({k}) = {sigma_theta_marg[k]:.6g}  (original units)")

    return {
        # base points
        "theta0": {k: float(theta0[k]) for k in param_keys},
        "phi0": {k: float(phi0[k]) for k in param_keys},
        "scaling_spec": scaling_spec,
        "df": df,
        "freq_mask": m,

        # full (phi space)
        "param_keys_full": list(param_keys),
        "fisher_phi_full": Gamma_phi,
        "cov_phi_full": Cov_phi_full,
        "sigma_phi_full": sigma_phi_full,

        # marginalised (phi space)
        "nuisance_keys": list(nuisance_keys),
        "kept_keys": list(kept_keys),
        "fisher_phi_marg": Gamma_phi_marg,
        "cov_phi_marg": Cov_phi_marg,
        "sigma_phi_marg": sigma_phi_marg,
        "corr_phi_marg": corr_phi_marg,

        # marginalised mapped to theta space (kept only)
        "jacobian_diag_kept_dtheta_dphi": J_diag,
        "cov_theta_marg": Cov_theta_marg,
        "sigma_theta_marg": sigma_theta_marg,
        "corr_theta_marg": corr_theta_marg,
    }


# =========================
# Example calling code
# =========================


def _fd_inner_product(a, b, psd, df):
    """
    One-sided, frequency-domain noise-weighted inner product:
      (a|b) = 4 Re sum_f a*(f) b(f) / S_n(f) * df
    Assumes arrays already restricted to the desired frequency mask.
    """
    #print(f'inner prod = {4.0 * np.real(np.sum(np.conjugate(a) * b / psd) * df)}')
    return 4.0 * np.real(np.sum(np.conjugate(a) * b / psd) * df)


def _mask_freqs(f, fmin=None, fmax=None):
    m = np.ones_like(f, dtype=bool)
    if fmin is not None:
        m &= (f >= fmin)
    if fmax is not None:
        m &= (f <= fmax)
    return m


def _ground_strain_fd_from_params(ifos, waveform_generator, params):
    """
    Returns dict: {ifo.name: h_fd_complex[freq]}
    """
    pols = waveform_generator.frequency_domain_strain(params)
    out = {}
    for ifo in ifos:
        out[ifo.name] = ifo.get_detector_response(pols, params)
    return out

def _lisa_strain_fd_from_params(lisa_likelihood, params):
    """
    Call your LISA waveform model in the form:
        h = waveform_model(params, frequency_array)

    Expected return:
        dict like {'A': hA, 'E': hE, 'T': hT}
    """
    wm = lisa_likelihood.waveform_model
    f = lisa_likelihood.frequency_array

    h = wm(params, f)  # <-- your model wants (params, frequency_array)

    print("Keys passed to LISA model:", params.keys())

    return {k: np.asarray(v, np.complex128) for k, v in h.items()}

def _fd_inner_product_masked(a, b, psd, df):
    print(np.real(np.conjugate(a) * b / psd * df))
    return 4.0 * np.real(np.sum(np.conjugate(a) * b / psd) * df)

def fisher_ground_bilby_mask(
    likelihood,
    params,
    param_keys,
    steps=None,
    relative_steps=None,
    pinv_rcond=1e-12,
    verbose=True,
):
    """
    Fisher matrix for bilby GravitationalWaveTransient that matches bilby's per-IFO frequency_mask.
    """
    steps = {} if steps is None else dict(steps)
    relative_steps = {} if relative_steps is None else dict(relative_steps)

    ifos = likelihood.interferometers
    wfgen = likelihood.waveform_generator

    # df from frequency grid (bilby uses rfft grid with df = 1/duration)
    f = np.asarray(ifos[0].strain_data.frequency_array, float)
    df = float(f[1] - f[0])

    # Use bilby's own masks and PSD arrays per IFO
    mask_dict = {ifo.name: np.asarray(ifo.frequency_mask, bool) for ifo in ifos}
    psd_dict  = {ifo.name: np.asarray(ifo.power_spectral_density_array, float) for ifo in ifos}
  
    for ifo in ifos:
        mask_dict[ifo.name] &= (f > 0.0)

    def h_ifos(p):
        print(p)
        pols = wfgen.frequency_domain_strain(p)
        return {ifo.name: ifo.get_detector_response(pols, p) for ifo in ifos}

    default_rel = 1e-4
    default_abs = 1e-4

    derivs = {}
    p0 = dict(params)
    for key in param_keys:
        theta0 = float(p0[key])
        if key in steps:
            delta = float(steps[key])
        else:
            rel = float(relative_steps.get(key, default_rel))
            delta = rel * abs(theta0) if abs(theta0) > 0 else default_abs

        p_plus = dict(p0);  p_plus[key] = theta0 + delta
        p_minus = dict(p0); p_minus[key] = theta0 - delta

        h_plus = h_ifos(p_plus)
        h_minus = h_ifos(p_minus)

        dh = {}
        for name in h_plus:
            dh[name] = (np.asarray(h_plus[name]) - np.asarray(h_minus[name])) / (2.0 * delta)
        derivs[key] = dh

        if verbose:
            print(f"[Fisher] d/d{key}: step={delta:.3e}")

    n = len(param_keys)
    Gamma = np.zeros((n, n), float)

    for i, ki in enumerate(param_keys):
        for j, kj in enumerate(param_keys[:i+1]):
            val = 0.0
            print(f'*** {ki} {kj}')
            for ifo in ifos:
                name = ifo.name
                m = mask_dict[name]
                di = derivs[ki][name][m]
                dj = derivs[kj][name][m]
                Sn = psd_dict[name][m]
                val += _fd_inner_product_masked(di, dj, Sn, df)
            Gamma[i, j] = val
            Gamma[j, i] = val

    Cov = np.linalg.pinv(Gamma, rcond=pinv_rcond)

    sig = {k: float(np.sqrt(Cov[i, i])) if Cov[i, i] > 0 else np.nan for i, k in enumerate(param_keys)}
    try:
        cond = float(np.linalg.cond(Gamma))
    except Exception:
        cond = float("inf")

    if verbose:
        print(f"[Fisher] df={df:.6g}, cond={cond:.3e}")
        for k in param_keys:
            print(f"[Fisher] sigma({k}) ~ {sig[k]:.6g}")

    return {"fisher": Gamma, "cov": Cov, "sigma": sig, "cond": cond, "df": df}


def fisher_uncertainties_finite_difference(
    likelihood,
    params,
    param_keys,
    steps=None,
    relative_steps=None,
    fmin=None,
    fmax=None,
    pinv_rcond=1e-12,
    verbose=True,
):
    """
    Compute Fisher matrix and parameter uncertainties using finite differences.

    Parameters
    ----------
    likelihood : bilby likelihood object
        Either GravitationalWaveTransient (ground) or LisaAETFrequencyDomainLikelihood (LISA in your script).
    params : dict
        Parameter dict at which to compute the Fisher matrix (e.g., your injection).
    param_keys : list[str]
        Parameters for which to compute uncertainties.
    steps : dict[str, float], optional
        Absolute finite-difference step per parameter. If provided, overrides relative_steps for that key.
    relative_steps : dict[str, float], optional
        Relative step per parameter, delta = rel * |theta|, with a fallback if theta==0.
        If neither steps nor relative_steps are provided, a conservative default is used.
    fmin, fmax : float, optional
        Frequency bounds for the inner product.
        If None, uses any fmin/fmax already stored on the likelihood if present; otherwise no extra cut.
    pinv_rcond : float
        rcond for pseudo-inverse if Fisher is ill-conditioned.
    verbose : bool

    Returns
    -------
    result : dict
        keys:
          - 'fisher' : (n,n) Fisher matrix
          - 'cov' : (n,n) covariance matrix (pseudo-inverse)
          - 'sigma' : dict of 1-sigma uncertainties for each key
          - 'corr' : (n,n) correlation matrix
          - 'cond' : condition number of Fisher (may be inf)
          - 'df' : frequency spacing used
          - 'freq_mask' : boolean mask used
    """
    steps = {} if steps is None else dict(steps)
    relative_steps = {} if relative_steps is None else dict(relative_steps)

    # Detect type / get waveform and PSD containers
    is_ground = hasattr(likelihood, "interferometers") and hasattr(likelihood, "waveform_generator")
    is_lisa = hasattr(likelihood, "data_fd") and hasattr(likelihood, "psd_fd") and hasattr(likelihood, "waveform_model")

    if not (is_ground or is_lisa):
        raise TypeError(
            "Unsupported likelihood type. Expected GravitationalWaveTransient (ground) "
            "or LisaAETFrequencyDomainLikelihood (LISA)."
        )

    if is_ground:
        ifos = likelihood.interferometers
        wfgen = likelihood.waveform_generator
        # frequency array (assume consistent across ifos)
        f = np.asarray(ifos[0].strain_data.frequency_array, float)

        # bilby stores one-sided PSD in ifo.power_spectral_density_array on the freq grid
        psd_dict = {ifo.name: np.asarray(ifo.power_spectral_density_array, float) for ifo in ifos}

        # fallback fmin/fmax: use provided args; otherwise try ifo.minimum_frequency if present
        if fmin is None:
            fmin = getattr(ifos[0], "minimum_frequency", None)
        if fmax is None:
            fmax = getattr(ifos[0], "maximum_frequency", None)

        def h_of(p):
            return _ground_strain_fd_from_params(ifos, wfgen, p)

    else:
        f = np.asarray(likelihood.frequency_array, float)
        psd_dict = {k: np.asarray(v, float) for k, v in likelihood.psd_fd.items()}
        if fmin is None:
            fmin = getattr(likelihood, "fmin", None)
        if fmax is None:
            fmax = getattr(likelihood, "fmax", None)

        def h_of(p):
            return _lisa_strain_fd_from_params(likelihood, p)

    # Frequency spacing (assume uniform)
    if f.size < 2:
        raise ValueError("Frequency array too small for Fisher computation.")
    df = float(np.median(np.diff(f)))

    # Apply frequency mask (including removing f=0 point if present)
    m = _mask_freqs(f, fmin=fmin, fmax=fmax)
    m &= np.isfinite(f)

    # Default step heuristics: conservative and fairly stable in practice.
    # You should tune per-parameter once you see numerical noise.
    default_rel = 1e-4
    default_abs = 1e-4

    # Compute baseline waveform (not strictly needed, but useful for debug and for step fallback)
    p0 = dict(params)

    # Finite-difference derivatives for each parameter
    derivs = {}  # key -> dict[channel_or_ifo -> dh/ dtheta array]
    for key in param_keys:
        theta0 = float(p0[key])

        if key in steps:
            delta = float(steps[key])
        else:
            rel = float(relative_steps.get(key, default_rel))
            # handle theta0 ~ 0
            delta = rel * abs(theta0) if abs(theta0) > 0 else default_abs

        if delta == 0 or not np.isfinite(delta):
            raise ValueError(f"Bad finite-difference step for {key}: {delta}")

        p_plus = dict(p0);  p_plus[key] = theta0 + delta
        p_minus = dict(p0); p_minus[key] = theta0 - delta

        h_plus = h_of(p_plus)
        h_minus = h_of(p_minus)

        # central difference
        dh = {}
        for ch, hp in h_plus.items():
            hm = h_minus[ch]
            dh[ch] = (np.asarray(hp) - np.asarray(hm)) / (2.0 * delta)
        derivs[key] = dh
        print(f'key is {key}')
        print(derivs[key])

        if verbose:
            print(f"[Fisher] d/d{key}: step={delta:.3e} (theta0={theta0:.6g})")

    # Build Fisher matrix
    n = len(param_keys)
    Gamma = np.zeros((n, n), dtype=float)

    # Prepare masked PSD dict
    psd_masked = {ch: psd[m] for ch, psd in psd_dict.items()}

    for i, ki in enumerate(param_keys):
        for j, kj in enumerate(param_keys[: i + 1]):
            val = 0.0
            print(ki,kj)
            if i==j:
                # sum over detectors/channels
                for ch in psd_masked.keys():
                    di = np.asarray(derivs[ki][ch])[m]
                    dj = np.asarray(derivs[kj][ch])[m]
                    val += _fd_inner_product(di, dj, psd_masked[ch], df)
                Gamma[i, j] = val
                Gamma[j, i] = val
                print(f'Gamma({ki},{kj}) = {Gamma[i,j]}')

    print(Gamma)

    # Invert Fisher -> covariance (use pseudo-inverse for robustness)
    try:
        cond = float(np.linalg.cond(Gamma))
    except Exception:
        cond = float("inf")

    Cov = np.linalg.pinv(Gamma, rcond=pinv_rcond)

    # 1-sigma uncertainties
    sigma = {}
    for i, k in enumerate(param_keys):
        v = Cov[i, i]
        sigma[k] = float(np.sqrt(v)) if v > 0 else float("nan")

    # correlation matrix
    corr = np.zeros_like(Cov)
    for i in range(n):
        for j in range(n):
            denom = np.sqrt(Cov[i, i] * Cov[j, j])
            corr[i, j] = Cov[i, j] / denom if denom > 0 else np.nan

    if verbose:
        print(f"[Fisher] df={df:.6g}, fmin={fmin}, fmax={fmax}, masked bins={int(np.sum(m))}/{f.size}")
        print(f"[Fisher] cond(Fisher)={cond:.3e}")
        for k in param_keys:
            print(f"[Fisher] sigma({k}) ~ {sigma[k]:.6g}")

    return {
        "fisher": Gamma,
        "cov": Cov,
        "sigma": sigma,
        "corr": corr,
        "cond": cond,
        "df": df,
        "freq_mask": m,
    }


def f_guess_newtonian(M_SI, eta, dt):
    # dt in seconds
    # f0 from leading-order inversion (works well as a guess)
    return (1.0/np.pi) * ( (5.0/(256.0*eta*dt))**(3.0/8.0) ) * ( (C_SI**3/(G_SI*M_SI))**(5.0/8.0) )

def get_fmin(M, eta, dt, fmax=1024.0):
    M_SI = M * MSUN_SI

    k = (G_SI / C_SI**3) * M_SI * np.pi
    pref = (5.0 / (256.0 * eta)) * (G_SI / C_SI**3) * M_SI
    a2 = (743.0/252.0) + (11.0*eta/3.0)
    a3 = (32.0*np.pi/5.0)
    a4 = (3058673.0/508032.0) + (5429.0*eta/504.0) + (617.0/72.0)*eta**2

    def dtchirp(f):
        v = (k * f)**(1.0/3.0)
        invv2 = v**(-2.0)
        invv4 = invv2*invv2
        invv5 = invv4/v
        invv6 = invv4*invv2
        invv8 = invv4*invv4
        temp = invv8 + a2*invv6 - a3*invv5 + a4*invv4
        return pref*temp - dt

    f0 = f_guess_newtonian(M_SI, eta, dt)
    # bracket: e.g. two decades either side (tune as you like)
    flo = max(1e-6, f0/100.0)
    fhi = min(fmax, f0*100.0)

    # Ensure sign change; expand if needed (rare if guess is reasonable)
    ylo, yhi = dtchirp(flo), dtchirp(fhi)
    if ylo*yhi > 0:
        # fall back to wider bracket
        flo, fhi = 1e-6, fmax

    return brentq(dtchirp, flo, fhi, xtol=1e-6)

def filter_unused_kwargs(kwargs):
    """
    Remove parameters generated by conversion functions or internal logic 
    that are not used by the underlying LAL waveform models.
    """
    to_remove = [
        'z', 'm1_src', 'm2_src', 'm1_det', 'm2_det', 
        #'phi_12', 'phi_jl', 'tilt_1', 'tilt_2', 'a_1', 'a_2',
        'mass_diff', '_fd_source', 'reweight', 
        'source_mass_reweight', 'uniform_all_reweight',
        'waveform_approximant', 'minimum_frequency',
        'reference_frequency', 'maximum_frequency',
    ]
    return {k: v for k, v in kwargs.items() if k not in to_remove}

def frequency_taper(f, fmin, df_taper):
    W = np.ones_like(f)
    idx = (f >= fmin) & (f <= fmin + df_taper)
    W[f < fmin] = 0.0
    W[idx] = 0.5 * (1 - np.cos(np.pi * (f[idx] - fmin) / df_taper))
    return W

def highpass_filter(
    x,
    fs,
    fcut,
    order=4,
):
    """
    Zero-phase Butterworth high-pass filter.

    Parameters
    ----------
    x : array
        Time series.
    fs : float
        Sampling frequency (Hz).
    fcut : float
        High-pass cutoff frequency (Hz).
    order : int
        Filter order (4–6 is typical).

    Returns
    -------
    y : array
        High-passed signal (zero phase).
    """
    nyq = 0.5 * fs
    wn = fcut / nyq
    b, a = butter(order, wn, btype="highpass")
    return filtfilt(b, a, x)

def spins_from_lal_params(a1, tilt1, a2, tilt2, phi_12, degrees=False):

    chi1 = np.array([
        a1 * np.sin(tilt1),
        0.0,
        a1 * np.cos(tilt1),
    ])

    chi2 = np.array([
        a2 * np.sin(tilt2) * np.cos(phi_12),
        a2 * np.sin(tilt2) * np.sin(phi_12),
        a2 * np.cos(tilt2),
    ])

    return chi1, chi2

################################################################################
# Cosmology helpers
################################################################################

# Global interpolator state
_DLH0_XGRID = None   # x = dL(z; H0_ref) * H0_ref  (units: km/s)
_DLH0_ZGRID = None   # z grid corresponding to x grid
_DLH0_XMAX  = None   # max x covered


def init_redshift_interpolator(zmax=20.0, ngrid=4000):
    """
    Precompute an inverse mapping z(x) where x = dL * H0, using the current global `cosmo`.
    Call this once at program start.
    """
    global _DLH0_XGRID, _DLH0_ZGRID, _DLH0_XMAX

    H0_ref = cosmo.H0.to_value(u.km / u.s / u.Mpc)

    # z grid: dense at low z, extends to zmax
    z_log = np.logspace(-8, np.log10(zmax + 1.0), ngrid - 1) - 1.0
    z_log = np.clip(z_log, 0.0, zmax)
    z_grid = np.concatenate(([0.0], z_log))
    z_grid = np.unique(z_grid)

    # vectorized dL(z)
    dL_Mpc = cosmo.luminosity_distance(z_grid).to_value(u.Mpc)

    # x = dL * H0  (Mpc * km/s/Mpc = km/s)
    x_grid = dL_Mpc * H0_ref

    # enforce monotonicity + uniqueness (important for np.interp)
    x_grid = np.maximum.accumulate(x_grid)
    x_grid, idx = np.unique(x_grid, return_index=True)
    z_grid = z_grid[idx]

    _DLH0_XGRID = x_grid
    _DLH0_ZGRID = z_grid
    _DLH0_XMAX  = float(x_grid[-1])


def compute_redshift_from_H0_dL(H0, dL_Mpc):
    """
    Fast redshift estimate using the precomputed interpolator.
    Requires init_redshift_interpolator() to have been called once.
    """
    if _DLH0_XGRID is None:
        raise RuntimeError("Call init_redshift_interpolator(zmax, ngrid) once before using z_from_H0_dL().")

    H0_val = H0.to_value(u.km / u.s / u.Mpc) if isinstance(H0, u.Quantity) else float(H0)
    dL_val = dL_Mpc.to_value(u.Mpc) if isinstance(dL_Mpc, u.Quantity) else float(dL_Mpc)

    if H0_val <= 0 or dL_val < 0:
        raise ValueError("Require H0 > 0 and dL_Mpc >= 0.")

    x = dL_val * H0_val  # km/s

    if x > _DLH0_XMAX:
        raise ValueError(
            f"dL*H0={x:.3g} km/s is outside the interpolator range (max {_DLH0_XMAX:.3g} km/s). "
            "Increase zmax in init_redshift_interpolator()."
        )

    return float(np.interp(x, _DLH0_XGRID, _DLH0_ZGRID))

################################################################################
# FFT helper (as in newcos4-2.py)
################################################################################

def _nfft_onesided(x_td, fs):
    y = nfft(x_td, sampling_frequency=fs)
    if isinstance(y, tuple):
        Y, f = y
        Y = np.asarray(Y, np.complex128).ravel()
        f = np.asarray(f, float).ravel()
    else:
        Y = np.asarray(y, np.complex128).ravel()
        f = np.linspace(0.0, fs / 2.0, Y.size, dtype=float)
    return Y, f


################################################################################
# Waveform builders (osc only, or osc+memory) – memory implementation from newcos4-2.py
################################################################################

def make_bbh_no_memory_fd(*, 
         sampling_frequency,
         duration,
         minimum_frequency_default,
         df_taper=5.0,
         waveform_arguments=None,
         base_fd_source=None,
    ): 
    """
    FD source model with explicit signature (so bilby passes params).
    We DO NOT forward ra/dec/psi/geocent_time to lal_binary_black_hole.
    """
    if base_fd_source is None:
        base_fd_source = bilby.gw.source.lal_binary_black_hole

    fs = float(sampling_frequency)
    wf_args = dict(waveform_arguments or {})
    T = float(duration)
    df = 1.0/T
    N = int(T*fs)

    def fd_source_nomem(
        frequency_array,
        mass_1=None, mass_2=None,
        chirp_mass=None, q=None,
        H0=None,
        luminosity_distance=None, theta_jn=None,
        a_1=0.0, a_2=0.0,
        tilt_1=0.0, tilt_2=0.0,
        phi_12=0.0, phi_jl=0.0,
        phase=0.0,
        # extrinsics accepted (bilby/detector response needs them), but not used by LAL source
        psi=0.0, ra=0.0, dec=0.0, geocent_time=0.0,
        **kwargs
    ):
        #if mass_1 is None or mass_2 is None:
        q  = float(q)      # assuming q = m2/m1 <= 1
        mc = float(chirp_mass)
        m1_det = mc * q**(-3.0/5.0) * (1.0 + q)**(1.0/5.0)
        m2_det = q * m1_det
        #else:
        #    m1_det = float(mass_1)
        #    m2_det = float(mass_2)
        
        p_lal = dict(
            mass_1=m1_det,
            mass_2=m2_det,
            luminosity_distance=float(luminosity_distance),
            a_1=float(a_1), a_2=float(a_2),
            tilt_1=float(tilt_1), tilt_2=float(tilt_2),
            phi_12=float(phi_12), phi_jl=float(phi_jl),
            theta_jn=float(theta_jn),
            phase=float(phase),
            **wf_args,
            **filter_unused_kwargs(kwargs),  # Apply filter here
        )
        osc = base_fd_source(frequency_array, **p_lal)

        # compute the low frequency cut-off based on the total detcetor frame mass and observation duration
        M_det = m1_det + m2_det   # the total detector frame mass
        eta = m1_det*m2_det/M_det**2     # the symmetric mass ratio
        minimum_frequency = max(get_fmin(M_det,eta,0.75*duration),minimum_frequency_default)

        # window the osc frequency series to reduce the edge effect at the low frequency cutoff
        wf = frequency_taper(frequency_array, fmin=minimum_frequency, df_taper=df_taper)
        osc_plus = np.asarray(osc["plus"]*wf, np.complex128).ravel()
        osc_cross = np.asarray(osc["cross"]*wf, np.complex128).ravel()

        return {"plus": osc_plus, "cross": osc_cross}

    return fd_source_nomem


def make_bbh_with_gwmemory_fd(
    *,
    gwmemory_model,
    l_max,
    sampling_frequency,
    duration,
    minimum_frequency_default,
    df_taper=5.0,
    waveform_arguments=None,
    base_fd_source=None,
):
    """
    Restores the same memory construction as newcos4-2.py:
    - compute z from (H0, dL)
    - osc uses detector-frame masses
    - memory uses source-frame total mass + q + spins
    - build TD segment aligned to geocent_time
    - tau_src = tau/(1+z) and (1+z)^-2 scaling
    - nfft to FD, mask below fmin, interpolate to target frequency_array
    """
    if base_fd_source is None:
        base_fd_source = bilby.gw.source.lal_binary_black_hole

    fs = float(sampling_frequency)
    wf_args = dict(waveform_arguments or {})
    T = float(duration)
    df = 1.0/T
    N = int(T*fs)

    # time vector at the detector with t=0 at the true merger time
    #t_det = np.arange(N)/T + start_time - float(geocent_time)

    def fd_source_with_memory(
        frequency_array,
        mass_1=None, mass_2=None,
        chirp_mass=None, q=None,
        H0=None,
        luminosity_distance=None, theta_jn=None,
        a_1=0.0, a_2=0.0,
        tilt_1=0.0, tilt_2=0.0,
        phi_12=0.0, phi_jl=0.0,
        phase=0.0,
        # extrinsics (used by detector response; needed in the param dict)
        psi=0.0, ra=0.0, dec=0.0, geocent_time=0.0,
        **kwargs
    ):
        #if mass_1 is None or mass_2 is None:
        q  = float(q)      # assuming q = m2/m1 <= 1
        mc = float(chirp_mass)
        m1_det = mc * q**(-3.0/5.0) * (1.0 + q)**(1.0/5.0)
        m2_det = q * m1_det
        #else:
        #    print('IN HERE')
        #    m1_det = float(mass_1)
        #    m2_det = float(mass_2)

        p_lal = dict(
            mass_1=m1_det,
            mass_2=m2_det,
            luminosity_distance=float(luminosity_distance),
            a_1=float(a_1), a_2=float(a_2),
            tilt_1=float(tilt_1), tilt_2=float(tilt_2),
            phi_12=float(phi_12), phi_jl=float(phi_jl),
            theta_jn=float(theta_jn),
            phase=float(phase),
            **wf_args,
            **filter_unused_kwargs(kwargs),  # Apply filter here
        )
        osc = base_fd_source(frequency_array, **p_lal)

        # compute the low frequency cut-off based on the total detcetor frame mass and observation duration
        M_det = m1_det + m2_det   # the total detector frame mass
        eta = m1_det*m2_det/M_det**2     # the symmetric mass ratio
        minimum_frequency = max(get_fmin(M_det,eta,0.75*duration),minimum_frequency_default)

        # window the osc frequency series to reduce the edge effect at the low frequency cutoff
        wf = frequency_taper(frequency_array, fmin=minimum_frequency, df_taper=df_taper)
        osc_plus = np.asarray(osc["plus"]*wf, np.complex128).ravel()
        osc_cross = np.asarray(osc["cross"]*wf, np.complex128).ravel()

        # ----- z and source-frame masses -----
        z = compute_redshift_from_H0_dL(float(H0), float(luminosity_distance))
        m1_src = m1_det / (1.0 + z)
        m2_src = m2_det / (1.0 + z)

        # ----- memory TD part (source-frame masses) -----
        mem_q = (m1_det / m2_det) if m1_det >= m2_det else (m2_det / m1_det)
        M_src = m1_src + m2_src

        # convert spins - we use IMRPhenomD so only aligned spins
        S1 = [0.0,0.0,a_1]
        S2 = [0.0,0.0,a_2]

        # this returns a limited timeseries that may not be uniformly spaced
        # the t_mem vector = 0 at merger and is effectively a source frame time
        # we use the time domain version because we need to apply a redshift correction in time
        h_mem_td, t_mem = gwmemory.gwmemory.time_domain_memory(
            model=gwmemory_model,
            q=mem_q,
            total_mass=M_src,
            spin_1=S1,
            spin_2=S2,
            distance=float(luminosity_distance),
            inc=float(theta_jn),
            phase=float(phase),
            l_max=int(l_max),
        )

        # time vector at the detector spanning -T/2 to T/2
        tau_det = np.arange(N)/fs - float(duration)/2.0
        tau_src = tau_det / (1.0 + float(z))   # the equivelent time at the source (centred on merger)

        t_mem = np.asarray(t_mem, float)   # this is the time at the detector
        hp0 = np.asarray(h_mem_td["plus"], float)
        hc0 = np.asarray(h_mem_td["cross"], float)

        # edge-hold interpolation
        # this is for the redshift effect that time stretches the memory signal
        # this gives the full signal that will now have an artificial step at the start/end (t=-T/2 and T/2)
        hplus_td = np.interp(tau_src, t_mem, hp0, left=hp0[0], right=hp0[-1])
        hcross_td = np.interp(tau_src, t_mem, hc0, left=hc0[0], right=hc0[-1])

        # explicit (1+z)^-2 scaling
        # and we apply a Tukey window to avoid the discontinuity at the start and end
        # this is the NEW science
        alpha = 2.0/(duration*minimum_frequency_default)     
        wt = tukey(N,alpha)   # corresponds to a safe level for a specific low frequency cut-off 
        s = (1.0 + z) ** -2
        hplus_td *= s*wt
        hcross_td *= s*wt

        # TD -> FD
        Hp, _ = _nfft_onesided(hplus_td, fs)
        Hc, _ = _nfft_onesided(hcross_td, fs)

        # time shift to move the merger time to the start of the timeseries - this has to happen
        ramp = np.exp(-2j*np.pi*np.asarray(frequency_array, float)*0.5*T) 
        Hp *= ramp
        Hc *= ramp

        return {"plus": osc_plus + Hp, "cross": osc_cross + Hc}

    return fd_source_with_memory


################################################################################
# Reweighting
################################################################################

def _safe_prob(prior, x):
    """
    Evaluate a bilby prior's probability density safely.
    Returns 0 where evaluation is invalid/non-finite.
    """
    if prior is None:
        return np.ones_like(x, dtype=float)

    x = np.asarray(x, dtype=float)

    try:
        p = prior.prob(x)
    except Exception:
        # Some priors may not like vector inputs; try scalar fallback
        p = np.array([prior.prob(float(v)) for v in x], dtype=float)

    p = np.asarray(p, dtype=float)
    p[~np.isfinite(p)] = 0.0
    p = np.clip(p, 0.0, None)
    return p

#def _safe_prob(prior, x):
#    try:
#        p = prior.prob(x)
#        return np.asarray(p, dtype=float)
#    except Exception:
#        return np.full_like(np.asarray(x, dtype=float), np.nan, dtype=float)

def reweight_posterior_samples(
    posterior,
    original_prior=None,
    new_prior=None,
    existing_weight_column="weights",
    new_weight_column="weights_reweighted",
    jacobian=None,
    require_all_new_prior_keys=True,
    normalize=True,
    inplace=False,
):
    """
    Importance reweight posterior samples from `original_prior` to `new_prior`.

    NEW BEHAVIOUR:
    - If `new_prior` is None, assume an (improper) uniform prior over (-inf, inf)
      for all parameters: i.e., π_new is constant and cancels out, so weights
      are proportional to 1/π_old (times jacobian/existing weights).
    - If `original_prior` is None, treat π_old as constant.

    Parameters
    ----------
    posterior : pandas.DataFrame-like
        Posterior samples table.
    original_prior : bilby.core.prior.PriorDict or dict-like or None
        The prior used to generate the posterior ("old" prior). If None, treated
        as constant.
    new_prior : bilby.core.prior.PriorDict or dict-like or None
        The target prior ("new" prior). If None, treated as constant uniform
        over (-inf, inf) for all parameters.
    existing_weight_column : str or None
        If this column exists in `posterior`, multiply the importance weights by it.
        Set to None to ignore any existing weights.
    new_weight_column : str
        Column name to store the (optionally normalized) reweighted weights.
    jacobian : callable or None
        Optional multiplicative factor to include, e.g. for reparameterisations.
        Signature: jacobian(posterior) -> array-like length N.
    require_all_new_prior_keys : bool
        Only applies when `new_prior` is provided. If True, raises if any param
        in new_prior is missing from posterior. If new_prior is None, ignored.
    normalize : bool
        Normalize weights to sum to 1.
    inplace : bool
        If True, modify posterior in place; else return a copy.

    Returns
    -------
    posterior_out, weights
        posterior_out: DataFrame-like with `new_weight_column` added.
        weights: numpy.ndarray shape (N,)
    """
    n = len(posterior)
    if n == 0:
        w = np.array([], dtype=float)
        if inplace:
            posterior[new_weight_column] = w
            return posterior, w
        post2 = copy.deepcopy(posterior)
        post2[new_weight_column] = w
        return post2, w

    # Dict-like access
    old = dict(original_prior) if original_prior is not None else {}
    new = dict(new_prior) if new_prior is not None else None  # None => constant over all params

    # Start with ones
    w = np.ones(n, dtype=float)

    # If a new_prior is provided, apply ratio on those keys.
    # If new_prior is None, ratio uses π_new = const => w *= 1/π_old for keys we choose.
    if new is not None:
        keys = list(new.keys())
        missing = [k for k in keys if k not in posterior.columns]
        if missing and require_all_new_prior_keys:
            raise KeyError(
                f"Posterior is missing columns needed for new_prior: {missing}. "
                f"Either add these columns or set require_all_new_prior_keys=False."
            )

        for k, newp in new.items():
            if k not in posterior.columns:
                continue

            x = posterior[k].to_numpy(dtype=float)

            oldp = old.get(k, None)
            p_new = _safe_prob(newp, x)
            p_old = _safe_prob(oldp, x)

            ratio = np.zeros_like(p_new, dtype=float)
            m = p_old > 0
            ratio[m] = p_new[m] / p_old[m]
            w *= ratio

    else:
        # new_prior is constant (improper uniform): w ∝ 1 / π_old
        # We need to decide which parameters to include in π_old.
        # Best default: use keys from original_prior (if provided).
        if original_prior is not None and len(old) > 0:
            for k, oldp in old.items():
                if k not in posterior.columns:
                    continue
                x = posterior[k].to_numpy(dtype=float)
                p_old = _safe_prob(oldp, x)

                inv = np.zeros_like(p_old, dtype=float)
                m = p_old > 0
                inv[m] = 1.0 / p_old[m]
                w *= inv
        else:
            # both new and old are constant -> ratio is constant -> do nothing
            pass

    # Optional jacobian factor
    if jacobian is not None:
        j = np.asarray(jacobian(posterior), dtype=float)
        if j.shape != (n,):
            raise ValueError(f"jacobian(posterior) must return shape ({n},), got {j.shape}")
        j[~np.isfinite(j)] = 0.0
        j = np.clip(j, 0.0, None)
        w *= j

    # Multiply any existing weights if requested and present
    if existing_weight_column is not None and existing_weight_column in posterior.columns:
        w *= posterior[existing_weight_column].to_numpy(dtype=float)

    # Clean + normalize
    w = np.asarray(w, dtype=float)
    w[~np.isfinite(w)] = 0.0
    w = np.clip(w, 0.0, None)

    if normalize:
        s = np.sum(w)
        if s <= 0:
            w = np.ones(n, dtype=float) / n
        else:
            w = w / s

    # Write output
    if inplace:
        posterior[new_weight_column] = w
        return posterior, w
    else:
        post2 = copy.deepcopy(posterior)
        post2[new_weight_column] = w
        return post2, w

def standardise_mass_columns(
    posterior,
    z_col_candidates=("z", "redshift"),
    prefer_mass12=("mass_1", "mass_2"),
    inplace=False,
    add_q=False,
    q_col="q",
    q_definition="m2_over_m1_leq1",  # default: q in (0,1]
):
    """
    Ensure posterior has consistent columns:
      - z
      - m1_det, m2_det (detector frame)
      - m1_src, m2_src (source frame)
    and optionally add a mass-ratio column q.

    Parameters
    ----------
    posterior : pandas.DataFrame-like
    z_col_candidates : tuple[str]
        Candidate redshift column names to look for.
    prefer_mass12 : tuple[str,str]
        Column names for bilby-style detector-frame masses (mass_1, mass_2).
    inplace : bool
        If True, modify and return the same object; else return a copy.
    add_q : bool
        If True, add `q_col` derived from masses if not already present.
    q_col : str
        Name of the q column to create.
    q_definition : str
        - "m2_over_m1_leq1": q = min(m1,m2)/max(m1,m2) in (0,1]
        - "m1_over_m2_geq1": q = max(m1,m2)/min(m1,m2) in [1, inf)

    Returns
    -------
    post, info : (DataFrame-like, dict)
    """
    post = posterior if inplace else copy.deepcopy(posterior)

    # ---- find / standardise z ----
    z_col = None
    for c in z_col_candidates:
        if c in post.columns:
            z_col = c
            break
    if z_col is None:
        raise KeyError(f"Need redshift samples. None of {z_col_candidates} found in posterior.")

    if z_col != "z":
        post["z"] = post[z_col].to_numpy(dtype=float)
    else:
        post["z"] = post["z"].to_numpy(dtype=float)

    z = post["z"].to_numpy(dtype=float)
    onepz = 1.0 + z

    # ---- detector-frame masses: prefer m1_det/m2_det, else use mass_1/mass_2 ----
    have_mdet = ("m1_det" in post.columns) and ("m2_det" in post.columns)
    have_mass12 = (prefer_mass12[0] in post.columns) and (prefer_mass12[1] in post.columns)

    if not have_mdet and have_mass12:
        post["m1_det"] = post[prefer_mass12[0]].to_numpy(dtype=float)
        post["m2_det"] = post[prefer_mass12[1]].to_numpy(dtype=float)

    # Back-fill aliases if useful
    if "m1_det" in post.columns and prefer_mass12[0] not in post.columns:
        post[prefer_mass12[0]] = post["m1_det"].to_numpy(dtype=float)
    if "m2_det" in post.columns and prefer_mass12[1] not in post.columns:
        post[prefer_mass12[1]] = post["m2_det"].to_numpy(dtype=float)

    # ---- source-frame masses: compute if missing and det+z exists; or compute det if missing and src+z exists ----
    have_msrc = ("m1_src" in post.columns) and ("m2_src" in post.columns)
    have_mdet = ("m1_det" in post.columns) and ("m2_det" in post.columns)

    if not have_msrc and have_mdet:
        post["m1_src"] = post["m1_det"].to_numpy(dtype=float) / onepz
        post["m2_src"] = post["m2_det"].to_numpy(dtype=float) / onepz
        have_msrc = True

    if not have_mdet and have_msrc:
        post["m1_det"] = post["m1_src"].to_numpy(dtype=float) * onepz
        post["m2_det"] = post["m2_src"].to_numpy(dtype=float) * onepz
        have_mdet = True
        if prefer_mass12[0] not in post.columns:
            post[prefer_mass12[0]] = post["m1_det"].to_numpy(dtype=float)
        if prefer_mass12[1] not in post.columns:
            post[prefer_mass12[1]] = post["m2_det"].to_numpy(dtype=float)

    if not have_msrc:
        raise KeyError("Could not construct m1_src/m2_src. Need either (m1_src,m2_src) or (m1_det,m2_det)+z.")
    if not have_mdet:
        raise KeyError("Could not construct m1_det/m2_det. Need either (m1_det,m2_det) or (m1_src,m2_src)+z.")

    # ---- Option A: add q if requested ----
    if add_q and (q_col not in post.columns):
        # Use source-frame masses (q is invariant under (1+z) scaling anyway)
        m1 = post["m1_src"].to_numpy(dtype=float)
        m2 = post["m2_src"].to_numpy(dtype=float)

        # Guard against zeros / negatives (shouldn't happen, but be safe)
        m1 = np.asarray(m1, float)
        m2 = np.asarray(m2, float)

        hi = np.maximum(m1, m2)
        lo = np.minimum(m1, m2)

        q = np.zeros_like(hi, dtype=float)
        ok = (hi > 0) & np.isfinite(hi) & np.isfinite(lo)

        if q_definition == "m2_over_m1_leq1":
            # q in (0,1]
            q[ok] = lo[ok] / hi[ok]
        elif q_definition == "m1_over_m2_geq1":
            # q in [1, inf)
            q[ok] = hi[ok] / np.maximum(lo[ok], 1e-300)
        else:
            raise ValueError(f"Unknown q_definition='{q_definition}'")

        q[~np.isfinite(q)] = 0.0
        q = np.clip(q, 0.0, None)
        post[q_col] = q

    info = {
        "z_col_used": z_col,
        "has_m1_det": "m1_det" in post.columns,
        "has_m2_det": "m2_det" in post.columns,
        "has_m1_src": "m1_src" in post.columns,
        "has_m2_src": "m2_src" in post.columns,
        "has_mass_1": prefer_mass12[0] in post.columns,
        "has_mass_2": prefer_mass12[1] in post.columns,
        "added_q": bool(add_q),
        "q_col": q_col if add_q else None,
        "q_definition": q_definition if add_q else None,
    }
    return post, info

def jacobian_det_to_src(posterior):
    z = posterior["z"].to_numpy(dtype=float)
    return 1.0 / (1.0 + z)**2

def jacobian_det_to_m1q_src(posterior):
    z = posterior["z"].to_numpy(dtype=float)
    m1_src = posterior["m1_src"].to_numpy(dtype=float)
    onepz = 1.0 + z
    # safe: avoid divide by zero
    j = 1.0 / (onepz**2 * m1_src)
    j[~np.isfinite(j)] = 0.0
    return j

################################################################################
# SNR selection
################################################################################

def apply_snr_selection_bias(posterior, likelihood, snr_threshold, weight_cols, npool=1):
    posterior = copy.deepcopy(posterior)
    bilby.gw.conversion.compute_snrs(posterior, likelihood, npool=npool)

    if "network_optimal_snr" in posterior.columns:
        net = np.abs(posterior["network_optimal_snr"].to_numpy())
    elif "optimal_snr" in posterior.columns:
        net = np.abs(posterior["optimal_snr"].to_numpy())
    else:
        ifo_cols = [c for c in posterior.columns if c.endswith("_optimal_snr")]
        if len(ifo_cols) == 0:
            raise RuntimeError("No optimal SNR columns after compute_snrs.")
        s2 = np.zeros(len(posterior), dtype=float)
        for c in ifo_cols:
            s2 += np.abs(posterior[c].to_numpy()) ** 2
        net = np.sqrt(s2)

    mask = net > float(snr_threshold)
    post_f = posterior.loc[mask].reset_index(drop=True)
    print(f"SNR selection: {len(posterior)} -> {len(post_f)} (thr={snr_threshold})")

    for wc in weight_cols:
        if wc in post_f.columns:
            w = post_f[wc].to_numpy(dtype=float)
            w = np.clip(w, 0.0, None)
            s = np.sum(w)
            post_f[wc] = (np.ones_like(w) / len(w)) if s <= 0 else (w / s)

    return post_f


################################################################################
# Plotting Utility (Time Domain)
################################################################################

def plot_time_domain_data(
    outdir,
    label,
    detectors,
    parameters,
    sampling_frequency,
    duration,
    start_time,
    minimum_frequency,
    fd_source_full,
    fd_source_nomem,
    data_fd_dict,
    psd_fd_dict,
    is_lisa=False,
    lisa_wrapper=None
):
    """
    Plots TD data (Noise + Signal) for each detector.
    Splits signal into 'Oscillatory' and 'Memory' if memory is present.
    """
    print("Generating Time Domain Plots...")
    
    # Setup Figure
    n_det = len(detectors)
    fig, axes = plt.subplots(n_det, 2, figsize=(20, 4 * n_det), sharex=True)
    #if n_det == 1:
    #    axes = [axes]

    # Time array
    N = int(round(duration * sampling_frequency))
    dt = 1.0 / sampling_frequency
    df = 1.0/duration
    freqs = df*np.arange(N//2 + 1) 
    # Usually aligned so t=0 is start, or centered. Bilby aligns roughly to start_time.
    # We will just plot vs relative time (0 to T) or centered on trigger if we knew it exactly here.
    # For simplicity, 0 to T.
    time = np.arange(N) * dt

    # 1. Generate FULL Model (Osc + Mem)
    # ----------------------------------
    if is_lisa:

        # A/E/T model builder using wrapper + bilby-style params dict
        def get_lisa_td(params, fd_source):
            # Ensure required keys exist for fd_source signature
            p = dict(params)

            # Convert equatorial ra/dec -> ecliptic lam/beta for fastlisaresponse
            lam, beta = equatorial_to_ecliptic_lambda_beta(float(p.get("ra", 0.0)), float(p.get("dec", 0.0)))

            p["_fd_source"] = fd_source

            # Apply time shift relative to wrapper's t0 via bilby-style geocent_time parameter:
            # fastlisaresponse is being given an FD polarization series; to shift by Δt we can phase-ramp those polarizations.
            # We do this by wrapping fd_source with a phase ramp if geocent_time differs from trigger_time.
            dt_shift = float(p.get("geocent_time", 0.0)) - float(start_time)
            #print(f'in analysis - dt_shift = {dt_shift}')
            #if dt_shift != 0.0:
            def _fd_source_shifted(farr, **pp):
                print('_fd_source_shifted: within get_lisa_td')
                pol = fd_source(farr, **pp)
                ramp = np.exp(-2j*np.pi*np.asarray(farr, float)*dt_shift)
                return {"plus": np.asarray(pol["plus"], np.complex128)*ramp,
                    "cross": np.asarray(pol["cross"], np.complex128)*ramp}
            p["_fd_source"] = _fd_source_shifted

            # TD A/E/T - outputs in the time domain
            return lisa_wrapper(lam,beta,p) # requires ecliptic coords as first and second args

        # Full Signal - TIME domain
        aet_full = get_lisa_td(parameters, fd_source_full) # returns [A_td, E_td, T_td] arrays
        signal_full = {"A": aet_full[0], "E": aet_full[1], "T": aet_full[2]}

        # Whiten the signal 
        signal_full_wt = {}
        for name in signal_full:
            temp, _ = _nfft_onesided(signal_full[name],sampling_frequency)
            signal_full_wt[name] = np.sqrt(2.0*dt)*infft(temp/np.sqrt(psd_fd_dict[name]), sampling_frequency)
            
        aet_osc = get_lisa_td(parameters, fd_source_nomem)
        signal_osc = {"A": aet_osc[0], "E": aet_osc[1], "T": aet_osc[2]}
        signal_osc_wt = {}
        for name in signal_osc:
            temp, _ = _nfft_onesided(signal_osc[name],sampling_frequency)
            signal_osc_wt[name] = np.sqrt(2.0*dt)*infft(temp/np.sqrt(psd_fd_dict[name]), sampling_frequency) 

        # Data (FD -> TD)
        # The data_fd passed is already (Signal + Noise) in FD.
        data_td = {}
        data_td_wt = {}
        for k in detectors:
            # irfft for real data
            data_td[k] = infft(data_fd_dict[k], sampling_frequency)
            data_td_wt[k] = np.sqrt(2.0*dt)*infft(data_fd_dict[k]/np.sqrt(psd_fd_dict[k]), sampling_frequency)

    else:
        # Ground Based
        # We need to project polarizations to detectors
        wf_gen = bilby.gw.waveform_generator.WaveformGenerator(
            duration=duration,
            sampling_frequency=sampling_frequency,
            frequency_domain_source_model=fd_source_full,
            parameter_conversion=None,
            waveform_arguments=parameters.get("waveform_arguments", {})
        )

        # Full Pols
        pols_full = wf_gen.frequency_domain_strain(parameters)
        
        # Osc Pols
        # Temporarily swap source
        orig_source = wf_gen.frequency_domain_source_model
        
        def fd_source_osc_only(f, **k):
             k_filt = filter_unused_kwargs(k)
             return fd_source_nomem(f, **k_filt)
             #return bilby.gw.source.lal_binary_black_hole(f, **k_filt)
        
        wf_gen.frequency_domain_source_model = fd_source_osc_only
        pols_osc = wf_gen.frequency_domain_strain(parameters)
        wf_gen.frequency_domain_source_model = orig_source # restore

        signal_full = {}
        signal_osc = {}
        signal_full_wt = {}
        signal_osc_wt = {}
        data_td = {}
        data_td_wt = {}        

        for det in detectors: # detectors is list of Interferometer objects
            # Full Signal TD
            # get_detector_response returns FD strain, we need to IFFT
            sig_fd_full = det.get_detector_response(pols_full, parameters)
            signal_full[det.name] = infft(sig_fd_full, sampling_frequency)
            signal_full_wt[det.name] = infft(sig_fd_full/np.sqrt(det.power_spectral_density_array), sampling_frequency)

            # Osc Signal TD
            sig_fd_osc = det.get_detector_response(pols_osc, parameters)
            signal_osc[det.name] = infft(sig_fd_osc, sampling_frequency)
            signal_osc_wt[det.name] = infft(sig_fd_osc/np.sqrt(det.power_spectral_density_array), sampling_frequency)

            f = det.strain_data.frequency_array
            mask = det.frequency_mask          # boolean array, same length as f

            f_used = f[mask]
            print("Effective fmin used:", f_used[0])
            print("Effective fmax used:", f_used[-1])
            print("Number of bins used:", f_used.size, "of", f.size) 

            print("ifo.minimum_frequency:", det.minimum_frequency)
            print("ifo.maximum_frequency:", det.maximum_frequency)

            # Sometimes also present depending on version / setup:
            print("strain_data min/max:",
                getattr(det.strain_data, "minimum_frequency", None),
                getattr(det.strain_data, "maximum_frequency", None))

            psd = det.power_spectral_density_array
            bad = ~np.isfinite(psd) | (psd <= 0)

            print("Lowest finite-PSD frequency:", f[~bad][0])
            print("Lowest finite-PSD & mask frequency:", f[mask & ~bad][0])

            # Data TD (Data is already in det object)
            h = det.time_domain_strain
            data_td[det.name] = highpass_filter(h, fs=sampling_frequency, fcut=det.minimum_frequency)
            temp,_ = _nfft_onesided(data_td[det.name], sampling_frequency)
            data_td_wt[det.name] = infft(temp/np.sqrt(det.power_spectral_density_array), sampling_frequency)
            
            # normalise whitened data to have unit variance in the noise
            data_td_wt[det.name] /= np.sqrt(0.5*sampling_frequency)
            signal_osc_wt[det.name] /= np.sqrt(0.5*sampling_frequency)
            signal_full_wt[det.name] /= np.sqrt(0.5*sampling_frequency) 

            t_merger = (
                    parameters["geocent_time"]
                    + det.time_delay_from_geocenter(
                    parameters["ra"], parameters["dec"],parameters['geocent_time']
                    )
            )

            fs = det.strain_data.sampling_frequency
            start = det.strain_data.start_time
            N = len(det.strain_data.time_domain_strain)

            # 1) expected index from timing
            t_arrive = float(parameters["geocent_time"]) + det.time_delay_from_geocenter(
                parameters["ra"], parameters["dec"], parameters["geocent_time"]
            )
            idx_expected = int(round((t_arrive - start) * fs))

            # 2) where the time-domain peak actually is (rough proxy for "merger")
            ht = det.strain_data.time_domain_strain
            idx_peak = int(np.argmax(np.abs(ht)))

            print("start_time:", start)
            print("t_arrive:", t_arrive)
            print("idx_expected:", idx_expected, " -> t =", start + idx_expected/fs)
            print("idx_peak:", idx_peak, " -> t =", start + idx_peak/fs)
            print("delta samples:", idx_peak - idx_expected, "delta sec:", (idx_peak-idx_expected)/fs)

    # Plotting Loop
    for i, det_key in enumerate(detectors):
        ax = axes[i]
        
        if is_lisa:
            name = det_key
            s_full = signal_full[name]
            s_osc = signal_osc[name]
            d_data = data_td[name]
            s_full_wt = signal_full_wt[name]
            s_osc_wt = signal_osc_wt[name]
            d_data_wt = data_td_wt[name]
            before = 500
            after = 250
            dt_zoom = 250
            t_merger = (3.0/4.0)*duration
            t_merger_det = t_merger
        else:
            name = det_key.name
            s_full = signal_full[name]
            s_osc = signal_osc[name]
            s_full_wt = signal_full_wt[name]
            s_osc_wt = signal_osc_wt[name]
            d_data = data_td[name]
            d_data_wt = data_td_wt[name]
            before = 0.2
            after = 0.1        
            t_merger_det = (
                    parameters["geocent_time"]
                    + det.time_delay_from_geocenter(
                    parameters["ra"], parameters["dec"],parameters['geocent_time']
                    )
            ) - det.strain_data.start_time
            t_merger = parameters["geocent_time"] - det.strain_data.start_time
            dt_zoom = 0.25

        # Calculate Memory Component
        s_mem = s_full - s_osc
        s_mem_wt = s_full_wt - s_osc_wt        

        # Plot Data (Noise + Signal)
        ax[0].plot(time, d_data, color='grey', alpha=0.2, label='Data (Noise+Signal)')
   
        # Plot Signals
        ax[0].plot(time, s_osc, color='tab:blue', label='Oscillatory', linewidth=1.5)

        # Check if memory is significant enough to plot (or just plot it)
        if np.max(np.abs(s_mem)) > 0:
            ax[0].plot(time, s_mem, color='tab:orange', label='Memory', linewidth=1.5)
        
        ax[0].set_ylabel(f"{name} Strain")
        if i == 0:
            ax[0].legend(loc='upper right', framealpha=0.9)
         
        # inset
        axins = inset_axes(ax[0], width="35%", height="70%", loc="center", borderpad=2.)
        axins.plot(time, d_data, color='grey', linewidth=1.5, alpha=0.2)
        axins.plot(time, s_osc, color='tab:blue', linewidth=1.5)
        axins.plot(time, s_mem, color='tab:orange', linewidth=1.5)

        # zoom window (same in all panels)
        axins.set_xlim(t_merger - dt_zoom, t_merger + dt_zoom)
        axins.set_ylim(-1.2*np.max(np.abs(s_mem)), 1.2*np.max(np.abs(s_mem)))

        axins.tick_params(labelsize=7)

        mark_inset(ax[0], axins, loc1=2, loc2=4, fc="none", ec="0.5")
 
        # Plot Data (Noise + Signal)
        ax[1].plot(time, d_data_wt, color='grey', alpha=0.2, label='Data (Noise+Signal)')

        # Plot Whitened Signals
        ax[1].plot(time, s_osc_wt, color='tab:blue', label='Oscillatory', linewidth=1.5)

        # Check if memory is significant enough to plot (or just plot it)
        if np.max(np.abs(s_mem_wt)) > 0:
            ax[1].plot(time, s_mem_wt, color='tab:orange', label='Memory', linewidth=1.5)

        # inset
        axins = inset_axes(ax[1], width="35%", height="70%", loc="center")
        axins.plot(time, d_data_wt, color='grey', linewidth=1.5, alpha=0.2)
        axins.plot(time, s_osc_wt, color='tab:blue', linewidth=1.5)
        axins.plot(time, s_mem_wt, color='tab:orange', linewidth=1.5)
        print(f'****** {np.std(d_data_wt)}')

        # zoom window (same in all panels)
        axins.set_xlim(t_merger_det - dt_zoom, t_merger_det + dt_zoom)
        axins.set_ylim(-1.2*np.max(np.abs(s_mem_wt)), 1.2*np.max(np.abs(s_mem_wt)))

        axins.tick_params(labelsize=7)

        mark_inset(ax[1], axins, loc1=2, loc2=4, fc="none", ec="0.5")

        ax[1].set_ylabel(f"{name} Whitened Strain")
        if i == 0:
            ax[1].legend(loc='upper right', framealpha=0.9)
    
        # Optional: Zoom in on trigger time
        # center = parameters['geocent_time']
        # Instead of absolute time, let's zoom to where the max amplitude is
        #idx_max = np.argmax(np.abs(s_osc))
        #t_center = time[idx_max]
        ax[0].set_xlim(0, duration) # Zoom window around merger
        ax[1].set_xlim(0, duration)
        ax[1].set_ylim(-5,5)

    axes[-1,0].set_xlabel("Time (s)")
    axes[-1,1].set_xlabel("Time (s)")
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, f"{label}_td_data_plot.png"), dpi=200)
    plt.close()
    print(f"Time domain plot saved to {os.path.join(outdir, f'{label}_td_data_plot.png')}")


################################################################################
# LISA utilities (fastlisaresponse + simple A/E/T likelihood)
################################################################################

MPC_SI = 3.0856775814913673e22  # m

def equatorial_to_ecliptic_lambda_beta(ra, dec):
    """
    Convert equatorial (ICRS) sky coords (ra, dec) [rad] to ecliptic longitude/latitude [rad].
    Uses Astropy if available; otherwise uses a J2000 mean-obliquity rotation.
    """
    try:
        from astropy.coordinates import SkyCoord, BarycentricTrueEcliptic
        import astropy.units as u
        c = SkyCoord(ra=float(ra)*u.rad, dec=float(dec)*u.rad, frame="icrs")
        e = c.transform_to(BarycentricTrueEcliptic())
        lam = float(e.lon.to(u.rad).value)
        beta = float(e.lat.to(u.rad).value)
        return lam, beta
    except Exception:
        eps = np.deg2rad(23.4392911)
        x = np.cos(dec) * np.cos(ra)
        y = np.cos(dec) * np.sin(ra)
        z = np.sin(dec)
        y2 = y*np.cos(eps) + z*np.sin(eps)
        z2 = -y*np.sin(eps) + z*np.cos(eps)
        lam = np.arctan2(y2, x) % (2*np.pi)
        beta = np.arcsin(z2 / np.sqrt(x*x + y2*y2 + z2*z2))
        return float(lam), float(beta)

def generate_colored_noise_fd(psd, df, rng):
    psd = np.asarray(psd, float)
    sigma = np.sqrt(0.25 * psd / df)   # has units 
    n_f = sigma*(rng.normal(size=psd.size) + 1j * rng.normal(size=psd.size))

    # Enforce real-valued time series
    n_f[0] = 0.0
    if psd.size % 2 == 0:
        n_f[-1] = n_f[-1].real + 0.0j
    return n_f


class LisaAETFrequencyDomainLikelihood(bilby.core.likelihood.Likelihood):
    """
    Minimal FD Gaussian likelihood for independent LISA TDI channels A/E/T.
    """
    def __init__(self, frequency_array, data_fd, psd_fd, waveform_model, fmin=1e-4, fmax=None):
        super().__init__(parameters={})
        self.frequency_array = np.asarray(frequency_array, float)
        self.data_fd = {k: np.asarray(v, np.complex128) for k, v in data_fd.items()}
        self.psd_fd = {k: np.asarray(v, float) for k, v in psd_fd.items()}
        self.waveform_model = waveform_model
        self.df = float(self.frequency_array[1] - self.frequency_array[0])
        self.fmin = float(fmin)
        self.fmax = float(fmax) if fmax is not None else float(self.frequency_array[-1])
        self._mask = (self.frequency_array >= self.fmin) & (self.frequency_array <= self.fmax)

    def log_likelihood(self):
        model = self.waveform_model(self.parameters, self.frequency_array)
        m = self._mask
        logL = 0.0
        for k in ["A", "E", "T"]:
            d = self.data_fd[k][m]
            h = np.asarray(model[k], np.complex128)[m]
            Sn = self.psd_fd[k][m]
            logL += -2.0 * self.df * np.sum((np.abs(d - h)**2) / Sn)
        #print(f'logL = {float(logL)}')
        return float(logL)


def make_fastlisaresponse_wrapper(duration, sampling_frequency, t0_abs, tdi_chan="AET", waveform_arguments=None):
    """
    Construct a ResponseWrapper that maps (h+ + i hx) -> TDI channels in time domain.
    """
    from fastlisaresponse import ResponseWrapper
    from lisatools.detector import EqualArmlengthOrbits
    from fastlisaresponse.utils.parallelbase import ParallelModuleBase
    fs = float(sampling_frequency)   # the samping frequency
    dt = 1.0/fs                      # the sampling time
    df = 1.0/duration                # the frequency resolution
    N = int(round(duration * fs))    # the number of time samples
    T_yrs = (N * dt) / YRSID_SI      # the observation time in years
    frequency_array = np.fft.rfftfreq(N, dt)    # the frequency array

    class _PolarizationTD(ParallelModuleBase):
        def __init__(self, waveform_arguments=None):
            super().__init__(force_backend=None)
            self.waveform_arguments = dict(waveform_arguments or {})
            self.frequency_array = frequency_array

        @classmethod
        def supported_backends(cls):
            return ["fastlisaresponse_cpu"]

        def __call__(self, *args, **kwargs):
            # should only be sent a dictionary of source parameters  
            if len(args) != 1:
                raise TypeError(f"_PolarizationFD expected >=3 positional args, got {len(args)}")

            params = args[-1]

            # Keep only args relevant to the waveform call (this removes "_fd_source" etc.)
            source_params = filter_unused_kwargs(params)

            # Add waveform configuration (fmin/fmax/fref/approximant/etc.) if needed
            if self.waveform_arguments:
                for k, v in self.waveform_arguments.items():
                    source_params.setdefault(k, v)

            # Call the selected FD source (shifted or not) WITHOUT passing "_fd_source" as a kwarg
            pols = params["_fd_source"](self.frequency_array, **source_params)

            # convert to time domain - should now be dimensionless
            hp = infft(np.asarray(pols["plus"], np.complex128),sampling_frequency=fs)
            hc = infft(np.asarray(pols["cross"], np.complex128),sampling_frequency=fs)

            return hp + 1j*hc   # return complex time domain signal

    # define the call to the wrapper thhat will return the AET variables
    wrapper = ResponseWrapper(
        _PolarizationTD(waveform_arguments=dict(waveform_arguments)),  # this function returns a time domain signal but uses a frequency domain waveform that we iFFT
        T_yrs,   # observation time in years
        dt,      # sampling time in seconds
        index_lambda=0, # the index of the lambda sky parameter in the args supplied to _PolarisationFD
        index_beta=1,   # the index of the beta sky parameter in the args supplied to _PolarisationFD
        t0=float(t0_abs),  # time at which signal starts (chops off data at start of waveform where information is not correct)
        flip_hx=False,
        force_backend=None,
        remove_sky_coords=True,
        is_ecliptic_latitude=True,
        remove_garbage="zero",
        orbits=EqualArmlengthOrbits(),
        order=25,
        tdi="2nd generation",
        tdi_chan=tdi_chan,
    )
    return wrapper, N, dt


def run_lisa_analysis(
    args,
    start_time,   # GPS time of observation start 
    sampling_frequency,
    duration,
    minimum_frequency,
    maximum_frequency,
    samp_priors,
    inf_priors,
    SNR_THRESHOLD,
    fd_source,
    waveform_generator,
    waveform_arguments,
    detstring=None,
):
    """
    LISA branch:
    - generate (osc + memory) polarizations with LAL + gwmemory (fd_source)
    - apply fastlisaresponse to get A/E/T TD
    - build FD likelihood for A/E/T with analytic PSD
    - run bilby sampler
    """
    rng = np.random.default_rng(int(args.seed))

    # FastLISA response wrapper fixed at trigger_time; geocent_time enters as a phase ramp in FD response via bilby-style parameter.
    # the wrapper function takes as INPUT a function that provides a complex TIME domain waveform h+ + ihx
    # the wrapper function returns the TIME domain A, E, and T channel response
    wrapper, N, dt = make_fastlisaresponse_wrapper(
        duration,
        sampling_frequency,
        t0_abs=float(duration/8.0), # time at which signal starts (chops off data at start of waveform where information is not correct)
        tdi_chan="AET",
        waveform_arguments=waveform_arguments
    )
    freqs = np.fft.rfftfreq(N, dt)
    df = freqs[1] - freqs[0]

    # Build analytic LISA noise model
    # wd=0 disables confusion noise (white-dwarf foreground)
    # t_obs can matter for confusion models; for pure instrumental it's not critical
    noise = AnalyticNoise(frq=freqs, wd=0)

    SnA = noise.psd(freq=freqs, option="A")  # 1/Hz
    SnE = noise.psd(freq=freqs, option="E")  # 1/Hz
    SnT = noise.psd(freq=freqs, option="T")  # 1/Hz
    SnA[0] = np.inf
    SnE[0] = np.inf
    SnT[0] = np.inf

    # Noise PSDs on this grid
    psd_fd = {"A": SnA, "E": SnE, "T": SnT}

    # A/E/T model builder using wrapper + bilby-style params dict
    def lisa_model_fd(params, frequency_array):
        # Ensure required keys exist for fd_source signature
        p = dict(params)

        # Convert equatorial ra/dec -> ecliptic lam/beta for fastlisaresponse
        lam, beta = equatorial_to_ecliptic_lambda_beta(float(p.get("ra", 0.0)), float(p.get("dec", 0.0)))

        # will generate a frequency domain waveform with merger at t=0
        # this is used by the class _PolarizationTD when generating the TIME domain complex h+ + ihx 
        p["_fd_source"] = fd_source

        # Apply time shift relative to wrapper's t0 via bilby-style geocent_time parameter:
        # fastlisaresponse is being given an FD polarization series; to shift by Δt we can phase-ramp those polarizations.
        # We do this by wrapping fd_source with a phase ramp if geocent_time differs from trigger_time.
        dt_shift = float(p.get("geocent_time", 0.0)) - float(start_time)
        def _fd_source_shifted(farr, **pp):
            pol = fd_source(farr, **pp)
            ramp = np.exp(-2j*np.pi*np.asarray(farr, float)*dt_shift)
            return {"plus": np.asarray(pol["plus"], np.complex128)*ramp,
                    "cross": np.asarray(pol["cross"], np.complex128)*ramp}
        p["_fd_source"] = _fd_source_shifted
        
        # TD A/E/T - real outputs in the TIME domain already shifted to correct arrival time
        aet_td = wrapper(lam,beta,p) # requires ecliptic coords as first and second args

        # Convert to FD on our rfft grid
        A_fd, _ = _nfft_onesided(aet_td[0],sampling_frequency) 
        E_fd, _ = _nfft_onesided(aet_td[1],sampling_frequency) 
        T_fd, _ = _nfft_onesided(aet_td[2],sampling_frequency) 

        return {"A": A_fd, "E": E_fd, "T": T_fd}

    # draw extrinsics fixed (as in original)
    while True:
        inj = samp_priors.sample(1)
        print(f'inside generation loop inj {inj}')

        # compute some alternative parameterisations 
        inj["z"] = compute_redshift_from_H0_dL(float(inj["H0"]), float(inj["luminosity_distance"]))
        inj["m2_src"] = float(inj["q"]) * float(inj["m1_src"])
        inj["mass_1"] = float(inj["m1_src"]) * (1.0 + float(inj["z"]))   # mass_1 and mass_2 refer to detector frame masses
        inj["mass_2"] = float(inj["m2_src"]) * (1.0 + float(inj["z"]))
        inj["m1_det"] = inj["mass_1"]
        inj["m2_det"] = inj["mass_2"]
        inj["q"] = float(inj["mass_2"])/float(inj["mass_1"])
        inj["chirp_mass"] = (float(inj["mass_1"])*float(inj["mass_2"]))**(3.0/5.0) / (float(inj["mass_1"])+float(inj["mass_2"]))**(1.0/5.0)

        # --- sanity: ensure the waveform merges within the band ---
        Mtot_det = float(inj["m1_det"] + inj["m2_det"])
        M_sec = (G_SI * MSUN_SI * Mtot_det) / (C_SI**3)
        f_isco = 1.0 / (6.0**1.5 * np.pi * M_sec)
        if not np.isfinite(f_isco) or (f_isco <= minimum_frequency) or (f_isco >= maximum_frequency):
            print(f'ERROR: f_isco not in band. f_isco={f_isco} min_freq={minimum_frequency} max_freq={maximum_frequency}')
            exit(0)

        # make signal FD - returns A,E, and T variables in the frequency domain
        hfd = lisa_model_fd(inj, freqs)

        # compute network SNR (A/E/T)
        snr2 = 0.0
        for k, Sn in [("A", SnA), ("E", SnE), ("T", SnT)]:
            m = (freqs >= minimum_frequency) & (freqs > 0)
            snr2 += 4.0 * np.sum((np.abs(hfd[k][m])**2) / Sn[m]) * df
        net_snr = float(np.sqrt(snr2))
        print(f"[LISA] drew injection net SNR ~ {net_snr:.6f}")
        if net_snr >= SNR_THRESHOLD:
            inj["net_opt_snr"] = net_snr
            break

    # build data = signal + noise (FD)
    data_fd = {}
    if args.zero_noise:
        print('***ZERO NOISE***')
        for k in ["A","E","T"]:
            data_fd[k] = np.asarray(hfd[k], np.complex128)
    else:
        for k, Sn in [("A", SnA), ("E", SnE), ("T", SnT)]:
            nfd = generate_colored_noise_fd(Sn, df, rng)
            data_fd[k] = np.asarray(hfd[k], np.complex128) + nfd

    likelihood = LisaAETFrequencyDomainLikelihood(
        frequency_array=freqs,
        data_fd=data_fd,
        psd_fd=psd_fd,
        waveform_model=lisa_model_fd,
        fmin=minimum_frequency,
        fmax=float(freqs[-1]),
    )

    return likelihood, inj, ["A","E","T"], data_fd, psd_fd, wrapper

def run_ground_analysis(
    args,
    start_time,   # GPS time of observation start
    sampling_frequency,
    duration,
    minimum_frequency,
    maximum_frequency,
    samp_priors,
    inf_priors,
    SNR_THRESHOLD,
    fd_source,
    waveform_generator,
    waveform_arguments,
    detstring=None,
):

    ifos = bilby.gw.detector.InterferometerList(detstring)

    psds_fd = {}
    psds_freqs = {}
    for ifo in ifos:
        # Get the frequency array
        psds_freqs[ifo.name] = ifo.power_spectral_density.frequency_array
        # Get the PSD array
        psds_fd[ifo.name] = ifo.power_spectral_density.psd_array

    # Draw until SNR threshold
    while True:

        # generate random signal params from sample prior
        inj = samp_priors.sample(1)

        # compute some alternative parameterisations
        inj["z"] = compute_redshift_from_H0_dL(float(inj["H0"]), float(inj["luminosity_distance"]))
        inj["m2_src"] = float(inj["q"]) * float(inj["m1_src"])
        inj["mass_1"] = float(inj["m1_src"]) * (1.0 + float(inj["z"]))   # mass_1 and mass_2 refer to detector frame masses
        inj["mass_2"] = float(inj["m2_src"]) * (1.0 + float(inj["z"]))
        inj["m1_det"] = inj["mass_1"]
        inj["m2_det"] = inj["mass_2"]
        inj["q"] = float(inj["mass_2"])/float(inj["mass_1"])
        inj["chirp_mass"] = (float(inj["mass_1"])*float(inj["mass_2"]))**(3.0/5.0) / (float(inj["mass_1"])+float(inj["mass_2"]))**(1.0/5.0)

        # --- sanity: ensure the waveform merges within the band ---
        Mtot_det = float(inj["m1_det"] + inj["m2_det"])
        M_sec = (G_SI * MSUN_SI * Mtot_det) / (C_SI**3)
        f_isco = 1.0 / (6.0**1.5 * np.pi * M_sec)
        if not np.isfinite(f_isco) or (f_isco <= minimum_frequency) or (f_isco >= maximum_frequency):
            print(f'ERROR: f_isco not in band. f_isco={f_isco} min_freq={minimum_frequency} max_freq={maximum_frequency}')
            exit(0)

        # set noise (optional zero-noise)
        try:
            print('try option worked')
            ifos.set_strain_data_from_power_spectral_densities(
                        sampling_frequency=sampling_frequency,
                        duration=duration,
                        start_time=start_time,
                        zero_noise=bool(args.zero_noise),
            )
        except TypeError:
            print('except option worked')
            ifos.set_strain_data_from_power_spectral_densities(
                        sampling_frequency=sampling_frequency,
                        duration=duration,
                        start_time=start_time,
            )
            if args.zero_noise:
                for ifo in ifos:
                    ifo.strain_data.time_domain_strain[:] = 0.0

        for ifo in ifos:
            ifo.minimum_frequency = minimum_frequency

        ifos.inject_signal(waveform_generator=waveform_generator, parameters=inj, raise_error=False)

        # compute network optimal SNR
        pols = waveform_generator.frequency_domain_strain(inj)
        rhos = []
        for ifo in ifos:
            sig_ifo = ifo.get_detector_response(pols, inj)
            rho_ifo = np.sqrt(np.real(ifo.optimal_snr_squared(signal=sig_ifo)))
            rhos.append(rho_ifo)
        rho_net = float(np.hypot.reduce(rhos))

        if rho_net >= SNR_THRESHOLD:
            print(f"Accepted injection with rho_net = {rho_net:.2f}")
            inj["net_opt_snr"] = rho_net
            break

     # define ground based likelihood
    likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
                interferometers=ifos,
                waveform_generator=waveform_generator,
                #priors=inf_priors,
                phase_marginalization=False,
                distance_marginalization=False,
                time_marginalization=False,
    )

    return likelihood, inj, ifos, None, None, None

################################################################################
# Main
################################################################################

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=88170235)
    parser.add_argument("--run-id", type=str, required=True)
    parser.add_argument(
        "--detector",
        type=int,
        default=0,
        choices=[0, 1, 2, 3, 4],
        help=(
            "Detector network: 0=HLV, 1=ET, 2=CE, 3=ET+CE, 4=LISA (A/E/T via fastlisaresponse). "
            "(Replaces the old --gen/--instrument flags.)"
        ),
    )
    parser.add_argument("--mem", type=int, default=0, help="0=no memory, 1=MWM, 2=NRSur7dq2")
    #parser.add_argument("--reuse-existing-result", type=int, default=1)
    #parser.add_argument("--force-rerun", type=int, default=0)
    parser.add_argument("--only-reweight", type=int, default=0)
    #parser.add_argument("--reuse-existing-reweight", type=int, default=1)
    parser.add_argument("--zero-noise", type=int, default=0)
    args = parser.parse_args()

    np.random.seed(args.seed)
    bilby.core.utils.random.seed(args.seed)

    # Output structure
    BASE_OUTDIR = "/data/www.astro/chrism/newcos/outdir"
    run_root = f"{args.run_id}_{args.mem}_{args.detector}"
    label = f"{args.run_id}_{args.mem}_{args.detector}_{args.seed}"
    outdir = os.path.join(BASE_OUTDIR, run_root, label)
    os.makedirs(outdir, exist_ok=True)
    patterns = [os.path.join(outdir, "*_result.json")]

    # copy the executable to the results
    this_file = Path('./cosmem.py').resolve()   # Path to the currently running script
    shutil.copy2(this_file, Path(outdir) / this_file.name)   # Copy the file
  
    # initialise the redshift interpolation
    init_redshift_interpolator(zmax=100.0, ngrid=4000)
    H0_min = 10
    H0_max = 200

    # Detectors + priors range + sampling settings
    is_lisa = False
    lisa_wrapper = None
    run_analysis = run_ground_analysis
    if args.detector == 0:
        detstring = ["H1", "L1", "V1"]
        sampling_frequency = 2048.0
        duration = 8.0
        minimum_frequency = 5.0
        maximum_frequency = sampling_frequency/2.0
        reference_frequency = 5.0
        df_taper = 5.0
        approximant = "IMRPhenomXPHM"
        dt_geocent = 0.01
        min_dt_geocent = 5e-4
        min_m_src = 5.0
        max_m_src = 80.0
        dL_max_prior = 500.0
        SNR_THRESHOLD = 8.0
    elif args.detector == 1:
        detstring = ["ET"]
        sampling_frequency = 2048.0
        duration = 8.0
        minimum_frequency = 5.0
        maximum_frequency = sampling_frequency/2.0
        reference_frequency = 5.0
        df_taper = 5.0
        approximant = "IMRPhenomXPHM"
        dt_geocent = 0.01
        min_dt_geocent = 5e-4
        min_m_src = 5.0
        max_m_src = 80.0
        dL_max_prior = 100.0
        SNR_THRESHOLD = 8.0
    elif args.detector == 2:
        detstring = ["CE"]
        sampling_frequency = 2048.0
        duration = 8.0
        minimum_frequency = 5.0
        maximum_frequency = sampling_frequency/2.0
        reference_frequency = 5.0
        df_taper = 5.0
        approximant = "IMRPhenomXPHM"
        dt_geocent = 0.01
        min_dt_geocent = 5e-4
        min_m_src = 5.0
        max_m_src = 80.0
        dL_max_prior = 100.0
        SNR_THRESHOLD = 8.0
    elif args.detector == 3:
        detstring = ["ET", "CE"]
        sampling_frequency = 2048.0
        duration = 8.0
        minimum_frequency = 5.0
        maximum_frequency = sampling_frequency/2.0
        reference_frequency = 5.0
        df_taper = 5.0
        approximant = "IMRPhenomXPHM"
        dt_geocent = 0.01
        min_dt_geocent = 5e-4
        min_m_src = 5.0
        max_m_src = 80.0
        dL_max_prior = 100.0
        SNR_THRESHOLD = 8.0
    elif args.detector == 4:
        detstring = ["A","E","T"]
        sampling_frequency = 0.2
        duration = (2**13) / sampling_frequency
        minimum_frequency = 2.0 * (1.0 / duration)  # placeholder - real fmin based on masses
        maximum_frequency = 0.5 * sampling_frequency
        reference_frequency = minimum_frequency
        df_taper = 10.0/duration   # 10 frequency bins
        approximant = "IMRPhenomD"
        dt_geocent = 100.0
        min_dt_geocent = 1.0
        min_m_src = 1e5
        max_m_src = 1e7
        dL_max_prior = 6.7e3 
        SNR_THRESHOLD = 8.0
        is_lisa = True
        run_analysis = run_lisa_analysis
    else:
        raise ValueError("--detector must be in {0,1,2,3,4}.")

    # using a fixed reference trigger time
    trigger_time = 0.0 #1126259642.413
    print(trigger_time, duration)
    start_time = trigger_time - (3.0/4.0)*duration  # the observation always starts 3/4*T before merger
    print(start_time)

    # minimum/maximum chirp mass
    min_chirp_mass_src = min_m_src/2**(1.0/5.0)
    max_chirp_mass_src = max_m_src/2**(1.0/5.0)
    max_z = compute_redshift_from_H0_dL(H0_max, dL_max_prior)  # the max redshift in inference

    # define the list of fixed parameters - varied in prior but not in inference
    # The LISA analysis currently only supports aligned spins (I think - IMRPhenomD)
    # we should be able to include a1 and a2 in the inference at some point
    # sky position is unlikely to correlate with the memory
    fixed_pars = ["a_1", "a_2", "tilt_1", "tilt_2", "phi_12", "phi_jl", "ra", "dec"]

    # the waveform arguments
    waveform_arguments = dict(
        waveform_approximant=approximant,
        reference_frequency=reference_frequency,
        minimum_frequency=minimum_frequency,
        maximum_frequency = maximum_frequency,
    )

    # Memory map
    MEM_MAP = {
        0: None,
        1: "MWM",
        2: "NRSur7dq2",
    }
    if args.mem not in MEM_MAP:
        raise ValueError(f"--mem {args.mem} not recognised. Extend MEM_MAP.")
    mem_model = MEM_MAP[args.mem]

    # Build FD source model
    if args.mem == 0:
        fd_source = make_bbh_no_memory_fd(
            sampling_frequency=sampling_frequency,
            duration=duration,
            minimum_frequency_default=minimum_frequency,
            df_taper=df_taper,
            waveform_arguments=waveform_arguments,
        )
        fd_source_nomem = fd_source   # define this for plotting only
    else:
        fd_source = make_bbh_with_gwmemory_fd(
            gwmemory_model=mem_model,
            l_max=4,
            sampling_frequency=sampling_frequency,
            duration=duration,
            minimum_frequency_default=minimum_frequency,
            df_taper=df_taper,
            waveform_arguments=waveform_arguments,
        )
        # we define this for plotting only
        fd_source_nomem = make_bbh_no_memory_fd(
            sampling_frequency=sampling_frequency,
            duration=duration,
            minimum_frequency_default=minimum_frequency,
            df_taper=df_taper,
            waveform_arguments=waveform_arguments,
        )

    # if not LISA then we define the waveform generator ready to use with Bilby
    waveform_generator = None
    if args.detector < 4:
        waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
            duration=duration,
            sampling_frequency=sampling_frequency,
            frequency_domain_source_model=fd_source,
            parameter_conversion=None,
            waveform_arguments=waveform_arguments,
        )

    # define sampling priors - in source frame 
    samp_priors = bilby.core.prior.PriorDict()
    if args.detector==4:
        samp_priors["m1_src"] = bilby.core.prior.LogUniform(min_m_src, max_m_src, name="m1_src")
        samp_priors["q"] = bilby.core.prior.Uniform(0.1, 1.0, name="q")
        samp_priors["phi_12"] = bilby.core.prior.DeltaFunction(0.0, name="phi_12")
        samp_priors["phi_jl"] = bilby.core.prior.DeltaFunction(0.0, name="phi_jl")
        samp_priors["tilt_1"] = bilby.core.prior.DeltaFunction(0.0, name="tilt_1")
        samp_priors["tilt_2"] = bilby.core.prior.DeltaFunction(0.0, name="tilt_2")
    else:
        samp_priors["m1_src"] = bilby.core.prior.PowerLaw(alpha=-2.5, minimum=min_m_src, maximum=max_m_src, name="m1_src")
        samp_priors["q"] = bilby.core.prior.PowerLaw(alpha=1.0, minimum=0.1, maximum=1.0, name="q")
        samp_priors["phi_12"] = bilby.core.prior.Uniform(0.0, 2*np.pi, name="phi_12",boundary="periodic")
        samp_priors["phi_jl"] = bilby.core.prior.Uniform(0.0, 2*np.pi, name="phi_jl",boundary="periodic")
        samp_priors["tilt_1"] = bilby.core.prior.Cosine(name="tilt_1")
        samp_priors["tilt_2"] = bilby.core.prior.Cosine(name="tilt_2")
    samp_priors["m2_src"] = bilby.core.prior.Constraint(minimum=min_m_src,maximum=max_m_src)
    samp_priors["luminosity_distance"] = bilby.gw.prior.UniformSourceFrame(0.0, dL_max_prior, name="luminosity_distance")
    samp_priors["H0"] = bilby.core.prior.DeltaFunction(peak=float(cosmo.H0.value), name="H0")   # always use the same H0 value
    samp_priors["theta_jn"] = bilby.core.prior.Sine(name="theta_jn")
    samp_priors["geocent_time"] = bilby.core.prior.DeltaFunction(peak=float(trigger_time), name="geocent_time")  # always use the same merger time 
    samp_priors["phase"] = bilby.core.prior.Uniform(0.0, 2*np.pi, name="phase", boundary="periodic")
    samp_priors["a_1"] = bilby.core.prior.Uniform(0.0, 0.8, name="a_1")
    samp_priors["a_2"] = bilby.core.prior.Uniform(0.0, 0.8, name="a_2")
    samp_priors["psi"] = bilby.core.prior.Uniform(0.0, np.pi, name="psi")
    samp_priors["ra"] = bilby.core.prior.Uniform(0.0, 2*np.pi, name="ra")
    samp_priors["dec"] = bilby.core.prior.Cosine(name="dec")

    # define conversion to get m2
    def _samp_conv(p):
        #added = {}
        p["m2_src"] = p["q"] * p["m1_src"]
        print(f'here {p["m2_src"]}')
        #added["m2_src"] = p["m2_src"]
        return p #, added
    samp_priors.conversion_function = _samp_conv
 
    # define inference priors - detector frame
    inf_priors = bilby.core.prior.PriorDict()
    inf_priors["chirp_mass"] = bilby.core.prior.Uniform(minimum=min_m_src*2**(-0.2),maximum=(1.0+max_z)*max_m_src*2**(-0.2),name="chirp_mass")
    inf_priors["q"] = bilby.core.prior.Uniform(minimum=0.1,maximum=1.0,name="q")
    #inf_priors["m2_src"] = bilby.core.prior.Constraint(minimum=min_m_src,maximum=np.inf)
    inf_priors["luminosity_distance"] = bilby.gw.prior.UniformSourceFrame(0.0, dL_max_prior, name="luminosity_distance")
    inf_priors["H0"] = bilby.core.prior.Uniform(H0_min, H0_max, name="H0")
    inf_priors["theta_jn"] = bilby.core.prior.Sine(name="theta_jn")
    inf_priors["geocent_time"] = bilby.core.prior.Uniform(trigger_time - dt_geocent, trigger_time + dt_geocent, name="geocent_time")
    inf_priors["phase"] = bilby.core.prior.Uniform(0.0, 2*np.pi, name="phase", boundary="periodic")
    inf_priors["psi"] = bilby.core.prior.Uniform(0.0, np.pi, name="psi")
    inf_priors["a_1"] = bilby.core.prior.Uniform(0.0, 0.8, name="a_1")
    inf_priors["a_2"] = bilby.core.prior.Uniform(0.0, 0.8, name="a_2")
    inf_priors["ra"] = bilby.core.prior.Uniform(0.0, 2*np.pi, name="ra")
    inf_priors["dec"] = bilby.core.prior.Cosine(name="dec")
    inf_priors["phi_12"] = bilby.core.prior.Uniform(0.0, 2*np.pi, name="phi_12",boundary="periodic")
    inf_priors["phi_jl"] = bilby.core.prior.Uniform(0.0, 2*np.pi, name="phi_jl",boundary="periodic")
    inf_priors["tilt_1"] = bilby.core.prior.Cosine(name="tilt_1")
    inf_priors["tilt_2"] = bilby.core.prior.Cosine(name="tilt_2")

    def _inf_conv(p):
        # IMPORTANT: don't mutate the original dict in-place if bilby reuses it
        p = dict(p)
        added = {}

        h0 = np.atleast_1d(p["H0"])
        dl = np.atleast_1d(p["luminosity_distance"])
        mc = np.atleast_1d(p["chirp_mass"]).astype(float)
        q  = np.atleast_1d(p["q"]).astype(float)

        # Bilby/nessai can call conversion_function on an empty subsample.
        # Ensure derived keys exist and return safely.
        if h0.size == 0 or dl.size == 0 or mc.size == 0 or q.size == 0:
            n = max(h0.size, dl.size, mc.size, q.size)  # usually 0 here
            p["z"] = np.empty(n, dtype=float)
            p["mass_1"] = np.empty(n, dtype=float)
            p["mass_2"] = np.empty(n, dtype=float)
            p["m1_src"] = np.empty(n, dtype=float)
            p["m2_src"] = np.empty(n, dtype=float)
            added["z"] = p["z"]
            added["mass_1"] = p["mass_1"]
            added["mass_2"] = p["mass_2"]
            added["m1_src"] = p["m1_src"]
            added["m2_src"] = p["m2_src"]            
            return p, added

        # Compute redshift (assume lengths match in normal operation)
        if h0.size > 1:
            z = np.array([compute_redshift_from_H0_dL(h0[i], dl[i]) for i in range(h0.size)], dtype=float)
        else:
            z = float(compute_redshift_from_H0_dL(h0[0], dl[0]))
        p["z"] = z
        added["z"] = z

        # chirp_mass + q -> component masses (assume q = m2/m1 <= 1)
        m1 = mc * q**(-3.0/5.0) * (1.0 + q)**(1.0/5.0)
        m2 = q * m1
        p["mass_1"] = m1
        p["mass_2"] = m2
        p["m1_src"] = m1/(1.0+z)
        p["m2_src"] = m2/(1.0+z)
        added["mass_1"] = m1
        added["mass_2"] = m2
        added["m1_src"] = p["m1_src"]
        added["m2_src"] = p["m2_src"]
        return p, added

    inf_priors.conversion_function = _inf_conv

    # if not LISA then we define the waveform generator ready to use with Bilby
    waveform_generator = None
    if args.detector < 4:
        waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
            duration=duration,
            sampling_frequency=sampling_frequency,
            frequency_domain_source_model=fd_source,
            parameter_conversion=inf_priors.conversion_function,
            waveform_arguments=waveform_arguments,
        )


    likelihood, inj, ifos, data_fd, psd_fd, lisa_wrapper = run_analysis(
            args=args,
            start_time=start_time,
            sampling_frequency=sampling_frequency,
            duration=duration,
            minimum_frequency=minimum_frequency,
            maximum_frequency=maximum_frequency,
            samp_priors=samp_priors,
            inf_priors=inf_priors,
            SNR_THRESHOLD=SNR_THRESHOLD,
            fd_source=fd_source,
            waveform_generator=waveform_generator,
            waveform_arguments=waveform_arguments,
            detstring=detstring,
    )

    # --- PLOT TIME DOMAIN DATA (Added Request) ---
    plot_time_domain_data(
        outdir=outdir,
        label=label,
        detectors=ifos,
        parameters=inj,
        sampling_frequency=sampling_frequency,
        duration=duration,
        start_time=start_time,
        minimum_frequency=minimum_frequency,
        fd_source_full=fd_source,
        fd_source_nomem=fd_source_nomem,
        data_fd_dict=data_fd,
        psd_fd_dict=psd_fd,
        is_lisa=is_lisa,
        lisa_wrapper=lisa_wrapper
    )

    # Locate existing result
    result_files = []
    for pat in patterns:
        result_files.extend(glob.glob(pat, recursive=True))
    result_files = sorted(result_files)
    print(f"Found {len(result_files)} result files in {outdir}")

    # if we are using existing results and doing reweighting 
    if args.only_reweight:
        if len(result_files) == 0:
            raise FileNotFoundError(f"--only-reweight set but no existing result matched {patterns}")
        print(f"[only-reweight] Loading: {result_files[0]}")
        result = bilby.result.read_in_result(result_files[0])
        try:
            if getattr(result, "injection_parameters", None) is not None:
                inj = dict(result.injection_parameters)
        except Exception:
            inj = {}

    else:

        # Choose parameters to include in Fisher
        param_keys = [
        "chirp_mass",
        "q",
        "luminosity_distance",
        "geocent_time",
        "phase",
        ]
        if args.mem>0:
            param_keys.append("H0")

        # Marginalise nuisance parameters (typical)
        nuisance_keys = ["phase"] #, "geocent_time"]

        # Optional: per-parameter relative steps in SCALED coordinates
        # (start here; you can tune)
        relative_steps = {
        "chirp_mass": 1e-5,            # step in log(Mc)
        "luminosity_distance": 1e-6,   # step in log(dL)
        "q": 1e-3,
        "geocent_time": 1e-7,          # step in ms coordinate
        "phase": 1e-6,
        "psi": 1e-6,
        "theta_jn": 1e-6,
        }
        if args.mem>0:
            relative_steps["H0"] = 1e-3  # step in H0/100
        if args.detector==4: 
            relative_steps["geocent_time"] = 0.1          # step in seconds coordinate

        # Use requested scaling (includes log(dL))
        scaling_spec = default_scaling_spec()
        if args.detector==4:
            scaling_spec["geocent_time"] = ("div", 1.0)    # keep seconds as the geocent scale for LISA  

        #for rel in np.logspace(-12,-1,25): #[1e-8,1e-7,1e-6,1e-5,1e-4,1e-3,1e-2,1e-1]:
        out = fisher_uncertainties_finite_difference_scaled_marginalised(
            likelihood=likelihood,
            params=inj,
            param_keys=param_keys,
            nuisance_keys=nuisance_keys,
            relative_steps=relative_steps, #{k:rel for k in param_keys}, #relative_steps,
            scaling_spec=scaling_spec,
            fmin=None,
            fmax=None,
            pinv_rcond=1e-12,
            verbose=True,
        )
        #print(f'{rel},{out["sigma_theta_marg"]}')
        #exit(0)

        save_fisher_results(out, outdir=outdir, label=label)

        print("\nKept parameters:", out["kept_keys"])
        print("Marginalised Fisher (phi):\n", out["fisher_phi_marg"])
        print("Marginalised Cov (theta):\n", out["cov_theta_marg"])
        print("Marginalised 1-sigma (theta):")
        for k, v in out["sigma_theta_marg"].items():
            print(f"  {k}: {v:.6g}")

        # Constrain priors on mass, distance and time for inference efficiency
        nsig = 20
        inf_priors["chirp_mass"].minimum = max(min_chirp_mass_src,float(inj["chirp_mass"]) - nsig*out["sigma_theta_marg"]["chirp_mass"])
        inf_priors["chirp_mass"].maximum = min(max_chirp_mass_src*(1.0+max_z),float(inj["chirp_mass"]) + nsig*out["sigma_theta_marg"]["chirp_mass"]) 
        inf_priors["q"].minimum = max(0.1,float(inj["q"]) - nsig*out["sigma_theta_marg"]["q"])
        inf_priors["q"].maximum = min(1.0,float(inj["q"]) + nsig*out["sigma_theta_marg"]["q"])
        inf_priors["luminosity_distance"].minimum = max(0.0,float(inj["luminosity_distance"]) - nsig*out["sigma_theta_marg"]["luminosity_distance"])
        inf_priors["luminosity_distance"].maximum = min(dL_max_prior,float(inj["luminosity_distance"]) + nsig*out["sigma_theta_marg"]["luminosity_distance"])
        Mtot_det = float(inj["mass_1"] + inj["mass_2"])
        M_sec = (G_SI * MSUN_SI * Mtot_det) / (C_SI**3)
        f_isco = 1.0 / (6.0**1.5 * np.pi * M_sec) 
        inf_priors["geocent_time"].minimum = float(inj["geocent_time"]) - nsig*out["sigma_theta_marg"]["geocent_time"]
        inf_priors["geocent_time"].maximum = float(inj["geocent_time"]) + nsig*out["sigma_theta_marg"]["geocent_time"]
        print(inj)
        print(f'chirp mass range [{inf_priors["chirp_mass"].minimum}-{inf_priors["chirp_mass"].maximum}]')
        print(f'q range [{inf_priors["q"].minimum}-{inf_priors["q"].maximum}]')
        print(f'luminosity_distance range [{inf_priors["luminosity_distance"].minimum}-{inf_priors["luminosity_distance"].maximum}]')
        print(f'geocent_time range [{inf_priors["geocent_time"].minimum}-{inf_priors["geocent_time"].maximum}]')
        #print(inj)

        # IMPORTANT - fix some parameters by using delta functiion priors at the injection values
        for k in fixed_pars:
            inf_priors[k] = bilby.core.prior.DeltaFunction(peak=float(inj[k]), name=k)
        #print(inf_priors)
        #print(inf_priors.non_fixed_keys)

        if (0):

            # Run the sampler
            result = bilby.run_sampler(
                likelihood=likelihood,
                priors=inf_priors,
                sampler="nessai",
                nlive=2000,
                dlogz=0.1,
                outdir=outdir,
                label=label,
                resume=False,
                injection_parameters=inj,
            )
        else:

            # Choose emcee hyperparams
            nwalkers = 4 * len(inf_priors.non_fixed_keys)   # rule of thumb: >= 2*ndim; 4*ndim is often safer
            nsteps   = 20000
            nburn    = 5000

            pos0, search_keys = make_emcee_pos0_from_injection(
                priors=inf_priors,
                injection_parameters=inj,
                nwalkers=nwalkers,
                u_sigma=1e-4,   # tighten/loosen initial cloud
                seed=0
            )

            result = bilby.run_sampler(
                likelihood=likelihood,
                priors=inf_priors,
                sampler="emcee",
                nwalkers=nwalkers,
                nsteps=nsteps,
                nburn=nburn,
                pos0=pos0,              # <-- start walkers near injection truth :contentReference[oaicite:3]{index=3}
                outdir=outdir,
                label=label,
                resume=False,
                injection_parameters=inj,
            )


    # Calculate and add z AND source masses BEFORE reweighting
    if "z" not in result.posterior.columns or "m1_src" not in result.posterior.columns:
        z_list = []
        m1_src_list = []
        m2_src_list = []
        m1_det_list = []
        m2_det_list = []

        for _, row in result.posterior.iterrows():
            q  = float(row["q"])      # assuming q = m2/m1 <= 1
            mc = float(row["chirp_mass"])
            m1_det = mc * q**(-3.0/5.0) * (1.0 + q)**(1.0/5.0)
            m2_det = q * m1_det
            z = compute_redshift_from_H0_dL(float(row["H0"]), float(row["luminosity_distance"]))
            z_list.append(z)
            m1_det_list.append(m1_det)
            m2_det_list.append(m2_det)
            m1_src_list.append(m1_det / (1.0 + z))
            m2_src_list.append(m2_det / (1.0 + z))
    
        result.posterior["z"] = np.asarray(z_list, float)
        result.posterior["mass_1"] = np.asarray(m1_det_list, float)
        result.posterior["mass_2"] = np.asarray(m2_det_list, float)
        result.posterior["m1_src"] = np.asarray(m1_src_list, float)
        result.posterior["m2_src"] = np.asarray(m2_src_list, float)

    # Attach injection for downstream plotting / reweighting
    result.injection_parameters = inj

    # ------------------------------------------------------------------
    # Post-processing pipeline (no marginalisations; keep this structured)
    #   1) Save raw posterior + corner
    #   2) Apply SNR selection (if possible) + save + corner
    #   3) Reweight from inf_priors -> samp_priors (with Jacobian if needed)
    #      + save + weighted corner
    # ------------------------------------------------------------------

    def _safe_makedirs(p):
        os.makedirs(p, exist_ok=True)

    def _save_csv(posterior, path):
        # bilby Result.posterior is a pandas.DataFrame; no explicit pandas dependency needed
        posterior.to_csv(path, index=False)
        print(f"Saved posterior: {path}")

    def _is_scalar_number(x):
        # Accept numpy scalars and python floats/ints; reject sequences / arrays
        if x is None:
            return False
        if isinstance(x, (float, int, np.floating, np.integer)):
            return True
        # 0-d numpy arrays
        if isinstance(x, np.ndarray) and x.shape == ():
            return True
        return False

    def _scalarise_value(x):
        # Turn numpy scalar / 0-d array into python float; leave scalars untouched.
        if isinstance(x, np.ndarray) and x.shape == ():
            return float(x)
        if isinstance(x, (np.floating, np.integer)):
            return float(x)
        return x

    def _select_plottable_parameters(posterior, preferred):
        plottable = []
        for k in preferred:
            if k not in posterior.columns:
                continue
            try:
                v0 = posterior[k].iloc[0]
            except Exception:
                continue
            v0 = _scalarise_value(v0)
            if not _is_scalar_number(v0):
                continue
            # also ensure the full column can be converted to float
            try:
                _ = posterior[k].to_numpy(dtype=float)
            except Exception:
                continue
            plottable.append(k)
        return plottable

    def _make_plotting_result(base_result, posterior, trigger_time, plot_params, truths, weights=None, filename=None):
        # Do not mutate the underlying result.posterior in-place: create a shallow copy result.
        r = copy.deepcopy(base_result)
        post = copy.deepcopy(posterior)

        # Add a plotting-only relative time if geocent_time is present
        if "geocent_time" in post.columns:
            post["geocent_time_rel"] = post["geocent_time"].to_numpy(dtype=float) - float(trigger_time)

        r.posterior = post

        # If plotting relative time, swap parameter name
        plot_params_eff = []
        truths_eff = list(truths)
        for i, p in enumerate(plot_params):
            if p == "geocent_time" and "geocent_time_rel" in post.columns:
                plot_params_eff.append("geocent_time_rel")
                # truth list provided by caller should already be geocent_time - trigger_time
            else:
                plot_params_eff.append(p)

        # Filter to scalar/numeric columns to avoid "inhomogeneous shape" corner failures
        plot_params_eff = _select_plottable_parameters(post, plot_params_eff)

        if len(plot_params_eff) < 2:
            print(f"Corner skipped (need >=2 scalar params), stage file={filename}")
            return

        try:
            r.plot_corner(
                parameters=plot_params_eff,
                truths=truths_eff,
                priors=False,
                weights=weights,
                filename=filename,
            )
        except Exception as e:
            wmsg = "weighted" if weights is not None else "unweighted"
            print(f"Corner failed ({wmsg}): {e}")

    _safe_makedirs(outdir)

    # Keep references to priors used for inference and sampling
    inf_priors_used = inf_priors
    samp_priors_used = samp_priors

    # Choose a sensible, stable plotting parameter set
    preferred_plot_params = [
        "chirp_mass", "q", "mass_1", "mass_2",
        "m1_src", "m2_src", "z",
        "geocent_time", "phase",
        "luminosity_distance", "H0",
        "theta_jn",
        "psi",
        #"a_1", "a_2",
        #"ra", "dec",
        #"phi_12", "phi_jl", "tilt_1", "tilt_2",
    ]

    # Truths for plotting (only used if present; keep consistent with geocent_time_rel)
    def truth_get(keys, default=np.nan):
        for k in keys:
            if k in inj:
                return float(inj[k])
        return float(default)

    # For geocent_time we plot relative time
    truths = []
    plot_params = []
    for p in preferred_plot_params:
        if p == "geocent_time":
            plot_params.append("geocent_time")
            truths.append(truth_get(["geocent_time"]) - float(trigger_time))
        else:
            plot_params.append(p)
            truths.append(truth_get([p]))

    # ----------------------------
    # Stage 1: raw posterior
    # ----------------------------
    posterior_raw = copy.deepcopy(result.posterior).reset_index(drop=True)
    raw_csv = os.path.join(outdir, f"{label}_raw_posterior.csv")
    _save_csv(posterior_raw, raw_csv)

    _make_plotting_result(
        base_result=result,
        posterior=posterior_raw,
        trigger_time=trigger_time,
        plot_params=plot_params,
        truths=truths,
        weights=None,
        filename=os.path.join(outdir, f"{label}_raw_corner.png"),
    )

    # ----------------------------
    # Stage 2: SNR selection
    # ----------------------------
    posterior_snr = posterior_raw
    snr_csv = None
    try:
        posterior_snr = apply_snr_selection_bias(
            posterior_raw,
            likelihood=likelihood,
            snr_threshold=SNR_THRESHOLD,
            weight_cols=[],
            npool=1,
        )
        snr_csv = os.path.join(outdir, f"{label}_snr_selected_posterior.csv")
        _save_csv(posterior_snr, snr_csv)

        _make_plotting_result(
            base_result=result,
            posterior=posterior_snr,
            trigger_time=trigger_time,
            plot_params=plot_params,
            truths=truths,
            weights=None,
            filename=os.path.join(outdir, f"{label}_snr_selected_corner.png"),
        )
    except Exception as e:
        # For LISA or custom likelihoods, bilby.compute_snrs may not work.
        # If we already have an SNR column, use it; otherwise skip selection.
        print(f"SNR selection step skipped/failed: {e}")

        # Try manual selection if an SNR-like column exists
        snr_cols = [c for c in posterior_raw.columns if ("snr" in c.lower() and "optimal" in c.lower())]
        if "network_optimal_snr" in posterior_raw.columns:
            snr_cols = ["network_optimal_snr"]
        if "optimal_snr" in posterior_raw.columns and len(snr_cols) == 0:
            snr_cols = ["optimal_snr"]

        if len(snr_cols) > 0:
            if len(snr_cols) == 1:
                net = np.abs(posterior_raw[snr_cols[0]].to_numpy(dtype=float))
            else:
                s2 = np.zeros(len(posterior_raw), dtype=float)
                for c in snr_cols:
                    s2 += np.abs(posterior_raw[c].to_numpy(dtype=float)) ** 2
                net = np.sqrt(s2)

            mask = net > float(SNR_THRESHOLD)
            posterior_snr = posterior_raw.loc[mask].reset_index(drop=True)
            print(f"Manual SNR selection: {len(posterior_raw)} -> {len(posterior_snr)} (thr={SNR_THRESHOLD})")

            snr_csv = os.path.join(outdir, f"{label}_snr_selected_posterior.csv")
            _save_csv(posterior_snr, snr_csv)

            _make_plotting_result(
                base_result=result,
                posterior=posterior_snr,
                trigger_time=trigger_time,
                plot_params=plot_params,
                truths=truths,
                weights=None,
                filename=os.path.join(outdir, f"{label}_snr_selected_corner.png"),
            )
        else:
            posterior_snr = posterior_raw

    # ----------------------------
    # Stage 3: reweight inf_priors -> samp_priors
    # ----------------------------
    # Choose a Jacobian based on what the NEW prior is defined on
    jac = None
    try:
        new_keys = list(getattr(samp_priors_used, "keys")())
    except Exception:
        new_keys = list(samp_priors_used.keys())

    if ("m1_src" in new_keys) and ("q" in new_keys):
        jac = jacobian_det_to_m1q_src
    elif ("m1_src" in new_keys) and ("m2_src" in new_keys):
        jac = jacobian_det_to_src

    posterior_rw, w_inf_to_samp = reweight_posterior_samples(
        posterior_snr,
        original_prior=inf_priors_used,
        new_prior=samp_priors_used,
        existing_weight_column=None,
        new_weight_column="w_inf_to_samp",
        jacobian=jac,
        require_all_new_prior_keys=True,
        normalize=True,
        inplace=False,
    )

    rw_csv = os.path.join(outdir, f"{label}_reweighted_posterior.csv")
    _save_csv(posterior_rw, rw_csv)

    _make_plotting_result(
        base_result=result,
        posterior=posterior_rw,
        trigger_time=trigger_time,
        plot_params=plot_params,
        truths=truths,
        weights=posterior_rw["w_inf_to_samp"].to_numpy(dtype=float),
        filename=os.path.join(outdir, f"{label}_reweighted_corner.png"),
    )

    # Save a final bilby result object using the reweighted posterior
    result_final = copy.deepcopy(result)
    result_final.posterior = posterior_rw
    result_final.save_to_file(outdir=outdir)
if __name__ == "__main__":
    main()










