eulerpi.jax_extension module

value_and_jacfwd(fun: Callable[[Array], Array]) Callable[[Array], Callable[[Array], Array]][source]

Returns a function that computes the value and the jacobian of the passed function using forward mode AD.

Parameters:

fun (Callable[[jnp.ndarray], jnp.ndarray]) – The function to supplement with the jacobian

Returns:

A function that computes the value and the jacobian of the passed function using forward mode AD.

Return type:

Callable[[jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]

value_and_jacrev(fun: Callable[[...], Array]) Callable[[Array], Callable[[Array], Array]][source]

Returns a function that computes the value and the jacobian of the passed function using reverse mode AD.

Parameters:

fun (Callable[..., jnp.ndarray]) – The function to supplement with the jacobian

Returns:

A function that computes the value and the jacobian of the passed function using reverse mode AD.

Return type:

Callable[[jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray]]