Corona SEIR Model
- class Corona(central_param: ndarray = array([-1.8, 0., 0.7]), param_limits: ndarray = array([[-3., 0.], [-2., 2.], [-2., 2.]]), name: str | None = None, **kwargs)[source]
Bases:
JaxModel
Describes the dynamics of the corona virus.
\begin{eqnarray} \frac{d[S]}{dt} = & -q_1[S][I] \\ \frac{d[E]}{dt} = & q_1[S][I] - q_2[E] \\ \frac{d[I]}{dt}= & q_2[E] - q_3[I] \\ \frac{d[R]}{dt}= & q_3[I] \end{eqnarray}subject to
\[{\left([S](t=0), \ [E](t=0), \ [I](t=0), \ [R](t=0)\right)}^\intercal= \left(999, \ 0, \ 1, \ 0\right)^\intercal,\]Note
ODE Solver: To solve the ODE problem the jax based ode solver library
diffrax
is used: https://github.com/patrick-kidger/diffrax.Automatic Differentiation: The derivatives are calculated automatically with jax by deriving from the class
JaxModel
, which automatically calculates setsjacobian()
.JIT compilation: Inheriting from
JaxModel
also enables jit compilation / optimization for the forward and jacobian method. This usually results in a significant execution speedup. It also allows to run your model on the gpu.
- classmethod forward(log_param)[source]
Executed the forward pass of the model to obtain data from a parameter.
- Parameters:
param (np.ndarray) – The parameter for which the data should be generated.
- Returns:
The data generated from the parameter.
- Return type:
np.ndarray
Examples:
import numpy as np from eulerpi.examples.heat import Heat from eulerpi.core.model import JaxModel from jax import vmap # instantiate the heat model model = Heat() # define a 3D example parameter for the heat model example_param = np.array([1.4, 1.6, 0.5]) # the forward simulation is achieved by using the forward method of the model sim_result = model.forward(example_param) # in a more realistic scenario, we would like to perform the forward pass on multiple parameters at once multiple_params = np.array([[1.5, 1.5, 0.5], [1.4, 1.4, 0.6], [1.6, 1.6, 0.4], model.central_param, [1.5, 1.4, 0.4]]) # try to use jax vmap to perform the forward pass on multiple parameters at once if isinstance(model, JaxModel): multiple_sim_results = vmap(model.forward, in_axes=0)(multiple_params) # if the model is not a jax model, we can use numpy vectorize to perform the forward pass else: multiple_sim_results = np.vectorize(model.forward, signature="(n)->(m)")(multiple_params)
- CENTRAL_PARAM = array([-1.8, 0. , 0.7])
- PARAM_LIMITS = array([[-3., 0.], [-2., 2.], [-2., 2.]])
- data_dim: int | None = 4
- param_dim: int | None = 3
- class CoronaArtificial(central_param: ndarray = array([-1.8, 0., 0.7]), param_limits: ndarray = array([[-2.5, -1.], [-0.75, 0.75], [0., 1.5]]), name: str | None = None, **kwargs)[source]
Bases:
Corona
,ArtificialModelInterface
- generate_artificial_params(num_samples)[source]
This method must be overwritten an return an numpy array of num_samples parameters.
- Parameters:
num_samples (int) – The number of parameters to generate.
- Returns:
The generated parameters.
- Return type:
np.ndarray
- Raises:
NotImplementedError – If the method is not overwritten in a subclass.
- PARAM_LIMITS = array([[-2.5 , -1. ], [-0.75, 0.75], [ 0. , 1.5 ]])