tests.test_jax module

Test some jax functionality

do_matrix_multiplication(n: int = 20)[source]

Do a simple matrix-matrix multiplication for two random matrices generated by numpy

Parameters:

n (int, optional) – number of entries per dimension, defaults to 20

Returns:

test_jax_cpu()[source]

Test wether jax can run on the cpu by executing some jax code

test_jax_gpu()[source]

Test wether jax can run on the gpu by executing some jax code. The jax gpu test may fail if no nvidia gpu is available or the cuda and cudnn libraries are not installed.

Args:

Returns: