Source code for tests.test_pickling

"""Test pickling of the various classes in the package."""

import importlib
import pickle

import numpy as np
import pytest

from eulerpi.core.models import BaseModel
from tests.test_examples import Examples, get_example_name


[docs] @pytest.mark.parametrize("example", Examples(), ids=get_example_name) def test_pickling_model(example): try: module_location, className, _ = example except ValueError: module_location, className = example dataFileName = None # Import class dynamically to avoid error on imports at the top which cant be tracked back to a specific test module = importlib.import_module(module_location) ModelClass = getattr(module, className) model: BaseModel = ModelClass() # Test pickling dumped = pickle.dumps(model) loaded = pickle.loads(dumped) for key in model.__dict__.keys(): if key == "amici_model" or key == "amici_solver": # Can't compare the models directly, because I dont know how to compare the SwigPyObject objects # Iterate over dict and compare each value, besides the SwigPyObject amici_model and amici_solver continue if isinstance(model.__dict__[key], np.ndarray): # catch numpy arrays, they dont implement __eq__ like normal python objects assert np.array_equal(model.__dict__[key], loaded.__dict__[key]) else: assert model.__dict__[key] == loaded.__dict__[key]