Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement unconstraining transform for LKJCorr #7380

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,12 @@
rv_size_is_none,
to_tuple,
)
from pymc.distributions.transforms import Interval, ZeroSumTransform, _default_transform
from pymc.distributions.transforms import (
CholeskyCorr,
Interval,
ZeroSumTransform,
_default_transform,
)
from pymc.logprob.abstract import _logprob
from pymc.math import kron_diag, kron_dot
from pymc.pytensorf import normalize_rng_param
Expand Down Expand Up @@ -1579,7 +1584,9 @@ def logp(value, n, eta):

@_default_transform.register(_LKJCorr)
def lkjcorr_default_transform(op, rv):
return MultivariateIntervalTransform(-1.0, 1.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you delete this transform class as well? It was a (wrong) patch to the problem you're solving

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can do. Just to confirm, you don't consider MultivariateIntervalTransform to be part of pymc's public API?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, can be removed without worries

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok - great

_, _, _, n, *_ = rv.owner.inputs
n = pt.get_scalar_constant_value(n) # Safely extract scalar value without eval
return CholeskyCorr(n)


class LKJCorr:
Expand Down
170 changes: 170 additions & 0 deletions pymc/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
"log",
"sum_to_1",
"circular",
"CholeskyCorr",
"CholeskyCovPacked",
"Chain",
"ZeroSumTransform",
Expand Down Expand Up @@ -138,6 +139,175 @@ def log_jac_det(self, value, *inputs):
return pt.sum(y, axis=-1)


class CholeskyCorr(Transform):
"""
Transforms unconstrained real numbers to the off-diagonal elements of
a Cholesky decomposition of a correlation matrix.

This ensures that the resulting correlation matrix is positive definite.

#### Mathematical Details

This bijector provides a change of variables from unconstrained reals to a
parameterization of the CholeskyLKJ distribution. The CholeskyLKJ distribution
[1] is a distribution on the set of Cholesky factors of positive definite
correlation matrices. The CholeskyLKJ probability density function is
obtained from the LKJ density on n x n matrices as follows:

1 = int p(A | eta) dA
= int Z(eta) * det(A) ** (eta - 1) dA
= int Z(eta) L_ii ** {(n - i - 1) + 2 * (eta - 1)} ^dL_ij (0 <= i < j < n)

where Z(eta) is the normalizer; the matrix L is the Cholesky factor of the
correlation matrix A; and ^dL_ij denotes the wedge product (or differential)
of the strictly lower triangular entries of L. The entries L_ij are
constrained such that each entry lies in [-1, 1] and the norm of each row is
1. The norm includes the diagonal; which is not included in the wedge product.
To preserve uniqueness, we further specify that the diagonal entries are
positive.

The image of unconstrained reals under the `CorrelationCholesky` bijector is
the set of correlation matrices which are positive definite. A [correlation
matrix](https://en.wikipedia.org/wiki/Correlation_and_dependence#Correlation_matrices)
can be characterized as a symmetric positive semidefinite matrix with 1s on
the main diagonal.

For a lower triangular matrix `L` to be a valid Cholesky-factor of a positive
definite correlation matrix, it is necessary and sufficient that each row of
`L` have unit Euclidean norm [1]. To see this, observe that if `L_i` is the
`i`th row of the Cholesky factor corresponding to the correlation matrix `R`,
then the `i`th diagonal entry of `R` satisfies:

1 = R_i,i = L_i . L_i = ||L_i||^2

where '.' is the dot product of vectors and `||...||` denotes the Euclidean
norm.

Furthermore, observe that `R_i,j` lies in the interval `[-1, 1]`. By the
Cauchy-Schwarz inequality:

|R_i,j| = |L_i . L_j| <= ||L_i|| ||L_j|| = 1

This is a consequence of the fact that `R` is symmetric positive definite with
1s on the main diagonal.

We choose the mapping from x in `R^{m}` to `R^{n^2}` where `m` is the
`(n - 1)`th triangular number; i.e. `m = 1 + 2 + ... + (n - 1)`.

L_ij = x_i,j / s_i (for i < j)
L_ii = 1 / s_i

where s_i = sqrt(1 + x_i,0^2 + x_i,1^2 + ... + x_(i,i-1)^2). We can check that
the required constraints on the image are satisfied.

#### Examples

```python
transform = CholeskyCorr(n=3)
x = pt.as_tensor_variable([0.0, 0.0, 0.0])
y = transform.forward(x).eval()
# y will be the off-diagonal elements of the Cholesky factor

x_reconstructed = transform.backward(y).eval()
# x_reconstructed should closely match the original x
```

#### References
- [Stan Manual. Section 24.2. Cholesky LKJ Correlation Distribution.](https://mc-stan.org/docs/2_18/functions-reference/cholesky-lkj-correlation-distribution.html)
- Lewandowski, D., Kurowicka, D., & Joe, H. (2009). "Generating random correlation matrices based on vines and extended onion method." *Journal of Multivariate Analysis, 100*(5), 1989-2001.
"""

name = "cholesky-corr"

def __init__(self, n, validate_args=False):
"""
Initialize the CholeskyCorr transform.

Parameters
----------
n : int
Size of the correlation matrix.
validate_args : bool, default False
Whether to validate input arguments.
"""
self.n = n
self.m = int(n * (n - 1) / 2) # Number of off-diagonal elements
self.tril_r_idxs, self.tril_c_idxs = self._generate_tril_indices()
self.triu_r_idxs, self.triu_c_idxs = self._generate_triu_indices()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See below, not sure we need to cache these. __init__ is probably unnecessary

super().__init__(validate_args=validate_args)

def _generate_tril_indices(self):
row_indices, col_indices = np.tril_indices(self.n, -1)
Copy link
Member

@ricardoV94 ricardoV94 Jun 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if it matters but there is a pt.tril_indices and pt.triu_indices so no need to eval n. If it's already restricted to be constant elsewhere (like the logp), then it's fine either way

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's good practice to use the pt version, even if n is fixed

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally tried to use the pt version, but one of the function calls required constant values. However, I've made so many changes, that might no longer be the case. I'll try the pt version again and see if I can get it to work.

return (row_indices, col_indices)

def _generate_triu_indices(self):
row_indices, col_indices = np.triu_indices(self.n, 1)
return (row_indices, col_indices)

def forward(self, x, *inputs):
"""
Forward transform: Unconstrained real numbers to Cholesky factors.

Parameters
----------
x : tensor
Unconstrained real numbers.

Returns
-------
tensor
Transformed Cholesky factors.
"""
# Initialize a zero matrix
chol = pt.zeros((self.n, self.n), dtype=x.dtype)

# Assign the unconstrained values to the lower triangular part
chol = pt.set_subtensor(chol[self.tril_r_idxs, self.tril_c_idxs], x)

# Normalize each row to have unit L2 norm
row_norms = pt.sqrt(pt.sum(chol**2, axis=1, keepdims=True))
chol = chol / row_norms

return chol[self.tril_r_idxs, self.tril_c_idxs]

def backward(self, y, *inputs):
"""
Backward transform: Cholesky factors to unconstrained real numbers.

Parameters
----------
y : tensor
Cholesky factors.

Returns
-------
tensor
Unconstrained real numbers.
"""
# Reconstruct the full Cholesky matrix
chol = pt.zeros((self.n, self.n), dtype=y.dtype)
chol = pt.set_subtensor(chol[self.triu_r_idxs, self.triu_c_idxs], y)
chol = chol + pt.transpose(chol) + pt.eye(self.n, dtype=y.dtype)

# Perform Cholesky decomposition
chol = pt.linalg.cholesky(chol)

# Extract the unconstrained parameters by normalizing
row_norms = pt.sqrt(pt.sum(chol**2, axis=1))
unconstrained = chol / row_norms[:, None]

return unconstrained[self.tril_r_idxs, self.tril_c_idxs]

def log_jac_det(self, y, *inputs):
"""
Compute the log determinant of the Jacobian.

The Jacobian determinant for normalization is the product of row norms.
"""
row_norms = pt.sqrt(pt.sum(y**2, axis=1))
return -pt.sum(pt.log(row_norms), axis=-1)


class CholeskyCovPacked(Transform):
"""
Transforms the diagonal elements of the LKJCholeskyCov distribution to be on the
Expand Down
153 changes: 153 additions & 0 deletions tests/distributions/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pymc as pm
import pymc.distributions.transforms as tr

from pymc.distributions.transforms import CholeskyCorr
from pymc.logprob.basic import transformed_conditional_logp
from pymc.logprob.transforms import Transform
from pymc.pytensorf import floatX, jacobian
Expand Down Expand Up @@ -673,3 +674,155 @@ def test_deprecated_ndim_supp_transforms():

with pytest.warns(FutureWarning, match="deprecated"):
assert tr.multivariate_sum_to_1 == tr.sum_to_1


def test_lkjcorr_transform_round_trip():
"""
Test that applying the forward transform followed by the backward transform
retrieves the original unconstrained parameters, and that sampled matrices are positive definite.
"""
with pm.Model() as model:
rho = pm.LKJCorr("rho", n=3, eta=2)

trace = pm.sample(
100, tune=100, chains=1, cores=1, progressbar=False, return_inferencedata=False
)

# Extract the sampled correlation matrices
rho_samples = trace["rho"]
num_samples = rho_samples.shape[0]

for i in range(num_samples):
sample_matrix = rho_samples[i]

# Check if the sampled matrix is positive definite
try:
np.linalg.cholesky(sample_matrix)
except np.linalg.LinAlgError:
pytest.fail(f"Sampled correlation matrix at index {i} is not positive definite.")

# Perform round-trip transform: forward and then backward
transform = CholeskyCorr(n=3)
unconstrained = transform.forward(pt.as_tensor_variable(sample_matrix)).eval()
reconstructed = transform.backward(unconstrained).eval()

# Assert that the original and reconstructed unconstrained parameters are close
assert_allclose(sample_matrix, reconstructed, atol=1e-6)


def test_lkjcorr_log_jac_det():
"""
Verify that the computed log determinant of the Jacobian matches the expected value
obtained from PyTensor's automatic differentiation with a non-trivial input.
"""
n = 3
transform = CholeskyCorr(n=n)

# Create a non-trivial sample unconstrained vector
x = np.random.randn(int(n * (n - 1) / 2)).astype(pytensor.config.floatX)
x_tensor = pt.as_tensor_variable(x)

# Perform forward transform to obtain Cholesky factors
y = transform.forward(x_tensor)

# Compute the log determinant using the transform's method
computed_log_jac_det = transform.log_jac_det(y).eval()

# Define the backward function
backward = transform.backward

# Compute the Jacobian matrix using PyTensor's automatic differentiation
backward_transformed = backward(y)
jacobian_matrix = pt.jacobian(backward_transformed, y)

# Compile the function to compute the Jacobian matrix
jacobian_func = pytensor.function([], jacobian_matrix)
jacobian_val = jacobian_func()

# Compute the log determinant of the Jacobian matrix
actual_log_jac_det = np.log(np.abs(np.linalg.det(jacobian_val)))

# Compare the two
assert_allclose(computed_log_jac_det, actual_log_jac_det, atol=1e-6)


@pytest.mark.parametrize("n", [2, 4, 5])
def test_lkjcorr_transform_various_sizes(n):
"""
Test the CholeskyCorr transform with various sizes of correlation matrices.
"""
transform = CholeskyCorr(n=n)
unconstrained_size = int(n * (n - 1) / 2)

# Generate random unconstrained real numbers
x = np.random.randn(unconstrained_size).astype(pytensor.config.floatX)
x_tensor = pt.as_tensor_variable(x)

# Perform forward transform
y = transform.forward(x_tensor).eval()

# Perform backward transform
reconstructed = transform.backward(y).eval()

# Assert that the original and reconstructed unconstrained parameters are close
assert_allclose(x, reconstructed, atol=1e-6)


def test_lkjcorr_invalid_n():
"""
Test that initializing CholeskyCorr with invalid 'n' values raises appropriate errors.
"""
with pytest.raises(ValueError):
# 'n' must be an integer greater than 1
CholeskyCorr(n=1)

with pytest.raises(TypeError):
# 'n' must be an integer
CholeskyCorr(n="three")


def test_lkjcorr_positive_definite():
"""
Ensure that all sampled correlation matrices are positive definite.
"""
with pm.Model() as model:
rho = pm.LKJCorr("rho", n=4, eta=2)

trace = pm.sample(
100, tune=100, chains=1, cores=1, progressbar=False, return_inferencedata=False
)

# Extract the sampled correlation matrices
rho_samples = trace["rho"]
num_samples = rho_samples.shape[0]

for i in range(num_samples):
sample_matrix = rho_samples[i]

# Check if the sampled matrix is positive definite
try:
np.linalg.cholesky(sample_matrix)
except np.linalg.LinAlgError:
pytest.fail(f"Sampled correlation matrix at index {i} is not positive definite.")


def test_lkjcorr_round_trip_various_sizes():
"""
Perform round-trip transformation tests for various sizes of correlation matrices.
"""
for n in [2, 3, 4]:
transform = CholeskyCorr(n=n)
unconstrained_size = int(n * (n - 1) / 2)

# Generate random unconstrained real numbers
x = np.random.randn(unconstrained_size).astype(pytensor.config.floatX)
x_tensor = pt.as_tensor_variable(x)

# Perform forward transform
y = transform.forward(x_tensor).eval()

# Perform backward transform
reconstructed = transform.backward(y).eval()

# Assert that the original and reconstructed unconstrained parameters are close
assert_allclose(x, reconstructed, atol=1e-6)
Loading