tests.test_transformations module
- class X2Model[source]
Bases:
JaxModel
,ArtificialModelInterface
- bw()
Jacobian of forward with respect to positional argument(s) 0. Takes the same arguments as forward but returns the jacobian of the output with respect to the arguments at positions 0.
- classmethod forward(param)[source]
Executed the forward pass of the model to obtain data from a parameter.
- Parameters:
param (np.ndarray) – The parameter for which the data should be generated.
- Returns:
The data generated from the parameter.
- Return type:
np.ndarray
Examples:
import numpy as np from eulerpi.examples.heat import Heat from eulerpi.core.model import JaxModel from jax import vmap # instantiate the heat model model = Heat() # define a 3D example parameter for the heat model example_param = np.array([1.4, 1.6, 0.5]) # the forward simulation is achieved by using the forward method of the model sim_result = model.forward(example_param) # in a more realistic scenario, we would like to perform the forward pass on multiple parameters at once multiple_params = np.array([[1.5, 1.5, 0.5], [1.4, 1.4, 0.6], [1.6, 1.6, 0.4], model.central_param, [1.5, 1.4, 0.4]]) # try to use jax vmap to perform the forward pass on multiple parameters at once if isinstance(model, JaxModel): multiple_sim_results = vmap(model.forward, in_axes=0)(multiple_params) # if the model is not a jax model, we can use numpy vectorize to perform the forward pass else: multiple_sim_results = np.vectorize(model.forward, signature="(n)->(m)")(multiple_params)
- fw()
- generate_artificial_params(num_samples: int) Array [source]
This method must be overwritten an return an numpy array of num_samples parameters.
- Parameters:
num_samples (int) – The number of parameters to generate.
- Returns:
The generated parameters.
- Return type:
np.ndarray
- Raises:
NotImplementedError – If the method is not overwritten in a subclass.
- vj()
- Parameters:
x (jnp.ndarray) – The input to the function
- Returns:
The value and the jacobian of the passed function using reverse mode AD.
- Return type:
Tuple[jnp.ndarray, jnp.ndarray]
- CENTRAL_PARAM = array([1.])
- PARAM_LIMITS = array([[0., 2.]])
- data_dim: int | None = 1
- param_dim: int | None = 1