Source code for tests.test_jax

"""
Test some jax functionality
"""

import jax
import jax.numpy as jnp
import pytest
from jax import random


[docs] def test_jax_cpu(): """Test wether jax can run on the cpu by executing some jax code""" jax.default_device = jax.devices("cpu")[0] do_matrix_multiplication()
[docs] @pytest.mark.xfail( jax.default_backend() == "cpu", # Jax uses gpu per default if available. So when can check wether gpu is available by checking for the default reason="XFAIL means that no GPU was visible to jax or the matrix multiplication failed, maybe jax[cuda] or cuda + cudnn not installed", ) def test_jax_gpu(): """Test wether jax can run on the gpu by executing some jax code. The jax gpu test may fail if no nvidia gpu is available or the cuda and cudnn libraries are not installed. Args: Returns: """ jax.default_device = jax.devices("gpu")[0] do_matrix_multiplication()
# The random part would be different if using numpy
[docs] def do_matrix_multiplication(n: int = 20): """Do a simple matrix-matrix multiplication for two random matrices generated by numpy Args: n(int, optional): number of entries per dimension, defaults to 20 Returns: """ key = random.PRNGKey( 0 ) # Not really random ;) [0,0] But gives reproducible results keyx, keyy = random.split(key) xj = random.normal(keyx, (n, n)) yj = random.normal(keyy, (n, n)) zj = jnp.dot(xj, yj)