eulerpi.core.models.jax_model module
- class JaxModel(central_param: ndarray, param_limits: ndarray, name: str | None = None, **kwargs)[source]
Bases:
BaseModel
The JaxModel is a base class for models using the JAX library.
It automatically creates the jacobian method based on the forward method. Additionally it jit compiles the forward and jacobian method with jax for faster execution.
Note
To use this class you have to implement your forward method using jax, e. g. jax.numpy. Dont overwrite the __init__ method of JaxModel without calling the super constructor. Else your forward method wont be jitted.
- forward_and_jacobian(param: ndarray) Tuple[ndarray, ndarray] [source]
Evaluates the jacobian and the forward pass of the model at the same time. This can be more efficient than calling the
forward()
andjacobian()
methods separately.- Parameters:
param (np.ndarray) – The parameter for which the jacobian should be evaluated.
- Returns:
The data and the jacobian for a given parameter.
- Return type:
Tuple[np.ndarray, np.ndarray]
- static forward_method(self, param: ndarray) ndarray [source]
This method is called by the jitted forward method. It is not intended to be called directly.
- Parameters:
param (np.ndarray) – The parameter for which the data should be generated.
- Returns:
The data generated from the parameter.
- Return type:
np.ndarray
- forward_vectorized(params: ndarray) ndarray [source]
A vectorized version of the forward function
- Parameters:
params (np.ndarray) – an array of parameters, shape (n, self.param_dim)
- Returns:
The data vector generated from the vector of parameters, shape (n, self.data_dim)
- Return type:
np.ndarray
- classmethod init_fw_and_bw()[source]
Calculates the jitted methods for the subclass(es). It is an unintended sideeffect that this happens for all intermediate classes also. E.g. for: class CoronaArtificial(Corona)
- jacobian(param: ndarray) ndarray [source]
Jacobian of the forward pass with respect to the parameters.
- Parameters:
param (np.ndarray) – The parameter for which the jacobian should be evaluated.
- Returns:
The jacobian for the variables returned by the
forward()
method with respect to the parameters.- Return type:
np.ndarray
- add_autodiff(_cls)[source]
Decorator to automatically create the jacobian method based on the forward method. Additionally it jit compiles the forward and jacobian method with jax.
- Parameters:
_cls – The class to decorate.
- Returns:
The decorated class with the jacobian method and the forward and jacobian method jit compiled with jax.