from typing import Optional
import diffrax as dx
import jax.numpy as jnp
import numpy as np
from eulerpi import logger
from eulerpi.core.models import ArtificialModelInterface, JaxModel
[docs]
class Corona(JaxModel):
"""Describes the dynamics of the corona virus.
.. math::
:nowrap:
\\begin{eqnarray}
\\frac{d[S]}{dt} = & -q_1[S][I] \\\\
\\frac{d[E]}{dt} = & q_1[S][I] - q_2[E] \\\\
\\frac{d[I]}{dt}= & q_2[E] - q_3[I] \\\\
\\frac{d[R]}{dt}= & q_3[I]
\\end{eqnarray}
subject to
.. math::
{\\left([S](t=0), \\ [E](t=0), \\ [I](t=0), \\ [R](t=0)\\right)}^\\intercal= \\left(999, \\ 0, \\ 1, \\ 0\\right)^\\intercal,
.. note::
* ODE Solver: To solve the ODE problem the jax based ode solver library :code:`diffrax` is used: https://github.com/patrick-kidger/diffrax.
* Automatic Differentiation: The derivatives are calculated automatically with jax by deriving from the class :py:class:`~eulerpi.core.models.JaxModel`,
which automatically calculates sets :py:meth:`~eulerpi.core.models.BaseModel.jacobian`.
* JIT compilation: Inheriting from :py:class:`~eulerpi.core.models.JaxModel` also enables jit compilation / optimization for the forward and jacobian method.
This usually results in a significant execution speedup. It also allows to run your model on the gpu.
"""
param_dim = 3
data_dim = 4
PARAM_LIMITS = np.array([[-3.0, 0.0], [-2.0, 2.0], [-2.0, 2.0]])
CENTRAL_PARAM = np.array([-1.8, 0.0, 0.7])
def __init__(
self,
central_param: np.ndarray = CENTRAL_PARAM,
param_limits: np.ndarray = PARAM_LIMITS,
name: Optional[str] = None,
**kwargs,
) -> None:
super().__init__(central_param, param_limits, name=name, **kwargs)
[docs]
@classmethod
def forward(cls, log_param):
param = jnp.power(10, log_param)
xInit = jnp.array([999.0, 0.0, 1.0, 0.0])
def rhs(t, x, param):
return jnp.array(
[
-param[0] * x[0] * x[2],
param[0] * x[0] * x[2] - param[1] * x[1],
param[1] * x[1] - param[2] * x[2],
param[2] * x[2],
]
)
term = dx.ODETerm(rhs)
solver = dx.Kvaerno5()
saveat = dx.SaveAt(ts=[0.0, 1.0, 2.0, 5.0, 15.0])
stepsize_controller = dx.PIDController(rtol=1e-7, atol=1e-9)
try:
ode_sol = dx.diffeqsolve(
term,
solver,
t0=0.0,
t1=15.0,
dt0=0.01,
y0=xInit,
args=param,
saveat=saveat,
stepsize_controller=stepsize_controller,
)
return ode_sol.ys[1:5, 2]
except Exception as e:
logger.warning("ODE solution not possible!", exc_info=e)
return np.array([-np.inf, -np.inf, -np.inf, -np.inf])
[docs]
class CoronaArtificial(Corona, ArtificialModelInterface):
PARAM_LIMITS = np.array([[-2.5, -1.0], [-0.75, 0.75], [0.0, 1.5]])
def __init__(
self,
central_param: np.ndarray = Corona.CENTRAL_PARAM,
param_limits: np.ndarray = PARAM_LIMITS,
name: Optional[str] = None,
**kwargs,
) -> None:
super().__init__(central_param, param_limits, name=name, **kwargs)
[docs]
def generate_artificial_params(self, num_samples):
lower_bound = np.array([-1.9, -0.1, 0.6])
upper_bound = np.array([-1.7, 0.1, 0.8])
true_param_sample = lower_bound + (
upper_bound - lower_bound
) * np.random.rand(num_samples, 3)
return true_param_sample