Source code for tests.test_dense_grid

"""
Test the slices functionality for each of the inference methods.
"""

import numpy as np
import pytest

from eulerpi.core.dense_grid_types import DenseGridType
from eulerpi.core.inference import inference
from eulerpi.core.inference_types import InferenceType
from eulerpi.core.models import ArtificialModelInterface, BaseModel
from eulerpi.examples.corona import CoronaArtificial


# Parametrize the test to run for each inference type
[docs] @pytest.mark.parametrize( "dense_grid_type", list(DenseGridType), ids=DenseGridType._member_names_, ) def test_dense_grid(dense_grid_type): """ """ model: BaseModel = CoronaArtificial() # generate artificial data if isinstance(model, ArtificialModelInterface): num_data_points = 100 params = model.generate_artificial_params(num_data_points) data = model.generate_artificial_data(params) else: raise Exception("This test is only for artificial data") slice1 = np.array([0]) slice2 = np.array([1, 2]) slice3 = np.array([0, 1, 2]) slices = [slice1, slice2, slice3] inference( model, data, inference_type=InferenceType.DENSE_GRID, slices=slices, dense_grid_type=dense_grid_type, num_grid_points=4, )