tests.test_transformations module

class X2Model[source]

Bases: JaxModel, ArtificialModelInterface

bw()

Jacobian of forward with respect to positional argument(s) 0. Takes the same arguments as forward but returns the jacobian of the output with respect to the arguments at positions 0.

classmethod forward(param)[source]

Executed the forward pass of the model to obtain data from a parameter.

Parameters:

param (np.ndarray) – The parameter for which the data should be generated.

Returns:

The data generated from the parameter.

Return type:

np.ndarray

Examples:

import numpy as np
from eulerpi.examples.heat import Heat
from eulerpi.core.model import JaxModel
from jax import vmap

# instantiate the heat model
model = Heat()

# define a 3D example parameter for the heat model
example_param = np.array([1.4, 1.6, 0.5])

# the forward simulation is achieved by using the forward method of the model
sim_result = model.forward(example_param)

# in a more realistic scenario, we would like to perform the forward pass on multiple parameters at once
multiple_params = np.array([[1.5, 1.5, 0.5],
                            [1.4, 1.4, 0.6],
                            [1.6, 1.6, 0.4],
                            model.central_param,
                            [1.5, 1.4, 0.4]])

# try to use jax vmap to perform the forward pass on multiple parameters at once
if isinstance(model, JaxModel):
    multiple_sim_results = vmap(model.forward, in_axes=0)(multiple_params)

# if the model is not a jax model, we can use numpy vectorize to perform the forward pass
else:
    multiple_sim_results = np.vectorize(model.forward, signature="(n)->(m)")(multiple_params)
fw()
generate_artificial_params(num_samples: int) Array[source]

This method must be overwritten an return an numpy array of num_samples parameters.

Parameters:

num_samples (int) – The number of parameters to generate.

Returns:

The generated parameters.

Return type:

np.ndarray

Raises:

NotImplementedError – If the method is not overwritten in a subclass.

vj()
Parameters:

x (jnp.ndarray) – The input to the function

Returns:

The value and the jacobian of the passed function using reverse mode AD.

Return type:

Tuple[jnp.ndarray, jnp.ndarray]

CENTRAL_PARAM = array([1.])
PARAM_LIMITS = array([[0., 2.]])
data_dim: int | None = 1
param_dim: int | None = 1
test_calc_gram_determinant()[source]
test_evaluate_density(caplog)[source]