import numpy as np
import matplotlib.pyplot as plt

from beast.physicsmodel.prior_weights_dust import PriorWeightsDust

fig, ax = plt.subplots()

# av grid with linear spacing
avs = np.linspace(0.0, 10.0, num=200)

dust_prior_models = [
    {"name": "flat"},
    {"name": "lognormal", "max_pos": 2.0, "sigma": 1.0},
    {
        "name": "two_lognormal",
        "max_pos1": 0.2,
        "max_pos2": 2.0,
        "sigma1": 1.0,
        "sigma2": 0.5,
        "N1_to_N2": 1.0 / 5.0
    },
    {"name": "exponential", "a": 1.0},
]

for dmod in dust_prior_models:
    dmodel = PriorWeightsDust(
        avs, dmod, [1.0], {"name": "flat"}, [1.0], {"name": "flat"}
    )

    ax.plot(avs, dmodel.av_priors, label=dmod["name"])

ax.set_ylabel("probability")
ax.set_xlabel("A(V)")
ax.legend(loc="best")
plt.tight_layout()
plt.show()