iftpwa - Tips#
Analysis Tips#
Hyperparameter scans#
Hyperparameter scans are performed under the hood using optuna python package. PyAmpTools streamlines this interaction, see the configuration file tutorial. We can hijack this interface to perform scans over any hyperparameter/key defined in the config file. Since the entire model is defined in this config file, you have full control.
Waveset Scans: Currently only
pa run_iftis aware of hyperparameter scanning but one can imagine doing waveset scans withrun_mlefits (to be added). Useoptuna’sBruteForceSamplerto perform scans over a list of waveset strings.Randomized Fits: If we wanted to do random fits with
pa run_ift, one can useoptuna’sBruteForceSamplerto scan overGENERAL.seedsYAML key.Systematic Uncertainties: We can additionally hijack this system to extract our systematic uncertainties.
NIFTyresearchers estimates systematic uncertainties by randomly moving the prior mean (of all prior distributions) around the prior std (leaving std unchanged), performs a fit, then aggregates the shifts.
Developing Tips#
All random variables start normally distributed, when constructing terms/factors in the model it is often needed to transform to other distributions which can be done through using the cumulative distribution functions and quantile functions
ift.from_randomis often used to probe what the forward model is doingExample: code for drawing from a half-normal distribution
from jax.scipy.special import erfinv
from jax.scipy.stats.norm import cdf as normal_cdf
import nifty8 as ift
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
scale = 10 # arbitary standard deviation
domain = ift.RGSpace(shape=(1 ,), distances=(1.0,), harmonic=False)
def fun(x):
uniform_rv = normal_cdf(x)
# see https://en.wikipedia.org/wiki/Normal_distribution for Quantile function
# which uses erfinv(2p-1) where p is uniform. Makes result half-normal instead of normal
return scale * jnp.sqrt(2.0) * erfinv(uniform_rv)
fun = ift.JaxOperator(domain=domain, target=domain, func=fun)
draws = np.array([fun(ift.from_random(domain)).val for _ in range(10000)]).flatten()
plt.hist(draws, bins=100)