eulerpi.core.models.jax_extension module
- value_and_jacfwd(fun: Callable[[Array], Array]) Callable[[Array], Callable[[Array], Array]] [source]
Returns a function that computes the value and the jacobian of the passed function using forward mode AD.
- Parameters:
fun (Callable[[jnp.ndarray], jnp.ndarray]) – The function to supplement with the jacobian
- Returns:
A function that computes the value and the jacobian of the passed function using forward mode AD.
- Return type:
Callable[[jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]
- value_and_jacrev(fun: Callable[[...], Array]) Callable[[Array], Callable[[Array], Array]] [source]
Returns a function that computes the value and the jacobian of the passed function using reverse mode AD.
- Parameters:
fun (Callable[..., jnp.ndarray]) – The function to supplement with the jacobian
- Returns:
A function that computes the value and the jacobian of the passed function using reverse mode AD.
- Return type:
Callable[[jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]