import numpy as np
import matplotlib.pyplot as plt

from beast.physicsmodel.prior_weights_dust import PriorWeightsDust

fig, ax = plt.subplots()

# fA grid with linear spacing
fAs = np.linspace(0.0, 1.0, num=200)

dust_prior_models = [
    {"name": "flat"},
    {"name": "lognormal", "max_pos": 0.8, "sigma": 0.1},
    {
        "name": "two_lognormal",
        "max_pos1": 0.2,
        "max_pos2": 0.8,
        "sigma1": 0.1,
        "sigma2": 0.2,
        "N1_to_N2": 2.0 / 5.0
    }
]

for dmod in dust_prior_models:
    dmodel = PriorWeightsDust(
        [1.0], {"name": "flat"}, [1.0], {"name": "flat"}, fAs, dmod
    )

    ax.plot(fAs, dmodel.fA_priors, label=dmod["name"])

ax.set_ylabel("probability")
ax.set_xlabel(r"$f_A$")
ax.legend(loc="best")
plt.tight_layout()
plt.show()