"""
Test the slices functionality for each of the inference methods.
"""
import numpy as np
import pytest
from eulerpi.core.inference import InferenceType, inference
from eulerpi.core.model import Model
from eulerpi.examples.corona import CoronaArtificial
# Parametrize the test to run for each inference type
[docs]
@pytest.mark.parametrize(
"inference_type",
InferenceType._member_map_.values(),
ids=InferenceType._member_names_,
)
def test_slices(inference_type):
""" """
model: Model = CoronaArtificial()
# generate artificial data
if model.is_artificial():
num_data_points = 100
params = model.generate_artificial_params(num_data_points)
data = model.generate_artificial_data(params)
else:
raise Exception("This test is only for artificial data")
slice1 = np.array([0])
slice2 = np.array([1, 2])
slice3 = np.array([0, 1, 2])
slices = [slice1, slice2, slice3]
if inference_type == InferenceType.MCMC:
kwargs = {"num_steps": 100}
else:
kwargs = {}
inference(
model,
data,
inference_type,
slices=slices,
**kwargs,
)