Skip to content

Bayesian Inference over SGAM PFT Parameters

This notebook demonstrates Bayesian inference over two plant functional type (PFT) parameters of the SGAM vegetation model:

  • lue_max: maximum light use efficiency (gC MJ⁻¹), controlling photosynthetic carbon gain
  • leaf_turnover_rate: weekly leaf turnover rate (week⁻¹), controlling leaf lifespan

We use the Metropolis-Hastings algorithm to recover these parameters from synthetic observations of the weekly leaf carbon pool, generated by running the model with its default grassland parameter values.

import tempfile
import tomllib
from pathlib import Path

import marimo as mo
import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import minimize

from satterc import build_driver
from satterc.config import Config
from satterc.setup_utils.data_gen import generate_synthetic_data

Pipeline configuration

We run the full SPLASH → P-model → SGAM chain. SPLASH provides soil moisture to P-model, whose LUE and GPP outputs feed into SGAM. No outputs are written to disk — all results are kept in memory.

_config_toml = """
modules = [
  "models.splash",
  "models.pmodel",
  "models.sgam",
  "inputs.daily",
  "inputs.weekly",
  "inputs.static",
  "resample",
]

[models.pmodel]
method_kphio = "sandoval"
method_optchi = "lavergne20_c3"

[inputs.daily]
path = "daily.nc"
vars = [
  "precipitation_mm",
  "sunshine_fraction",
  "temperature_celcius",
  "lai",
  "gpp",
]

[inputs.weekly]
path = "weekly.nc"
vars = [
  "co2_ppm",
  "fapar",
  "ppfd_umol_m2_s1",
  "pressure_pa",
  "vpd_pa",
]

[inputs.static]
path = "static.nc"
vars = [
  "elevation",
  "plant_type",
  "max_soil_moisture",
  "leaf_pool_init",
  "stem_pool_init",
  "root_pool_init",
]

[resample]
daily_to_weekly = [
  "temperature_celcius",
  "soil_moisture",
  "aridity_index",
]
daily_to_monthly = []
weekly_to_monthly = []
"""

parsed_config = Config(tomllib.loads(_config_toml)).parse()
parsed_config
_tmpdir = Path(tempfile.mkdtemp())

parsed_config["driver_config"]["daily_inputs_path"] = str(_tmpdir / "daily.nc")
parsed_config["driver_config"]["weekly_inputs_path"] = str(_tmpdir / "weekly.nc")
parsed_config["driver_config"]["static_inputs_path"] = str(_tmpdir / "static.nc")

generate_synthetic_data(config=parsed_config, grid=(1, 1), n_days=730, seed=42)
dr = build_driver(
    modules=parsed_config["modules"],
    config=parsed_config["driver_config"],
)

Generating synthetic observations

We run the full pipeline once to get the "true" leaf pool time series, then add Gaussian noise to create synthetic observations. We also cache all direct inputs to the SGAM node — so that each MCMC step only needs to re-execute SGAM itself, rather than the entire SPLASH → P-model chain.

_SGAM_INPUTS = [
    "gpp_weekly",
    "lue_weekly",
    "iwue_weekly",
    "temperature_celcius_weekly",
    "soil_moisture_weekly",
    "vpd_pa_weekly",
    "disturbances_weekly",
    "dates_weekly",
    "leaf_pool_init",
    "stem_pool_init",
    "root_pool_init",
    "plant_type",
    "pft_params",
    "leaf_pool_weekly",
]
_all_outputs = dr.execute(_SGAM_INPUTS)

# Cache all direct SGAM inputs — passed as overrides in the MCMC loop so that
# each step only recomputes the sgam node, not the upstream SPLASH/P-model chain.
upstream = {k: _all_outputs[k] for k in _SGAM_INPUTS if k != "leaf_pool_weekly"}

true_lue_max = float(_all_outputs["pft_params"]["lue_max"].values[0])
true_leaf_turnover = float(
    _all_outputs["pft_params"]["leaf_turnover_rate"].values[0]
)

np.random.seed(42)
_true_leaf_pool = _all_outputs["leaf_pool_weekly"].values[:, 0]
synthetic_obs = _true_leaf_pool + np.random.normal(0, 2.0, _true_leaf_pool.shape)

MLE warm-start

We run a quick Nelder-Mead optimisation to find the maximum-likelihood parameter values before starting the MCMC chain. Initialising at the MLE means the chain starts in the high-probability region from the very first step, removing the need for a long burn-in period and making fixed step sizes much easier to calibrate.

Two features of the optimisation results are worth noting:

  • The MSE does not converge to zero. The irreducible minimum is σ² = 4, the variance of the noise added to the synthetic observations. Even with the exact true parameters, the model cannot fit the noise away.
  • The MLE does not recover the exact true values. With a finite, noisy sample the optimizer finds the parameters that best explain the noisy signal, not the true one. This is the motivation for the MCMC step: rather than a point estimate, we want a posterior distribution that honestly reflects the residual uncertainty.
# Start far from the grassland defaults (≈ 3.0, 0.035) to show the
# optimiser earning its keep — not just confirming the initial guess.
_x0 = [4.0, 0.02]
opt_history = [[*_x0, float(objective_function(_x0, dr, synthetic_obs, upstream))]]

def _callback(xk):
    opt_history.append(
        [
            float(xk[0]),
            float(xk[1]),
            float(objective_function(list(xk), dr, synthetic_obs, upstream)),
        ]
    )

_result = minimize(
    fun=objective_function,
    x0=_x0,
    args=(dr, synthetic_obs, upstream),
    method="Nelder-Mead",
    callback=_callback,
    options={"xatol": 1e-8, "fatol": 1e-8, "maxiter": 2000},
)
mle_params = list(_result.x)
mle_params
_traj = np.array(opt_history)

_fig, _axes = plt.subplots(1, 3, figsize=(16, 4))

_axes[0].plot(_traj[:, 0], marker="o", markersize=3)
_axes[0].axhline(
    y=true_lue_max,
    color="r",
    linestyle="--",
    label=f"True value ({true_lue_max:.3f})",
)
_axes[0].set_xlabel("Iteration")
_axes[0].set_ylabel("lue_max (gC MJ⁻¹)")
_axes[0].set_title("lue_max Convergence")
_axes[0].legend()
_axes[0].grid(True, alpha=0.3)

_axes[1].plot(_traj[:, 1], marker="o", markersize=3)
_axes[1].axhline(
    y=true_leaf_turnover,
    color="r",
    linestyle="--",
    label=f"True value ({true_leaf_turnover:.4f})",
)
_axes[1].set_xlabel("Iteration")
_axes[1].set_ylabel("leaf_turnover_rate (week⁻¹)")
_axes[1].set_title("leaf_turnover_rate Convergence")
_axes[1].legend()
_axes[1].grid(True, alpha=0.3)

_axes[2].plot(_traj[:, 2], marker="o", markersize=3)
_axes[2].axhline(
    y=4, color="k", linestyle="--", alpha=0.6, label="Expected minimum (σ²=4)"
)
_axes[2].set_xlabel("Iteration")
_axes[2].set_ylabel("MSE")
_axes[2].set_title("Objective Function Convergence")
_axes[2].legend()
_axes[2].grid(True, alpha=0.3)

_fig.tight_layout()
_fig

Bayesian Inference

We use the Metropolis-Hastings algorithm to sample the joint posterior of lue_max and leaf_turnover_rate. Each proposal modifies the cached pft_params dataset and re-executes only the SGAM node, making each step fast.

The likelihood is Gaussian with σ = 2 gC m⁻² and the priors are uniform:

  • lue_max ~ Uniform(2.0, 4.5)
  • leaf_turnover_rate ~ Uniform(0.01, 0.07)

Parameter normalisation and step-size selection

These two parameters differ enormously in their effect on the model: a perturbation of ±0.001 in leaf_turnover_rate (≈ 3 % of the true value) changes the log-likelihood by ~300, whereas the same step in lue_max changes it by less than 1. Sensitivity analysis shows the posterior widths are roughly:

  • lue_max: σ ≈ 0.05 in physical units → 0.02 in normalised units
  • leaf_turnover_rate: σ ≈ 5 × 10⁻⁵ in physical units → 8 × 10⁻⁴ in normalised units

To handle this scale mismatch, each parameter is mapped to [0, 1] by

\[u_i = \frac{\theta_i - \theta_i^{\min}}{\theta_i^{\max} - \theta_i^{\min}}\]

Normalisation by prior range maps prior widths to [0, 1], but the posterior widths in normalised space are still very different (~25×). A single isotropic step size cannot serve both directions well, so we use per-parameter step sizes matched to each posterior width (≈ 1σ per direction):

Parameter Posterior σ (normalised) Step size
lue_max ≈ 0.02 0.02
leaf_turnover_rate ≈ 8 × 10⁻⁴ 0.001

Each step is drawn as uniform(−step, +step) independently per parameter. Samples are denormalised back to physical units before storage, so the plots below show posterior distributions in their original units.

Note that with uniform priors the posterior is proportional to the likelihood alone — p(θ|y) ∝ p(y|θ) — so the posterior means coincide with the MLE.

def objective_function(params, dr, observations, upstream):
    lue_max, leaf_turnover = params
    modified_pft = upstream["pft_params"].copy()
    modified_pft["lue_max"].values[:] = lue_max
    modified_pft["leaf_turnover_rate"].values[:] = leaf_turnover
    overrides = {**upstream, "pft_params": modified_pft}
    outputs = dr.execute(final_vars=["leaf_pool_weekly"], overrides=overrides)
    modelled = outputs["leaf_pool_weekly"].values[:, 0]
    return np.mean((modelled - observations) ** 2)
def make_log_posterior(dr, synthetic_obs, upstream, prior_bounds, likelihood_sigma):
    """Return (log_posterior, normalize, denormalize).

    log_posterior accepts params in normalised [0, 1]^d space.
    normalize / denormalize convert between physical and normalised space.

    Running MCMC in normalised space removes the need to hand-tune per-parameter
    step sizes: both dimensions are O(1) regardless of their physical scales.
    """
    π = np.pi
    σ = likelihood_sigma
    N = len(synthetic_obs)
    _lo = np.array([b[0] for b in prior_bounds])
    _hi = np.array([b[1] for b in prior_bounds])

    def normalize(params):
        """Map physical params θ → u ∈ [0, 1]^d."""
        return (np.array(params) - _lo) / (_hi - _lo)

    def denormalize(u):
        """Map normalised u ∈ [0, 1]^d → physical params θ."""
        return _lo + np.array(u) * (_hi - _lo)

    def log_likelihood(u):
        # Denormalise before passing to the model
        mse = objective_function(denormalize(u), dr, synthetic_obs, upstream)
        return -(N / (2 * σ**2)) * mse - (N / 2) * np.log(2 * π * σ**2)

    def log_prior(u):
        # Uniform prior in physical space is Uniform[0,1] in normalised space
        if np.all((np.array(u) >= 0) & (np.array(u) <= 1)):
            return 0.0
        return -np.inf

    def log_posterior(u):
        return log_prior(u) + log_likelihood(u)

    return log_posterior, normalize, denormalize
_prior_bounds = [(2.0, 4.5), (0.01, 0.07)]
n_iterations = 500
burn_in = 200

log_posterior, _normalize, _denormalize = make_log_posterior(
    dr,
    synthetic_obs,
    upstream,
    prior_bounds=_prior_bounds,
    likelihood_sigma=2.0,
)

# Per-parameter step sizes in normalised [0, 1]^2 space, each matched to
# the corresponding posterior width (≈ 1σ per direction):
#   lue_max:           σ ≈ 0.02 normalised  →  step 0.02
#   leaf_turnover_rate: σ ≈ 8e-4 normalised  →  step 0.001
_step_sizes = np.array([0.02, 0.001])

# Warm-start from the MLE
current = list(_normalize(mle_params))

# mcmc_history stores physical params (denormalised) for readable plots
mcmc_history = [list(_denormalize(current))]
accepted = 0

for _i in range(burn_in + n_iterations):
    _proposed = list(
        np.array(current) + np.random.uniform(-_step_sizes, _step_sizes, 2)
    )

    _log_alpha = log_posterior(_proposed) - log_posterior(current)

    if np.log(np.random.uniform()) < _log_alpha:
        current = _proposed
        if _i >= burn_in:
            accepted += 1

    mcmc_history.append(list(_denormalize(current)))

acceptance_rate = accepted / n_iterations
acceptance_rate
_samples = np.array(mcmc_history)

_fig, _axes = plt.subplots(2, 1, figsize=(10, 6), sharex=True)

_axes[0].plot(_samples[:, 0], lw=0.8)
_axes[0].axvline(x=burn_in, color="r", linestyle="--", label=f"Burn-in ({burn_in})")
_axes[0].axhline(
    y=true_lue_max,
    color="g",
    linestyle=":",
    label=f"True value ({true_lue_max:.2f})",
)
_axes[0].set_ylabel("lue_max (gC MJ⁻¹)")
_axes[0].set_title("MCMC Trace")
_axes[0].legend()
_axes[0].grid(True, alpha=0.3)

_axes[1].plot(_samples[:, 1], lw=0.8)
_axes[1].axvline(x=burn_in, color="r", linestyle="--", label=f"Burn-in ({burn_in})")
_axes[1].axhline(
    y=true_leaf_turnover,
    color="g",
    linestyle=":",
    label=f"True value ({true_leaf_turnover:.4f})",
)
_axes[1].set_xlabel("Iteration")
_axes[1].set_ylabel("leaf_turnover_rate (week⁻¹)")
_axes[1].legend()
_axes[1].grid(True, alpha=0.3)

_fig.tight_layout()
_fig
_posterior = np.array(mcmc_history)[burn_in:]

_fig, _axes = plt.subplots(1, 3, figsize=(14, 4))

# 2D joint posterior
_axes[0].scatter(_posterior[:, 0], _posterior[:, 1], alpha=0.4, s=8)
_axes[0].axvline(
    x=true_lue_max,
    color="g",
    linestyle=":",
    label=f"True lue_max ({true_lue_max:.2f})",
)
_axes[0].axhline(
    y=true_leaf_turnover,
    color="r",
    linestyle=":",
    label=f"True leaf_turnover ({true_leaf_turnover:.4f})",
)
_axes[0].set_xlabel("lue_max (gC MJ⁻¹)")
_axes[0].set_ylabel("leaf_turnover_rate (week⁻¹)")
_axes[0].set_title("Joint Posterior")
_axes[0].legend(fontsize=8)
_axes[0].grid(True, alpha=0.3)

# lue_max marginal
_axes[1].hist(_posterior[:, 0], bins=25, density=True, alpha=0.7)
_axes[1].axvline(
    x=true_lue_max,
    color="g",
    linestyle=":",
    label=f"True ({true_lue_max:.2f})",
)
_axes[1].axvline(
    x=np.mean(_posterior[:, 0]),
    color="b",
    linestyle="--",
    label=f"Mean ({np.mean(_posterior[:, 0]):.2f})",
)
_axes[1].set_xlabel("lue_max (gC MJ⁻¹)")
_axes[1].set_ylabel("Density")
_axes[1].set_title("Marginal: lue_max")
_axes[1].legend(fontsize=8)
_axes[1].grid(True, alpha=0.3)

# leaf_turnover_rate marginal
_axes[2].hist(_posterior[:, 1], bins=25, density=True, alpha=0.7)
_axes[2].axvline(
    x=true_leaf_turnover,
    color="g",
    linestyle=":",
    label=f"True ({true_leaf_turnover:.4f})",
)
_axes[2].axvline(
    x=np.mean(_posterior[:, 1]),
    color="b",
    linestyle="--",
    label=f"Mean ({np.mean(_posterior[:, 1]):.4f})",
)
_axes[2].set_xlabel("leaf_turnover_rate (week⁻¹)")
_axes[2].set_ylabel("Density")
_axes[2].set_title("Marginal: leaf_turnover_rate")
_axes[2].legend(fontsize=8)
_axes[2].grid(True, alpha=0.3)

_fig.tight_layout()
_fig
_posterior_summary = np.array(mcmc_history)[burn_in:]
mo.md(f"""
### Inference Results

| Parameter | True | Posterior Mean | Posterior Std |
|-----------|------|----------------|---------------|
| lue_max (gC MJ⁻¹) | {true_lue_max:.3f} | {np.mean(_posterior_summary[:, 0]):.3f} | {np.std(_posterior_summary[:, 0]):.3f} |
| leaf_turnover_rate (week⁻¹) | {true_leaf_turnover:.4f} | {np.mean(_posterior_summary[:, 1]):.4f} | {np.std(_posterior_summary[:, 1]):.4f} |
""")