Source code for eulerpi.core.data_transformations.data_normalization
fromtypingimportTupleimportjax.numpyasjnpfromjaximportjit,tree_utilfrom.affine_transformationimportAffineTransformation@jitdefcompute_normalization(data)->Tuple[jnp.ndarray,jnp.ndarray]:mean_vector=jnp.mean(data,axis=0)cov=jnp.cov(data,rowvar=False)L=jnp.linalg.cholesky(jnp.atleast_2d(cov))normalizing_matrix=jnp.linalg.inv(L)shift_vector=-normalizing_matrix@mean_vector# Use jnp.squeeze to reduce dimensions if normalizing_matrix is (1, 1)normalizing_matrix=jnp.squeeze(normalizing_matrix)returnnormalizing_matrix,shift_vector
[docs]classDataNormalization(AffineTransformation):"""Class for normalizing data. The data is normalized by subtracting the mean and multiplying by the inverse of the Cholesky decomposition of the covariance matrix."""def__init__(self,data:jnp.ndarray):"""Initialize a DataNormalization object. Args: data (jnp.ndarray): The data from which to calculate the mean vector and normalizing matrix. """normalizing_matrix,shift_vector=compute_normalization(data)super().__init__(normalizing_matrix,shift_vector)@classmethoddef_tree_unflatten(cls,aux_data,children):"""Unflatten the DataNormalization object for JAX."""# Create an instance of DataNormalization without invoking its __init__instance=cls.__new__(cls)# Bypasses DataNormalization.__init__# Initialize instance using AffineTransformation's __init__AffineTransformation.__init__(instance,*children)# Calls AffineTransformation's __init__returninstance# Return the correctly initialized DataNormalization instance
# Register the pytree node for JAX to handle serialization for DataNormalizationtree_util.register_pytree_node(DataNormalization,DataNormalization._tree_flatten,DataNormalization._tree_unflatten,)