Source code for eulerpi.core.model_check

"""Check a custom :py:class:`BaseModel <eulerpi.core.models.BaseModel>` for implementation errors or test them in a quick inference run on an artificially created dataset."""

import jax.numpy as jnp
import numpy as np

from eulerpi.core.inference import InferenceType, inference
from eulerpi.core.models import BaseModel
from eulerpi.core.plotting import sample_violin_plot


[docs] def basic_model_check(model: BaseModel) -> None: """Perform a simple sanity check on the model. It tests the following: - The model has a positive parameter dimension - The model has a positive data dimension - The model has a valid combination of parameter and data dimension - The central parameter has the correct shape - The parameter limits have the correct shape - The model can be instantiated - The model forward pass can be calculated - The model jacobi matrix can be calculated - The return values of the forward pass and the jacobi matrix have the correct shape - The jacobi matrix has full rank Args: model(BaseModel): The model describing the mapping from parameters to data. Raises: AssertionError: Raised if any of the tests fails. Returns: None Examples: .. code-block:: python from eulerpi.examples.corona import Corona from eulerpi.core.model_check import basic_model_check basic_model_check(Corona()) """ print( f"Checking model {model.name} at location \n{model} \nfor basic functionality.\n" ) # test the shapes assert ( model.param_dim > 0 ), f"Model {model} has a non-positive parameter dimension" assert ( model.data_dim > 0 ), f"Model {model} has a non-positive data dimension" assert model.data_dim >= model.param_dim, ( f"Model {model} has a data dimension smaller than the parameter dimension. " "This is not supported by the inference." ) assert model.central_param.shape == (model.param_dim,), ( f"Model {model} has a central parameter with the wrong shape. " f"Expected {(model.param_dim,)}, got {model.central_param.shape}" ) assert model.param_limits.shape == (model.param_dim, 2), ( f"Model {model} has parameter limits with the wrong shape. " f"Expected {(model.param_dim, 2)}, got {model.param_limits.shape}" ) print("Successfully checked shapes and dimensions of model attributes.\n") print( f"Evaluate model {model.name} and its jacobian in its central parameter \n{model.central_param}." ) model_forward = model.forward(model.central_param) assert ( model_forward.shape == (1, model.data_dim) or model_forward.shape == (model.data_dim,) or model_forward.shape == () ), ( f"Model {model} has a forward function with the wrong shape. " f"Expected {(1, model.data_dim)}, {(model.data_dim,)} or {()}, got {model_forward.shape}" ) model_jac = model.jacobian(model.central_param) assert ( model_jac.shape == (model.data_dim, model.param_dim) or (model.data_dim == 1 and model_jac.shape == (model.param_dim,)) or (model.param_dim == 1 and model_jac.shape == (model.data_dim,)) ), ( f"Model {model} has a jacobian function with the wrong shape. " f"Expected {(model.data_dim, model.param_dim)}, {(model.param_dim,)} or {(model.data_dim,)}, got {model_jac.shape}" ) # check rank of jacobian assert jnp.linalg.matrix_rank(model_jac) == model.param_dim, ( f"The Jacobian of the model {model} does not have full rank. This is a requirement for the inference. " "Please check the model implementation." ) fw, jc = model.forward_and_jacobian(model.central_param) assert fw.shape == model_forward.shape, ( f"The shape {fw.shape} of the forward function extracted from the forward_and_jacobian function does not match the shape {model_forward.shape} of the forward function. " "Please check the model implementation." ) assert jc.shape == model_jac.shape, ( f"The shape {jc.shape} of the jacobian extracted from the forward_and_jacobian function does not match the shape {model_jac.shape} of the jacobian. " "Please check the model implementation." ) assert jnp.allclose(fw, model_forward), ( f"The forward function of the model {model} does not match the forward function extracted from the forward_and_jacobian function. " "Please check the model implementation." ) assert jnp.allclose(jc, model_jac), ( f"The jacobian of the model {model} does not match the jacobian extracted from the forward_and_jacobian function. " "Please check the model implementation." ) print( "Successfully checked model forward simulation and corresponding jacobian.\n" )
[docs] def inference_model_check( model: BaseModel, num_data_points: int = 1000, num_model_evaluations: int = 11000, ) -> None: """Check your model in a quick inference run on an artificially created dataset. It produces a violin plot comparing the artificially created parameters and data to the respectively inferred samples. Args: model(BaseModel): The model describing the mapping from parameters to data. num_data_points (int, optional): The number of data data points to artificially generate (Default value = 1000) num_model_evaluations (int, optional): The number of model evaluations to perform in the inference. (Default value = 11000) Returns: None Examples: .. code-block:: python from eulerpi.examples.corona import Corona from eulerpi.core.model_check import inference_model_check inference_model_check(Corona()) """ print( f"Checking model {model.name} at location \n{model} \nfor inference functionality on artificially created data.\n" ) # create artificial parametrs similar to how we create initial walker positions for emcee sampling central_param = model.central_param param_limits = model.param_limits # sample parameters from a uniform distribution around the central parameter and between the parameter limits d_min = np.minimum( central_param - param_limits[:, 0], param_limits[:, 1] - central_param ) param_sample = central_param + d_min * ( (np.random.rand(num_data_points, model.param_dim) - 0.5) / 3.0 ) data_sample = model.forward_vectorized(param_sample) print( f"Successfully created an artificial data set of size {num_data_points}.\n" ) # choose sensible values for the sampling hyper-parameters and print them num_inference_evaluations = num_model_evaluations - num_data_points num_walkers = int(np.sqrt(num_inference_evaluations / 10)) num_steps = int(num_inference_evaluations / num_walkers) num_burn_in_samples = num_walkers thinning_factor = int(np.ceil(num_walkers / 10)) print("Attempting inference with hyperparameters chosen as follows:") print(f"num_data_points: {num_data_points}") print(f"num_walkers: {num_walkers}") print(f"num_steps: {num_steps}") print(f"num_burn_in_samples: {num_burn_in_samples}") print(f"thinning_factor: {thinning_factor}") run_name = "test_model_run" # perform the inference inference( model, data=data_sample, inference_type=InferenceType.MCMC, slices=[np.arange(model.param_dim)], run_name=run_name, num_runs=1, num_walkers=num_walkers, num_steps=num_steps, num_burn_in_samples=num_burn_in_samples, thinning_factor=thinning_factor, ) print( f"Successfully finishes inference run with {num_walkers*num_steps} samples.\n" ) # plot the results sample_violin_plot( model, reference_sample=param_sample, run_name=run_name, credibility_level=0.999, what_to_plot="param", ) sample_violin_plot( model, reference_sample=data_sample, run_name=run_name, credibility_level=0.999, what_to_plot="data", )
[docs] def full_model_check( model: BaseModel, num_data_points: int = 1000, num_model_evaluations: int = 11000, ) -> None: """Perform all available checks on the model. Check your model for basic functionality and in a quick inference run on an artificially created dataset. We recommend to run this function for every new model you create. It runs the functions :py:func:`basic_model_check <basic_model_check>` and :py:func:`inference_model_check <inference_model_check>` to perform the checks. Args: model(BaseModel): The model describing the mapping from parameters to data. num_data_points (int, optional): The number of data data points to artificially generate (Default value = 1000) num_model_evaluations (int, optional): The number of model evaluations to perform in the inference. (Default value = 11000) Raises: AssertionError: Raised if any of the tests fails. Returns: None Examples: .. code-block:: python from eulerpi.examples.corona import Corona from eulerpi.core.model_check import full_model_check full_model_check(Corona()) """ basic_model_check(model) inference_model_check(model, num_data_points, num_model_evaluations)