"""Check a custom :py:class:`Model <eulerpi.core.model.Model>` for implementation errors or test them in a quick inference run on an artificially created dataset."""
import jax.numpy as jnp
import numpy as np
from jax import vmap
from eulerpi.core.inference import InferenceType, inference
from eulerpi.core.model import JaxModel, Model
from eulerpi.core.plotting import sample_violin_plot
[docs]
def basic_model_check(model: Model) -> None:
"""Perform a simple sanity check on the model.
It tests the following:
- The model has a positive parameter dimension
- The model has a positive data dimension
- The model has a valid combination of parameter and data dimension
- The central parameter has the correct shape
- The parameter limits have the correct shape
- The model can be instantiated
- The model forward pass can be calculated
- The model jacobi matrix can be calculated
- The return values of the forward pass and the jacobi matrix have the correct shape
- The jacobi matrix has full rank
Args:
model(Model): The model describing the mapping from parameters to data.
Raises:
AssertionError: Raised if any of the tests fails.
Returns:
None
Examples:
.. code-block:: python
from eulerpi.examples.corona import Corona
from eulerpi.core.model_check import basic_model_check
basic_model_check(Corona())
"""
print(
f"Checking model {model.name} at location \n{model} \nfor basic functionality.\n"
)
# test the shapes
assert (
model.param_dim > 0
), f"Model {model} has a non-positive parameter dimension"
assert (
model.data_dim > 0
), f"Model {model} has a non-positive data dimension"
assert model.data_dim >= model.param_dim, (
f"Model {model} has a data dimension smaller than the parameter dimension. "
"This is not supported by the inference."
)
assert model.central_param.shape == (model.param_dim,), (
f"Model {model} has a central parameter with the wrong shape. "
f"Expected {(model.param_dim,)}, got {model.central_param.shape}"
)
assert model.param_limits.shape == (model.param_dim, 2), (
f"Model {model} has parameter limits with the wrong shape. "
f"Expected {(model.param_dim, 2)}, got {model.param_limits.shape}"
)
print("Successfully checked shapes and dimensions of model attributes.\n")
print(
f"Evaluate model {model.name} and its jacobian in its central parameter \n{model.central_param}."
)
model_forward = model.forward(model.central_param)
assert (
model_forward.shape == (1, model.data_dim)
or model_forward.shape == (model.data_dim,)
or model_forward.shape == ()
), (
f"Model {model} has a forward function with the wrong shape. "
f"Expected {(1, model.data_dim)}, {(model.data_dim,)} or {()}, got {model_forward.shape}"
)
model_jac = model.jacobian(model.central_param)
assert (
model_jac.shape == (model.data_dim, model.param_dim)
or (model.data_dim == 1 and model_jac.shape == (model.param_dim,))
or (model.param_dim == 1 and model_jac.shape == (model.data_dim,))
), (
f"Model {model} has a jacobian function with the wrong shape. "
f"Expected {(model.data_dim, model.param_dim)}, {(model.param_dim,)} or {(model.data_dim,)}, got {model_jac.shape}"
)
# check rank of jacobian
assert jnp.linalg.matrix_rank(model_jac) == model.param_dim, (
f"The Jacobian of the model {model} does not have full rank. This is a requirement for the inference. "
"Please check the model implementation."
)
fw, jc = model.forward_and_jacobian(model.central_param)
assert fw.shape == model_forward.shape, (
f"The shape {fw.shape} of the forward function extracted from the forward_and_jacobian function does not match the shape {model_forward.shape} of the forward function. "
"Please check the model implementation."
)
assert jc.shape == model_jac.shape, (
f"The shape {jc.shape} of the jacobian extracted from the forward_and_jacobian function does not match the shape {model_jac.shape} of the jacobian. "
"Please check the model implementation."
)
assert jnp.allclose(fw, model_forward), (
f"The forward function of the model {model} does not match the forward function extracted from the forward_and_jacobian function. "
"Please check the model implementation."
)
assert jnp.allclose(jc, model_jac), (
f"The jacobian of the model {model} does not match the jacobian extracted from the forward_and_jacobian function. "
"Please check the model implementation."
)
print(
"Successfully checked model forward simulation and corresponding jacobian.\n"
)
[docs]
def inference_model_check(
model: Model,
num_data_points: int = 1000,
num_model_evaluations: int = 11000,
) -> None:
"""Check your model in a quick inference run on an artificially created dataset.
It produces a violin plot comparing the artificially created parameters and data to the respectively inferred samples.
Args:
model(Model): The model describing the mapping from parameters to data.
num_data_points (int, optional): The number of data data points to artificially generate (Default value = 1000)
num_model_evaluations (int, optional): The number of model evaluations to perform in the inference. (Default value = 11000)
Returns:
None
Examples:
.. code-block:: python
from eulerpi.examples.corona import Corona
from eulerpi.core.model_check import inference_model_check
inference_model_check(Corona())
"""
print(
f"Checking model {model.name} at location \n{model} \nfor inference functionality on artificially created data.\n"
)
# create artificial parametrs similar to how we create initial walker positions for emcee sampling
central_param = model.central_param
param_limits = model.param_limits
# sample parameters from a uniform distribution around the central parameter and between the parameter limits
d_min = np.minimum(
central_param - param_limits[:, 0], param_limits[:, 1] - central_param
)
param_sample = central_param + d_min * (
(np.random.rand(num_data_points, model.param_dim) - 0.5) / 3.0
)
# try to use jax vmap to perform the forward pass on multiple parameters at once
if isinstance(model, JaxModel):
data_sample = vmap(model.forward, in_axes=0)(param_sample)
else:
data_sample = np.vectorize(model.forward, signature="(n)->(m)")(
param_sample
)
print(
f"Successfully created an artificial data set of size {num_data_points}.\n"
)
# choose sensible values for the sampling hyper-parameters and print them
num_inference_evaluations = num_model_evaluations - num_data_points
num_walkers = int(np.sqrt(num_inference_evaluations / 10))
num_steps = int(num_inference_evaluations / num_walkers)
num_burn_in_samples = num_walkers
thinning_factor = int(np.ceil(num_walkers / 10))
print("Attempting inference with hyperparameters chosen as follows:")
print(f"num_data_points: {num_data_points}")
print(f"num_walkers: {num_walkers}")
print(f"num_steps: {num_steps}")
print(f"num_burn_in_samples: {num_burn_in_samples}")
print(f"thinning_factor: {thinning_factor}")
run_name = "test_model_run"
# perform the inference
inference(
model,
data=data_sample,
inference_type=InferenceType.MCMC,
slices=[np.arange(model.param_dim)],
run_name=run_name,
num_runs=1,
num_walkers=num_walkers,
num_steps=num_steps,
num_burn_in_samples=num_burn_in_samples,
thinning_factor=thinning_factor,
)
print(
f"Successfully finishes inference run with {num_walkers*num_steps} samples.\n"
)
# plot the results
sample_violin_plot(
model,
reference_sample=param_sample,
run_name=run_name,
credibility_level=0.999,
what_to_plot="param",
)
sample_violin_plot(
model,
reference_sample=data_sample,
run_name=run_name,
credibility_level=0.999,
what_to_plot="data",
)
[docs]
def full_model_check(
model: Model,
num_data_points: int = 1000,
num_model_evaluations: int = 11000,
) -> None:
"""Perform all available checks on the model.
Check your model for basic functionality and in a quick inference run on an artificially created dataset.
We recommend to run this function for every new model you create.
It runs the functions :py:func:`basic_model_check <basic_model_check>` and :py:func:`inference_model_check <inference_model_check>` to perform the checks.
Args:
model(Model): The model describing the mapping from parameters to data.
num_data_points (int, optional): The number of data data points to artificially generate (Default value = 1000)
num_model_evaluations (int, optional): The number of model evaluations to perform in the inference. (Default value = 11000)
Raises:
AssertionError: Raised if any of the tests fails.
Returns:
None
Examples:
.. code-block:: python
from eulerpi.examples.corona import Corona
from eulerpi.core.model_check import full_model_check
full_model_check(Corona())
"""
basic_model_check(model)
inference_model_check(model, num_data_points, num_model_evaluations)