from typing import Optional
import jax.numpy as jnp
import numpy as np
from eulerpi.core.model import ArtificialModelInterface, JaxModel
[docs]
class Linear(JaxModel, ArtificialModelInterface):
param_dim = 2
data_dim = 2
PARAM_LIMITS = np.array([[-0.2, 1.2], [-0.2, 1.2]])
CENTRAL_PARAM = np.array([0.5, 0.5])
def __init__(
self,
central_param: np.ndarray = CENTRAL_PARAM,
param_limits: np.ndarray = PARAM_LIMITS,
name: Optional[str] = None,
**kwargs,
) -> None:
super().__init__(central_param, param_limits, name=name, **kwargs)
[docs]
@classmethod
def forward(cls, param):
return jnp.array([param[0] * 10, (-2.0) * param[1] - 2.0])
[docs]
def generate_artificial_params(self, num_samples: int):
return np.random.rand(num_samples, self.param_dim)
[docs]
class Exponential(JaxModel):
param_dim = 2
data_dim = 2
PARAM_LIMITS = np.array([[1.0, 2.0], [1.0, 2.0]])
CENTRAL_PARAM = np.array([1.0, 1.0])
def __init__(
self,
central_param: np.ndarray = CENTRAL_PARAM,
param_limits: np.ndarray = PARAM_LIMITS,
name: Optional[str] = None,
**kwargs,
) -> None:
super().__init__(central_param, param_limits, name=name, **kwargs)
[docs]
@classmethod
def forward(cls, param):
return jnp.array([param[0] * jnp.exp(1), jnp.exp(param[1])])
[docs]
class LinearODE(JaxModel, ArtificialModelInterface):
param_dim = 2
data_dim = 2
PARAM_LIMITS = np.array([[-2.0, 4.0], [-2.0, 4.0]])
CENTRAL_PARAM = np.array([1.5, 1.5])
TRUE_PARAM_LIMITS = np.array([[1.0, 2.0], [1.0, 2.0]])
def __init__(
self,
central_param: np.ndarray = CENTRAL_PARAM,
param_limits: np.ndarray = PARAM_LIMITS,
name: Optional[str] = None,
**kwargs,
) -> None:
super().__init__(central_param, param_limits, name=name, **kwargs)
[docs]
@classmethod
def forward(cls, param):
return jnp.array(
[
param[0] * jnp.exp(param[1] * 1.0),
param[0] * jnp.exp(param[1] * 2.0),
]
)
[docs]
def generate_artificial_params(self, num_samples: int):
return np.random.rand(num_samples, 2) + 1