import numpy as np
import matplotlib.pyplot as plt

from beast.physicsmodel.prior_weights_stars import compute_age_prior_weights

fig, ax = plt.subplots()

# logage grid from 1 Myrs to 10 Gyrs
logages = np.linspace(6.0, 10.0)

age_prior_models = [
    {"name": "flat"},
    {"name": "flat_log"},
    {
        "name": "bins_histo",
        "logages": [6.0, 7.0, 8.0, 9.0, 10.0],
        "values": [1.0, 2.0, 1.0, 5.0, 3.0],
    },
    {
        "name": "bins_interp",
        "logages": [6.0, 7.0, 8.0, 9.0, 10.0],
        "values": [1.0, 2.0, 1.0, 5.0, 3.0],
    },
    {"name": "exp", "tau": 0.1}
]

for ap_mod in age_prior_models:
    ax.plot(logages, compute_age_prior_weights(logages, ap_mod), label=ap_mod["name"])

ax.set_ylabel("probability")
ax.set_xlabel("log(age)")
ax.legend(loc="best")
plt.tight_layout()
plt.show()