Source code for tests.test_slices

"""
Test the slices functionality for each of the inference methods.
"""

import numpy as np
import pytest

from eulerpi.core.inference import InferenceType, inference
from eulerpi.core.model import Model
from eulerpi.examples.corona import CoronaArtificial

# Parametrize the test to run for each inference type


[docs] @pytest.mark.parametrize( "inference_type", InferenceType._member_map_.values(), ids=InferenceType._member_names_, ) def test_slices(inference_type): """ """ model: Model = CoronaArtificial() # generate artificial data if model.is_artificial(): num_data_points = 100 params = model.generate_artificial_params(num_data_points) data = model.generate_artificial_data(params) else: raise Exception("This test is only for artificial data") slice1 = np.array([0]) slice2 = np.array([1, 2]) slice3 = np.array([0, 1, 2]) slices = [slice1, slice2, slice3] if inference_type == InferenceType.MCMC: kwargs = {"num_steps": 100} else: kwargs = {} inference( model, data, inference_type, slices=slices, **kwargs, )