"""Tests for the Kernel Density Estimation module"""importjaximportjax.scipy.statsasjstatsimportmatplotlib.pyplotaspltimportnumpyasnpimportpytestfrommatplotlibimportcmfromeulerpi.core.dense_gridimportgenerate_regular_gridfromeulerpi.core.kdeimportcalc_kernel_width,eval_kde_cauchy,eval_kde_gauss
[docs]defkernel_estimators():"""Yields the kernel density estimators"""yieldeval_kde_cauchyyieldeval_kde_gauss
[docs]defkde_data_test_set():"""Yields test data for the KDE data tests"""data_list=[np.array([[0.5,2.0]]),np.array([[0.0,0.0],[1.0,1.0],[2.0,0.0]]),]stdev_list=[np.array([2.0,1.0]),np.array([0.5,1.0])]grid_bounds_list=[[[-2.5,3.5],[0.5,3.5]],[[-1.0,3.0],[-1,2.0]]]foriinrange(len(data_list)):yielddata_list[i],stdev_list[i],grid_bounds_list[i]
# Mark the test as expected to fail because the function is not implemented yet, but definitly should be implemented soon!
[docs]@pytest.mark.xfail(reason="Not implemented yet, but is important!")deftest_calc_kernel_width():data=np.array([[0.0,1.0],[-1.0,2.0],[1.0,3.0]])data_stdevs=calc_kernel_width(data)assert0==1
[docs]@pytest.mark.parametrize("evalKDE",kernel_estimators())@pytest.mark.parametrize("batch",[False,True])deftest_KDE_batch(batch,evalKDE):"""Test both kernel density estimators by using random data points and evaluating the Kernel Density Estimation at one point Args: batch: evalKDE: Returns: """# Define random data points in 2Ddata_dim=2num_data_points=3data=np.random.rand(num_data_points,data_dim)stdevs=calc_kernel_width(data)# define the evaluation point(s)n_samples=5ifbatch:evalPoint=np.random.rand(n_samples,data_dim)else:evalPoint=np.array([0.5]*data_dim)# evaluate the KDEevaluated=evalKDE(data,evalPoint,stdevs)# The additional dimension should be still there if batch is Trueifbatch:assertevaluated.shape==(n_samples,)else:assertevaluated.shape==()
[docs]@pytest.mark.parametrize("data, stdevs, grid_bounds",kde_data_test_set())@pytest.mark.parametrize("evalKDE",kernel_estimators())deftest_KDE_data(evalKDE,data,stdevs,grid_bounds,resolution=33):"""Test both kernel density estimators by using one data point and evaluating the Kernel Density Estimation over a grid Args: evalKDE: data: stdevs: grid_bounds: resolution: (Default value = 33) Returns: """xGrid=np.linspace(*(grid_bounds[0]),resolution)yGrid=np.linspace(*(grid_bounds[1]),resolution)xMesh,yMesh=np.meshgrid(xGrid,yGrid)# We only want to vectorize the call for the evaluation points in the mesh, not for the data points.# Map over axis 0 because the grid points are stored row-wise in the meshevaluated=jax.vmap(evalKDE,in_axes=(None,0,None))(data,np.stack([xMesh,yMesh],axis=-1),stdevs)fig,ax=plt.subplots(subplot_kw={"projection":"3d"})surf=ax.plot_surface(xMesh,yMesh,evaluated,alpha=0.75,cmap=cm.coolwarm,linewidth=0,antialiased=False,)plt.show()
# WARNING: The following code only works for the simplest case. Equidistant grid, same number of points in each dimension, ...
[docs]defintegrate(z,x,y):# Integrate the function over the gridintegral=np.trapz(np.trapz(z,y,axis=0),x,axis=0)returnintegral
[docs]@pytest.mark.parametrize("dim",[1,2],ids=["1D","2D"])deftest_kde_convergence_gauss(dim,num_grid_points=100,num_data_points=10000):"""Test whether the KDE converges to the true distribution."""# Generate random numbers from a normal distributiondata=np.random.randn(num_data_points,dim)stdevs=calc_kernel_width(data)# Define the gridnum_grid_points=np.array([num_grid_pointsfor_inrange(dim)],dtype=np.int32)limits=np.array([[-5,5]for_inrange(dim)])grid=generate_regular_grid(num_grid_points,limits,flatten=True)kde_on_grid=eval_kde_gauss(data,grid,stdevs)# Evaluate the KDEmean=np.zeros(dim)cov=np.eye(dim)exact_on_grid=jstats.multivariate_normal.pdf(grid,mean,cov)# Evaluate the true distributiondiff=np.abs(kde_on_grid-exact_on_grid)# difference between the two# Plot the KDEimportmatplotlib.pyplotaspltifdim==1:grid=grid[:,0]error=np.trapz(diff,grid)# Calculate the errorasserterror<0.1# ~0.06 for 100 grid points, 1000 data pointsplt.plot(grid,kde_on_grid)plt.plot(grid,exact_on_grid)elifdim==2:# Calculate the errordiff=diff.reshape(num_grid_points[0],num_grid_points[1])x=np.linspace(limits[0,0],limits[0,1],num_grid_points[0])y=np.linspace(limits[1,0],limits[1,1],num_grid_points[1])error=integrate(diff,x,y)asserterror<0.15# ~0.13 for 100 grid points, 1000 data points# Surface plotfig=plt.figure()ax=fig.add_subplot(111,projection="3d")grid_2d=grid.reshape(num_grid_points[0],num_grid_points[1],dim)exact_on_grid_2d=exact_on_grid.reshape(num_grid_points[0],num_grid_points[1])surf=ax.plot_surface(grid_2d[:,:,0],grid_2d[:,:,1],exact_on_grid_2d,alpha=0.7,label="exact",)surf._edgecolors2d=surf._edgecolor3dsurf._facecolors2d=surf._facecolor3dkde_on_grid_2d=kde_on_grid.reshape(num_grid_points[0],num_grid_points[1])surf=ax.plot_surface(grid_2d[:,:,0],grid_2d[:,:,1],kde_on_grid_2d,alpha=0.7,label="kde",)surf._edgecolors2d=surf._edgecolor3dsurf._facecolors2d=surf._facecolor3dplt.show()