"""Test the slices functionality for each of the inference methods."""importnumpyasnpimportpytestfromeulerpi.core.dense_grid_typesimportDenseGridTypefromeulerpi.core.inferenceimportinferencefromeulerpi.core.inference_typesimportInferenceTypefromeulerpi.core.modelsimportArtificialModelInterface,BaseModelfromeulerpi.examples.coronaimportCoronaArtificial# Parametrize the test to run for each inference type
[docs]@pytest.mark.parametrize("dense_grid_type",list(DenseGridType),ids=DenseGridType._member_names_,)deftest_dense_grid(dense_grid_type):""" """model:BaseModel=CoronaArtificial()# generate artificial dataifisinstance(model,ArtificialModelInterface):num_data_points=100params=model.generate_artificial_params(num_data_points)data=model.generate_artificial_data(params)else:raiseException("This test is only for artificial data")slice1=np.array([0])slice2=np.array([1,2])slice3=np.array([0,1,2])slices=[slice1,slice2,slice3]inference(model,data,inference_type=InferenceType.DENSE_GRID,slices=slices,dense_grid_type=dense_grid_type,num_grid_points=4,)