Skip to content

Commit

Permalink
Expose k and scale parameters of "bernoulli-gamma" ConvNP lik…
Browse files Browse the repository at this point in the history
…elihood in low-level and high-level prediction interface (closes #123)
  • Loading branch information
tom-andersson committed Jul 28, 2024
1 parent 88a9818 commit 52caf11
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 48 deletions.
82 changes: 78 additions & 4 deletions deepsensor/model/convnp.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,10 +539,10 @@ def std(self, task: Task):
def alpha(
self, dist: AbstractMultiOutputDistribution
) -> Union[np.ndarray, List[np.ndarray]]:
if self.config["likelihood"] not in ["spikes-beta", "bernoulli-gamma"]:
if self.config["likelihood"] not in ["spikes-beta"]:
raise NotImplementedError(
f"ConvNP.alpha method not supported for likelihood {self.config['likelihood']}. "
f"Try changing the likelihood to a mixture model, e.g. 'spikes-beta' or 'bernoulli-gamma'."
f"Valid likelihoods: 'spikes-beta'."
)
alpha = dist.slab.alpha
alpha = self._cast_numpy_and_squeeze(alpha)
Expand Down Expand Up @@ -576,10 +576,10 @@ def alpha(self, task: Task) -> Union[np.ndarray, List[np.ndarray]]:
def beta(
self, dist: AbstractMultiOutputDistribution
) -> Union[np.ndarray, List[np.ndarray]]:
if self.config["likelihood"] not in ["spikes-beta", "bernoulli-gamma"]:
if self.config["likelihood"] not in ["spikes-beta"]:
raise NotImplementedError(
f"ConvNP.beta method not supported for likelihood {self.config['likelihood']}. "
f"Try changing the likelihood to a mixture model, e.g. 'spikes-beta' or 'bernoulli-gamma'."
f"Valid likelihoods: 'spikes-beta'."
)
beta = dist.slab.beta
beta = self._cast_numpy_and_squeeze(beta)
Expand Down Expand Up @@ -608,6 +608,80 @@ def beta(self, task: Task) -> Union[np.ndarray, List[np.ndarray]]:
dist = self(task)
return self.beta(dist)

@dispatch
def k(
self, dist: AbstractMultiOutputDistribution
) -> Union[np.ndarray, List[np.ndarray]]:
if self.config["likelihood"] not in ["bernoulli-gamma"]:
raise NotImplementedError(
f"ConvNP.k method not supported for likelihood {self.config['likelihood']}. "
f"Valid likelihoods: 'bernoulli-gamma'."
)
k = dist.slab.k
k = self._cast_numpy_and_squeeze(k)
return self._maybe_concat_multi_targets(k)

@dispatch
def k(self, task: Task) -> Union[np.ndarray, List[np.ndarray]]:
"""
k parameter values of model's distribution at target locations in task.
Returned numpy arrays have shape ``(N_features, *N_targets)``.
.. note::
This method only works for models that return a distribution with
a ``dist.slab.k`` attribute, e.g. models with a Beta or
Bernoulli-Gamma likelihood, where it returns the k values of
the slab component of the mixture model.
Args:
task (:class:`~.data.task.Task`):
The task containing the context and target data.
Returns:
:class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]:
k values.
"""
dist = self(task)
return self.k(dist)

@dispatch
def scale(
self, dist: AbstractMultiOutputDistribution
) -> Union[np.ndarray, List[np.ndarray]]:
if self.config["likelihood"] not in ["bernoulli-gamma"]:
raise NotImplementedError(
f"ConvNP.scale method not supported for likelihood {self.config['likelihood']}. "
f"Valid likelihoods: 'bernoulli-gamma'."
)
scale = dist.slab.scale
scale = self._cast_numpy_and_squeeze(scale)
return self._maybe_concat_multi_targets(scale)

@dispatch
def scale(self, task: Task) -> Union[np.ndarray, List[np.ndarray]]:
"""
Scale parameter values of model's distribution at target locations in task.
Returned numpy arrays have shape ``(N_features, *N_targets)``.
.. note::
This method only works for models that return a distribution with
a ``dist.slab.scale`` attribute, e.g. models with a Beta or
Bernoulli-Gamma likelihood, where it returns the scale values of
the slab component of the mixture model.
Args:
task (:class:`~.data.task.Task`):
The task containing the context and target data.
Returns:
:class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]:
Scale values.
"""
dist = self(task)
return self.scale(dist)

@dispatch
def mixture_probs(self, dist: AbstractMultiOutputDistribution):
if self.N_mixture_components == 1:
Expand Down
118 changes: 74 additions & 44 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def test_prediction_shapes_lowlevel(self, n_target_sets):
n_targets * dim_y_combined * n_target_dims,
),
)
if likelihood in ["cnp-spikes-beta"]:
if likelihood in ["cnp-spikes-beta", "bernoulli-gamma"]:
mixture_probs = model.mixture_probs(task)
if isinstance(mixture_probs, (list, tuple)):
for p, dim_y in zip(mixture_probs, tl.target_dims):
Expand All @@ -215,6 +215,7 @@ def test_prediction_shapes_lowlevel(self, n_target_sets):
),
)

if likelihood in ["cnp-spikes-beta"]:
x = model.alpha(task)
if isinstance(x, (list, tuple)):
for p, dim_y in zip(x, tl.target_dims):
Expand All @@ -229,6 +230,21 @@ def test_prediction_shapes_lowlevel(self, n_target_sets):
else:
assert_shape(x, (dim_y_combined, *expected_obs_shape))

if likelihood in ["bernoulli-gamma"]:
x = model.k(task)
if isinstance(x, (list, tuple)):
for p, dim_y in zip(x, tl.target_dims):
assert_shape(p, (dim_y, *expected_obs_shape))
else:
assert_shape(x, (dim_y_combined, *expected_obs_shape))

x = model.scale(task)
if isinstance(x, (list, tuple)):
for p, dim_y in zip(x, tl.target_dims):
assert_shape(p, (dim_y, *expected_obs_shape))
else:
assert_shape(x, (dim_y_combined, *expected_obs_shape))

# Scalars
if likelihood in ["cnp", "gnp"]:
# Methods for Gaussian likelihoods only
Expand Down Expand Up @@ -451,61 +467,75 @@ def test_highlevel_predict_coords_align_with_X_t_offgrid(self):
def test_highlevel_predict_with_pred_params_pandas(self):
"""
Test that passing ``pred_params`` to ``.predict`` works with
a spikes-beta likelihood for prediction to pandas.
mixture model likelihoods for off-grid prediction to pandas.
"""
tl = TaskLoader(context=self.da, target=self.da)
model = ConvNP(
self.dp,
tl,
unet_channels=(5, 5, 5),
verbose=False,
likelihood="cnp-spikes-beta",
)
task = tl("2020-01-01", context_sampling=10, target_sampling=10)

# Off-grid prediction
X_t = np.array([[0.0, 0.5, 1.0], [0.0, 0.5, 1.0]])
likelihoods = ["cnp-spikes-beta", "bernoulli-gamma"]
expected_pred_params = [
["mean", "std", "variance", "alpha", "beta"],
["mean", "std", "variance", "k", "scale"],
]

# Check that nothing breaks and the correct parameters are returned
pred_params = ["mean", "std", "variance", "alpha", "beta"]
pred = model.predict(task, X_t=X_t, pred_params=pred_params)
for pred_param in pred_params:
assert pred_param in pred["var"]
for likelihood, pred_params in zip(likelihoods, expected_pred_params):
model = ConvNP(
self.dp,
tl,
unet_channels=(5, 5, 5),
verbose=False,
likelihood=likelihood,
)
task = tl("2020-01-01", context_sampling=10)

# Off-grid prediction
X_t = np.array([[0.0, 0.5, 1.0], [0.0, 0.5, 1.0]])

# Test mixture probs special case
pred_params = ["mixture_probs"]
pred = model.predict(task, X_t=self.da, pred_params=pred_params)
for component in range(model.N_mixture_components):
pred_param = f"mixture_probs_{component}"
assert pred_param in pred["var"]
# Check that nothing breaks and the correct parameters are returned
pred = model.predict(task, X_t=X_t, pred_params=pred_params)
for pred_param in pred_params:
assert pred_param in pred["var"]

# Test mixture probs special case
pred_params = ["mixture_probs"]
pred = model.predict(task, X_t=self.da, pred_params=pred_params)
for component in range(model.N_mixture_components):
pred_param = f"mixture_probs_{component}"
assert pred_param in pred["var"]

def test_highlevel_predict_with_pred_params_xarray(self):
"""
Test that passing ``pred_params`` to ``.predict`` works with
a spikes-beta likelihood for prediction to xarray.
mixture model likelihoods for gridded prediction to xarray.
"""
tl = TaskLoader(context=self.da, target=self.da)
model = ConvNP(
self.dp,
tl,
unet_channels=(5, 5, 5),
verbose=False,
likelihood="cnp-spikes-beta",
)
task = tl("2020-01-01", context_sampling=10, target_sampling=10)

# Check that nothing breaks and the correct parameters are returned
pred_params = ["mean", "std", "variance", "alpha", "beta"]
pred = model.predict(task, X_t=self.da, pred_params=pred_params)
for pred_param in pred_params:
assert pred_param in pred["var"]

# Test mixture probs special case
pred_params = ["mixture_probs"]
pred = model.predict(task, X_t=self.da, pred_params=pred_params)
for component in range(model.N_mixture_components):
pred_param = f"mixture_probs_{component}"
assert pred_param in pred["var"]
likelihoods = ["cnp-spikes-beta", "bernoulli-gamma"]
expected_pred_params = [
["mean", "std", "variance", "alpha", "beta"],
["mean", "std", "variance", "k", "scale"],
]

for likelihood, pred_params in zip(likelihoods, expected_pred_params):
model = ConvNP(
self.dp,
tl,
unet_channels=(5, 5, 5),
verbose=False,
likelihood=likelihood,
)
task = tl("2020-01-01", context_sampling=10)

# Check that nothing breaks and the correct parameters are returned
pred = model.predict(task, X_t=self.da, pred_params=pred_params)
for pred_param in pred_params:
assert pred_param in pred["var"]

# Test mixture probs special case
pred_params = ["mixture_probs"]
pred = model.predict(task, X_t=self.da, pred_params=pred_params)
for component in range(model.N_mixture_components):
pred_param = f"mixture_probs_{component}"
assert pred_param in pred["var"]

def test_highlevel_predict_with_invalid_pred_params(self):
"""Test that passing ``pred_params`` to ``.predict`` works."""
Expand Down

0 comments on commit 52caf11

Please sign in to comment.