diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 359b0743dd..686b063cb9 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -69,7 +69,12 @@ rv_size_is_none, to_tuple, ) -from pymc.distributions.transforms import Interval, ZeroSumTransform, _default_transform +from pymc.distributions.transforms import ( + CholeskyCorr, + Interval, + ZeroSumTransform, + _default_transform, +) from pymc.logprob.abstract import _logprob from pymc.math import kron_diag, kron_dot from pymc.pytensorf import normalize_rng_param @@ -1579,7 +1584,9 @@ def logp(value, n, eta): @_default_transform.register(_LKJCorr) def lkjcorr_default_transform(op, rv): - return MultivariateIntervalTransform(-1.0, 1.0) + _, _, _, n, *_ = rv.owner.inputs + n = pt.get_scalar_constant_value(n) # Safely extract scalar value without eval + return CholeskyCorr(n) class LKJCorr: diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index d8998889cf..fe6010b710 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -45,6 +45,7 @@ "log", "sum_to_1", "circular", + "CholeskyCorr", "CholeskyCovPacked", "Chain", "ZeroSumTransform", @@ -138,6 +139,175 @@ def log_jac_det(self, value, *inputs): return pt.sum(y, axis=-1) +class CholeskyCorr(Transform): + """ + Transforms unconstrained real numbers to the off-diagonal elements of + a Cholesky decomposition of a correlation matrix. + + This ensures that the resulting correlation matrix is positive definite. + + #### Mathematical Details + + This bijector provides a change of variables from unconstrained reals to a + parameterization of the CholeskyLKJ distribution. The CholeskyLKJ distribution + [1] is a distribution on the set of Cholesky factors of positive definite + correlation matrices. The CholeskyLKJ probability density function is + obtained from the LKJ density on n x n matrices as follows: + + 1 = int p(A | eta) dA + = int Z(eta) * det(A) ** (eta - 1) dA + = int Z(eta) L_ii ** {(n - i - 1) + 2 * (eta - 1)} ^dL_ij (0 <= i < j < n) + + where Z(eta) is the normalizer; the matrix L is the Cholesky factor of the + correlation matrix A; and ^dL_ij denotes the wedge product (or differential) + of the strictly lower triangular entries of L. The entries L_ij are + constrained such that each entry lies in [-1, 1] and the norm of each row is + 1. The norm includes the diagonal; which is not included in the wedge product. + To preserve uniqueness, we further specify that the diagonal entries are + positive. + + The image of unconstrained reals under the `CorrelationCholesky` bijector is + the set of correlation matrices which are positive definite. A [correlation + matrix](https://en.wikipedia.org/wiki/Correlation_and_dependence#Correlation_matrices) + can be characterized as a symmetric positive semidefinite matrix with 1s on + the main diagonal. + + For a lower triangular matrix `L` to be a valid Cholesky-factor of a positive + definite correlation matrix, it is necessary and sufficient that each row of + `L` have unit Euclidean norm [1]. To see this, observe that if `L_i` is the + `i`th row of the Cholesky factor corresponding to the correlation matrix `R`, + then the `i`th diagonal entry of `R` satisfies: + + 1 = R_i,i = L_i . L_i = ||L_i||^2 + + where '.' is the dot product of vectors and `||...||` denotes the Euclidean + norm. + + Furthermore, observe that `R_i,j` lies in the interval `[-1, 1]`. By the + Cauchy-Schwarz inequality: + + |R_i,j| = |L_i . L_j| <= ||L_i|| ||L_j|| = 1 + + This is a consequence of the fact that `R` is symmetric positive definite with + 1s on the main diagonal. + + We choose the mapping from x in `R^{m}` to `R^{n^2}` where `m` is the + `(n - 1)`th triangular number; i.e. `m = 1 + 2 + ... + (n - 1)`. + + L_ij = x_i,j / s_i (for i < j) + L_ii = 1 / s_i + + where s_i = sqrt(1 + x_i,0^2 + x_i,1^2 + ... + x_(i,i-1)^2). We can check that + the required constraints on the image are satisfied. + + #### Examples + + ```python + transform = CholeskyCorr(n=3) + x = pt.as_tensor_variable([0.0, 0.0, 0.0]) + y = transform.forward(x).eval() + # y will be the off-diagonal elements of the Cholesky factor + + x_reconstructed = transform.backward(y).eval() + # x_reconstructed should closely match the original x + ``` + + #### References + - [Stan Manual. Section 24.2. Cholesky LKJ Correlation Distribution.](https://mc-stan.org/docs/2_18/functions-reference/cholesky-lkj-correlation-distribution.html) + - Lewandowski, D., Kurowicka, D., & Joe, H. (2009). "Generating random correlation matrices based on vines and extended onion method." *Journal of Multivariate Analysis, 100*(5), 1989-2001. + """ + + name = "cholesky-corr" + + def __init__(self, n, validate_args=False): + """ + Initialize the CholeskyCorr transform. + + Parameters + ---------- + n : int + Size of the correlation matrix. + validate_args : bool, default False + Whether to validate input arguments. + """ + self.n = n + self.m = int(n * (n - 1) / 2) # Number of off-diagonal elements + self.tril_r_idxs, self.tril_c_idxs = self._generate_tril_indices() + self.triu_r_idxs, self.triu_c_idxs = self._generate_triu_indices() + super().__init__(validate_args=validate_args) + + def _generate_tril_indices(self): + row_indices, col_indices = np.tril_indices(self.n, -1) + return (row_indices, col_indices) + + def _generate_triu_indices(self): + row_indices, col_indices = np.triu_indices(self.n, 1) + return (row_indices, col_indices) + + def forward(self, x, *inputs): + """ + Forward transform: Unconstrained real numbers to Cholesky factors. + + Parameters + ---------- + x : tensor + Unconstrained real numbers. + + Returns + ------- + tensor + Transformed Cholesky factors. + """ + # Initialize a zero matrix + chol = pt.zeros((self.n, self.n), dtype=x.dtype) + + # Assign the unconstrained values to the lower triangular part + chol = pt.set_subtensor(chol[self.tril_r_idxs, self.tril_c_idxs], x) + + # Normalize each row to have unit L2 norm + row_norms = pt.sqrt(pt.sum(chol**2, axis=1, keepdims=True)) + chol = chol / row_norms + + return chol[self.tril_r_idxs, self.tril_c_idxs] + + def backward(self, y, *inputs): + """ + Backward transform: Cholesky factors to unconstrained real numbers. + + Parameters + ---------- + y : tensor + Cholesky factors. + + Returns + ------- + tensor + Unconstrained real numbers. + """ + # Reconstruct the full Cholesky matrix + chol = pt.zeros((self.n, self.n), dtype=y.dtype) + chol = pt.set_subtensor(chol[self.triu_r_idxs, self.triu_c_idxs], y) + chol = chol + pt.transpose(chol) + pt.eye(self.n, dtype=y.dtype) + + # Perform Cholesky decomposition + chol = pt.linalg.cholesky(chol) + + # Extract the unconstrained parameters by normalizing + row_norms = pt.sqrt(pt.sum(chol**2, axis=1)) + unconstrained = chol / row_norms[:, None] + + return unconstrained[self.tril_r_idxs, self.tril_c_idxs] + + def log_jac_det(self, y, *inputs): + """ + Compute the log determinant of the Jacobian. + + The Jacobian determinant for normalization is the product of row norms. + """ + row_norms = pt.sqrt(pt.sum(y**2, axis=1)) + return -pt.sum(pt.log(row_norms), axis=-1) + + class CholeskyCovPacked(Transform): """ Transforms the diagonal elements of the LKJCholeskyCov distribution to be on the diff --git a/tests/distributions/test_transform.py b/tests/distributions/test_transform.py index 8d464f206a..617d2bf134 100644 --- a/tests/distributions/test_transform.py +++ b/tests/distributions/test_transform.py @@ -24,6 +24,7 @@ import pymc as pm import pymc.distributions.transforms as tr +from pymc.distributions.transforms import CholeskyCorr from pymc.logprob.basic import transformed_conditional_logp from pymc.logprob.transforms import Transform from pymc.pytensorf import floatX, jacobian @@ -673,3 +674,155 @@ def test_deprecated_ndim_supp_transforms(): with pytest.warns(FutureWarning, match="deprecated"): assert tr.multivariate_sum_to_1 == tr.sum_to_1 + + +def test_lkjcorr_transform_round_trip(): + """ + Test that applying the forward transform followed by the backward transform + retrieves the original unconstrained parameters, and that sampled matrices are positive definite. + """ + with pm.Model() as model: + rho = pm.LKJCorr("rho", n=3, eta=2) + + trace = pm.sample( + 100, tune=100, chains=1, cores=1, progressbar=False, return_inferencedata=False + ) + + # Extract the sampled correlation matrices + rho_samples = trace["rho"] + num_samples = rho_samples.shape[0] + + for i in range(num_samples): + sample_matrix = rho_samples[i] + + # Check if the sampled matrix is positive definite + try: + np.linalg.cholesky(sample_matrix) + except np.linalg.LinAlgError: + pytest.fail(f"Sampled correlation matrix at index {i} is not positive definite.") + + # Perform round-trip transform: forward and then backward + transform = CholeskyCorr(n=3) + unconstrained = transform.forward(pt.as_tensor_variable(sample_matrix)).eval() + reconstructed = transform.backward(unconstrained).eval() + + # Assert that the original and reconstructed unconstrained parameters are close + assert_allclose(sample_matrix, reconstructed, atol=1e-6) + + +def test_lkjcorr_log_jac_det(): + """ + Verify that the computed log determinant of the Jacobian matches the expected value + obtained from PyTensor's automatic differentiation with a non-trivial input. + """ + n = 3 + transform = CholeskyCorr(n=n) + + # Create a non-trivial sample unconstrained vector + x = np.random.randn(int(n * (n - 1) / 2)).astype(pytensor.config.floatX) + x_tensor = pt.as_tensor_variable(x) + + # Perform forward transform to obtain Cholesky factors + y = transform.forward(x_tensor) + + # Compute the log determinant using the transform's method + computed_log_jac_det = transform.log_jac_det(y).eval() + + # Define the backward function + backward = transform.backward + + # Compute the Jacobian matrix using PyTensor's automatic differentiation + backward_transformed = backward(y) + jacobian_matrix = pt.jacobian(backward_transformed, y) + + # Compile the function to compute the Jacobian matrix + jacobian_func = pytensor.function([], jacobian_matrix) + jacobian_val = jacobian_func() + + # Compute the log determinant of the Jacobian matrix + actual_log_jac_det = np.log(np.abs(np.linalg.det(jacobian_val))) + + # Compare the two + assert_allclose(computed_log_jac_det, actual_log_jac_det, atol=1e-6) + + +@pytest.mark.parametrize("n", [2, 4, 5]) +def test_lkjcorr_transform_various_sizes(n): + """ + Test the CholeskyCorr transform with various sizes of correlation matrices. + """ + transform = CholeskyCorr(n=n) + unconstrained_size = int(n * (n - 1) / 2) + + # Generate random unconstrained real numbers + x = np.random.randn(unconstrained_size).astype(pytensor.config.floatX) + x_tensor = pt.as_tensor_variable(x) + + # Perform forward transform + y = transform.forward(x_tensor).eval() + + # Perform backward transform + reconstructed = transform.backward(y).eval() + + # Assert that the original and reconstructed unconstrained parameters are close + assert_allclose(x, reconstructed, atol=1e-6) + + +def test_lkjcorr_invalid_n(): + """ + Test that initializing CholeskyCorr with invalid 'n' values raises appropriate errors. + """ + with pytest.raises(ValueError): + # 'n' must be an integer greater than 1 + CholeskyCorr(n=1) + + with pytest.raises(TypeError): + # 'n' must be an integer + CholeskyCorr(n="three") + + +def test_lkjcorr_positive_definite(): + """ + Ensure that all sampled correlation matrices are positive definite. + """ + with pm.Model() as model: + rho = pm.LKJCorr("rho", n=4, eta=2) + + trace = pm.sample( + 100, tune=100, chains=1, cores=1, progressbar=False, return_inferencedata=False + ) + + # Extract the sampled correlation matrices + rho_samples = trace["rho"] + num_samples = rho_samples.shape[0] + + for i in range(num_samples): + sample_matrix = rho_samples[i] + + # Check if the sampled matrix is positive definite + try: + np.linalg.cholesky(sample_matrix) + except np.linalg.LinAlgError: + pytest.fail(f"Sampled correlation matrix at index {i} is not positive definite.") + + +def test_lkjcorr_round_trip_various_sizes(): + """ + Perform round-trip transformation tests for various sizes of correlation matrices. + """ + for n in [2, 3, 4]: + transform = CholeskyCorr(n=n) + unconstrained_size = int(n * (n - 1) / 2) + + # Generate random unconstrained real numbers + x = np.random.randn(unconstrained_size).astype(pytensor.config.floatX) + x_tensor = pt.as_tensor_variable(x) + + # Perform forward transform + y = transform.forward(x_tensor).eval() + + # Perform backward transform + reconstructed = transform.backward(y).eval() + + # Assert that the original and reconstructed unconstrained parameters are close + assert_allclose(x, reconstructed, atol=1e-6)