Corona SEIR Model

class Corona(central_param: ndarray = array([-1.8, 0., 0.7]), param_limits: ndarray = array([[-3., 0.], [-2., 2.], [-2., 2.]]), name: str | None = None, **kwargs)[source]

Bases: JaxModel

Describes the dynamics of the corona virus.

\begin{eqnarray} \frac{d[S]}{dt} = & -q_1[S][I] \\ \frac{d[E]}{dt} = & q_1[S][I] - q_2[E] \\ \frac{d[I]}{dt}= & q_2[E] - q_3[I] \\ \frac{d[R]}{dt}= & q_3[I] \end{eqnarray}

subject to

\[{\left([S](t=0), \ [E](t=0), \ [I](t=0), \ [R](t=0)\right)}^\intercal= \left(999, \ 0, \ 1, \ 0\right)^\intercal,\]

Note

  • ODE Solver: To solve the ODE problem the jax based ode solver library diffrax is used: https://github.com/patrick-kidger/diffrax.

  • Automatic Differentiation: The derivatives are calculated automatically with jax by deriving from the class JaxModel, which automatically calculates sets jacobian().

  • JIT compilation: Inheriting from JaxModel also enables jit compilation / optimization for the forward and jacobian method. This usually results in a significant execution speedup. It also allows to run your model on the gpu.

classmethod forward(log_param)[source]

Executed the forward pass of the model to obtain data from a parameter.

Parameters:

param (np.ndarray) – The parameter for which the data should be generated.

Returns:

The data generated from the parameter.

Return type:

np.ndarray

Examples:

import numpy as np
from eulerpi.examples.heat import Heat
from eulerpi.core.model import JaxModel
from jax import vmap

# instantiate the heat model
model = Heat()

# define a 3D example parameter for the heat model
example_param = np.array([1.4, 1.6, 0.5])

# the forward simulation is achieved by using the forward method of the model
sim_result = model.forward(example_param)

# in a more realistic scenario, we would like to perform the forward pass on multiple parameters at once
multiple_params = np.array([[1.5, 1.5, 0.5],
                            [1.4, 1.4, 0.6],
                            [1.6, 1.6, 0.4],
                            model.central_param,
                            [1.5, 1.4, 0.4]])

# try to use jax vmap to perform the forward pass on multiple parameters at once
if isinstance(model, JaxModel):
    multiple_sim_results = vmap(model.forward, in_axes=0)(multiple_params)

# if the model is not a jax model, we can use numpy vectorize to perform the forward pass
else:
    multiple_sim_results = np.vectorize(model.forward, signature="(n)->(m)")(multiple_params)
CENTRAL_PARAM = array([-1.8,  0. ,  0.7])
PARAM_LIMITS = array([[-3.,  0.],        [-2.,  2.],        [-2.,  2.]])
data_dim: int | None = 4
param_dim: int | None = 3
class CoronaArtificial(central_param: ndarray = array([-1.8, 0., 0.7]), param_limits: ndarray = array([[-2.5, -1.], [-0.75, 0.75], [0., 1.5]]), name: str | None = None, **kwargs)[source]

Bases: Corona, ArtificialModelInterface

generate_artificial_params(num_samples)[source]

This method must be overwritten an return an numpy array of num_samples parameters.

Parameters:

num_samples (int) – The number of parameters to generate.

Returns:

The generated parameters.

Return type:

np.ndarray

Raises:

NotImplementedError – If the method is not overwritten in a subclass.

PARAM_LIMITS = array([[-2.5 , -1.  ],        [-0.75,  0.75],        [ 0.  ,  1.5 ]])