-
Notifications
You must be signed in to change notification settings - Fork 81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Parallel inference of LGSSM in the EM algorithm (+ some bug fixes) #336
base: main
Are you sure you want to change the base?
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
…he whole LGSSM codebase
|
||
Please note that we adopt the convention of Murphy, K. P. (2022), "Probabilistic machine learning: Advanced topics", | ||
rather than Särkkä, S. (2013), "Bayesian Filtering and Smoothing" for indexing parameters of LGSSM, where we start | ||
initial index at 0 instead of 1, which is not exactly in line with the former book. This tends to be a source of | ||
confusion sometimes. As such, $F_0$, $B_0$, $b_0$, $Q_0$ are always ignored and the prior specified by $m$ and $S$ | ||
is used as the distribution of the initial state. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are several conflicts of indexing style. I added this note here to make sure that this code base follows the notation of Murphy (2023), which is
By the way, I personally prefer Sarkka's style
MultivariateNormalFullCovariance as MVN) | ||
MultivariateNormalFullCovariance as MVN, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a quite few lines of changes introduced by applying black to the modified files. Maybe it's better to first merge a separate PR of applying black, to make the diff easier to read here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from dynamax.linear_gaussian_ssm.inference import preprocess_args, _get_one_param, _get_params, _log_likelihood | ||
|
||
|
||
def _get_one_param(x, dim, t): | ||
"""Helper function to get one parameter at time t.""" | ||
if callable(x): | ||
return x(t) | ||
elif x.ndim == dim + 1: | ||
return x[t] | ||
else: | ||
return x | ||
|
||
def _get_params(params, num_timesteps, t): | ||
"""Helper function to get parameters at time t.""" | ||
assert not callable(params.emissions.cov), "Emission covariance cannot be a callable." | ||
|
||
F = _get_one_param(params.dynamics.weights, 2, t) | ||
b = _get_one_param(params.dynamics.bias, 1, t) | ||
Q = _get_one_param(params.dynamics.cov, 2, t) | ||
H = _get_one_param(params.emissions.weights, 2, t+1) | ||
d = _get_one_param(params.emissions.bias, 1, t+1) | ||
|
||
if len(params.emissions.cov.shape) == 1: | ||
R = _get_one_param(params.emissions.cov, 1, t+1) | ||
elif len(params.emissions.cov.shape) > 2: | ||
R = _get_one_param(params.emissions.cov, 2, t+1) | ||
elif params.emissions.cov.shape[0] != num_timesteps: | ||
R = _get_one_param(params.emissions.cov, 2, t+1) | ||
elif params.emissions.cov.shape[1] != num_timesteps: | ||
R = _get_one_param(params.emissions.cov, 1, t+1) | ||
else: | ||
R = _get_one_param(params.emissions.cov, 2, t+1) | ||
warnings.warn( | ||
"Emission covariance has shape (N,N) where N is the number of timesteps. " | ||
"The covariance will be interpreted as static and non-diagonal. To " | ||
"specify a dynamic and diagonal covariance, pass it as a 3D array.") | ||
|
||
return F, b, Q, H, d, R |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remoted these kinds of duplicated utility functions in parallel_inference.py
and used the one defined in inference.py
|
||
from jax.config import config | ||
|
||
config.update("jax_enable_x64", True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests for marginal likelihood were quite unstable for float32, probably due instability of log det computation. I'd suggest enabling float64 as default for that reason.
""" | ||
if R.ndim == 2: | ||
S = H @ Q @ H.T + R | ||
return -MVN(jnp.zeros_like(y), S).log_prob(y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bug fix: the bias term was missed here.
# Get parameters and inputs for time index t | ||
F, B, b, Q = _get_params(params, num_timesteps, t)[:4] | ||
u = inputs[t] | ||
# Get parameters and inputs for time index t + 1 | ||
F_next, B_next, b_next, Q_next = _get_params(params, num_timesteps, t + 1)[:4] | ||
u_next = inputs[t + 1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bug fix: calculation of the mean on the next time step requires parameters at the next time step (unless you use Sarkka (2013)'s indexing instead of Murphy (2023)'s).
@@ -12,86 +12,111 @@ | |||
from dynamax.linear_gaussian_ssm.inference_test import flatten_diagonal_emission_cov | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test cases are updated to check if the parallel inference can handle inputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, the synthetic data for testing the time-varying case was too simplistic to capture some bugs while I was developing the code. So I updated the test case so that it has more time variation of parameters.
Hi, I wanted to use parallel filtering and smoothing of LGSSM for the EM algorithm so I updated the parallel inference functions to the level of feature parity with serial filtering and smoothing.
During the implementation, I found a couple of bugs as well so this PR includes the bug fix as well. (They are joint sampling logic in inference.py and missing emission bias term in the log likelihood of parallel_inference.py).
I thought this branch is almost ready for PR but it seems that I am having a large conflict due to the recent diagonal covariance PR. I will mark the PR as ready when the conflict is resolved.Now ready for review!