Source code for eulerpi.core.models.jax_extension

from functools import partial
from typing import Callable

import jax
import jax.numpy as jnp


[docs] def value_and_jacfwd( fun: Callable[[jnp.ndarray], jnp.ndarray] ) -> Callable[[jnp.ndarray], Callable[[jnp.ndarray], jnp.ndarray]]: """Returns a function that computes the value and the jacobian of the passed function using forward mode AD. Args: fun(Callable[[jnp.ndarray], jnp.ndarray]): The function to supplement with the jacobian Returns: typing.Callable[[jnp.ndarray], typing.Tuple[jnp.ndarray, jnp.ndarray]]: A function that computes the value and the jacobian of the passed function using forward mode AD. """ def value_and_jacfwd_fun( x: jnp.ndarray, ) -> Callable[[jnp.ndarray], jnp.ndarray]: """ Args: x(jnp.ndarray): The input to the function Returns: typing.Tuple[jnp.ndarray, jnp.ndarray]: The value and the jacobian of the passed function using forward mode AD. """ pushfwd = partial(jax.jvp, fun, (x,)) y, jac = jax.vmap(pushfwd, out_axes=(None, -1))( (jnp.eye(x.shape[0], dtype=x.dtype),) ) return y, jac return value_and_jacfwd_fun
[docs] def value_and_jacrev( fun: Callable[..., jnp.ndarray] ) -> Callable[[jnp.ndarray], Callable[[jnp.ndarray], jnp.ndarray]]: """Returns a function that computes the value and the jacobian of the passed function using reverse mode AD. Args: fun(Callable[..., jnp.ndarray]): The function to supplement with the jacobian Returns: typing.Callable[[jnp.ndarray], typing.Tuple[jnp.ndarray, jnp.ndarray]]: A function that computes the value and the jacobian of the passed function using reverse mode AD. """ def value_and_jacrev_fun(x): """ Args: x(jnp.ndarray): The input to the function Returns: typing.Tuple[jnp.ndarray, jnp.ndarray]: The value and the jacobian of the passed function using reverse mode AD. """ y, pullback = jax.vjp(fun, x) jac = jax.vmap(pullback)(jnp.eye(y.shape[0], dtype=y.dtype))[0] return y, jac return value_and_jacrev_fun