"""Test some jax functionality"""importjaximportjax.numpyasjnpimportpytestfromjaximportrandom
[docs]deftest_jax_cpu():"""Test wether jax can run on the cpu by executing some jax code"""jax.default_device=jax.devices("cpu")[0]do_matrix_multiplication()
[docs]@pytest.mark.xfail(jax.default_backend()=="cpu",# Jax uses gpu per default if available. So when can check wether gpu is available by checking for the defaultreason="XFAIL means that no GPU was visible to jax or the matrix multiplication failed, maybe jax[cuda] or cuda + cudnn not installed",)deftest_jax_gpu():"""Test wether jax can run on the gpu by executing some jax code. The jax gpu test may fail if no nvidia gpu is available or the cuda and cudnn libraries are not installed. Args: Returns: """jax.default_device=jax.devices("gpu")[0]do_matrix_multiplication()
# The random part would be different if using numpy
[docs]defdo_matrix_multiplication(n:int=20):"""Do a simple matrix-matrix multiplication for two random matrices generated by numpy Args: n(int, optional): number of entries per dimension, defaults to 20 Returns: """key=random.PRNGKey(0)# Not really random ;) [0,0] But gives reproducible resultskeyx,keyy=random.split(key)xj=random.normal(keyx,(n,n))yj=random.normal(keyy,(n,n))zj=jnp.dot(xj,yj)