diff --git a/deepsensor/model/convnp.py b/deepsensor/model/convnp.py index 57c7b789..bd0ab468 100644 --- a/deepsensor/model/convnp.py +++ b/deepsensor/model/convnp.py @@ -539,10 +539,10 @@ def std(self, task: Task): def alpha( self, dist: AbstractMultiOutputDistribution ) -> Union[np.ndarray, List[np.ndarray]]: - if self.config["likelihood"] not in ["spikes-beta", "bernoulli-gamma"]: + if self.config["likelihood"] not in ["spikes-beta"]: raise NotImplementedError( f"ConvNP.alpha method not supported for likelihood {self.config['likelihood']}. " - f"Try changing the likelihood to a mixture model, e.g. 'spikes-beta' or 'bernoulli-gamma'." + f"Valid likelihoods: 'spikes-beta'." ) alpha = dist.slab.alpha alpha = self._cast_numpy_and_squeeze(alpha) @@ -576,10 +576,10 @@ def alpha(self, task: Task) -> Union[np.ndarray, List[np.ndarray]]: def beta( self, dist: AbstractMultiOutputDistribution ) -> Union[np.ndarray, List[np.ndarray]]: - if self.config["likelihood"] not in ["spikes-beta", "bernoulli-gamma"]: + if self.config["likelihood"] not in ["spikes-beta"]: raise NotImplementedError( f"ConvNP.beta method not supported for likelihood {self.config['likelihood']}. " - f"Try changing the likelihood to a mixture model, e.g. 'spikes-beta' or 'bernoulli-gamma'." + f"Valid likelihoods: 'spikes-beta'." ) beta = dist.slab.beta beta = self._cast_numpy_and_squeeze(beta) @@ -608,6 +608,80 @@ def beta(self, task: Task) -> Union[np.ndarray, List[np.ndarray]]: dist = self(task) return self.beta(dist) + @dispatch + def k( + self, dist: AbstractMultiOutputDistribution + ) -> Union[np.ndarray, List[np.ndarray]]: + if self.config["likelihood"] not in ["bernoulli-gamma"]: + raise NotImplementedError( + f"ConvNP.k method not supported for likelihood {self.config['likelihood']}. " + f"Valid likelihoods: 'bernoulli-gamma'." + ) + k = dist.slab.k + k = self._cast_numpy_and_squeeze(k) + return self._maybe_concat_multi_targets(k) + + @dispatch + def k(self, task: Task) -> Union[np.ndarray, List[np.ndarray]]: + """ + k parameter values of model's distribution at target locations in task. + + Returned numpy arrays have shape ``(N_features, *N_targets)``. + + .. note:: + This method only works for models that return a distribution with + a ``dist.slab.k`` attribute, e.g. models with a Beta or + Bernoulli-Gamma likelihood, where it returns the k values of + the slab component of the mixture model. + + Args: + task (:class:`~.data.task.Task`): + The task containing the context and target data. + + Returns: + :class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]: + k values. + """ + dist = self(task) + return self.k(dist) + + @dispatch + def scale( + self, dist: AbstractMultiOutputDistribution + ) -> Union[np.ndarray, List[np.ndarray]]: + if self.config["likelihood"] not in ["bernoulli-gamma"]: + raise NotImplementedError( + f"ConvNP.scale method not supported for likelihood {self.config['likelihood']}. " + f"Valid likelihoods: 'bernoulli-gamma'." + ) + scale = dist.slab.scale + scale = self._cast_numpy_and_squeeze(scale) + return self._maybe_concat_multi_targets(scale) + + @dispatch + def scale(self, task: Task) -> Union[np.ndarray, List[np.ndarray]]: + """ + Scale parameter values of model's distribution at target locations in task. + + Returned numpy arrays have shape ``(N_features, *N_targets)``. + + .. note:: + This method only works for models that return a distribution with + a ``dist.slab.scale`` attribute, e.g. models with a Beta or + Bernoulli-Gamma likelihood, where it returns the scale values of + the slab component of the mixture model. + + Args: + task (:class:`~.data.task.Task`): + The task containing the context and target data. + + Returns: + :class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]: + Scale values. + """ + dist = self(task) + return self.scale(dist) + @dispatch def mixture_probs(self, dist: AbstractMultiOutputDistribution): if self.N_mixture_components == 1: diff --git a/tests/test_model.py b/tests/test_model.py index e9e9f75d..5193c480 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -193,7 +193,7 @@ def test_prediction_shapes_lowlevel(self, n_target_sets): n_targets * dim_y_combined * n_target_dims, ), ) - if likelihood in ["cnp-spikes-beta"]: + if likelihood in ["cnp-spikes-beta", "bernoulli-gamma"]: mixture_probs = model.mixture_probs(task) if isinstance(mixture_probs, (list, tuple)): for p, dim_y in zip(mixture_probs, tl.target_dims): @@ -215,6 +215,7 @@ def test_prediction_shapes_lowlevel(self, n_target_sets): ), ) + if likelihood in ["cnp-spikes-beta"]: x = model.alpha(task) if isinstance(x, (list, tuple)): for p, dim_y in zip(x, tl.target_dims): @@ -229,6 +230,21 @@ def test_prediction_shapes_lowlevel(self, n_target_sets): else: assert_shape(x, (dim_y_combined, *expected_obs_shape)) + if likelihood in ["bernoulli-gamma"]: + x = model.k(task) + if isinstance(x, (list, tuple)): + for p, dim_y in zip(x, tl.target_dims): + assert_shape(p, (dim_y, *expected_obs_shape)) + else: + assert_shape(x, (dim_y_combined, *expected_obs_shape)) + + x = model.scale(task) + if isinstance(x, (list, tuple)): + for p, dim_y in zip(x, tl.target_dims): + assert_shape(p, (dim_y, *expected_obs_shape)) + else: + assert_shape(x, (dim_y_combined, *expected_obs_shape)) + # Scalars if likelihood in ["cnp", "gnp"]: # Methods for Gaussian likelihoods only @@ -451,61 +467,75 @@ def test_highlevel_predict_coords_align_with_X_t_offgrid(self): def test_highlevel_predict_with_pred_params_pandas(self): """ Test that passing ``pred_params`` to ``.predict`` works with - a spikes-beta likelihood for prediction to pandas. + mixture model likelihoods for off-grid prediction to pandas. """ tl = TaskLoader(context=self.da, target=self.da) - model = ConvNP( - self.dp, - tl, - unet_channels=(5, 5, 5), - verbose=False, - likelihood="cnp-spikes-beta", - ) - task = tl("2020-01-01", context_sampling=10, target_sampling=10) - # Off-grid prediction - X_t = np.array([[0.0, 0.5, 1.0], [0.0, 0.5, 1.0]]) + likelihoods = ["cnp-spikes-beta", "bernoulli-gamma"] + expected_pred_params = [ + ["mean", "std", "variance", "alpha", "beta"], + ["mean", "std", "variance", "k", "scale"], + ] - # Check that nothing breaks and the correct parameters are returned - pred_params = ["mean", "std", "variance", "alpha", "beta"] - pred = model.predict(task, X_t=X_t, pred_params=pred_params) - for pred_param in pred_params: - assert pred_param in pred["var"] + for likelihood, pred_params in zip(likelihoods, expected_pred_params): + model = ConvNP( + self.dp, + tl, + unet_channels=(5, 5, 5), + verbose=False, + likelihood=likelihood, + ) + task = tl("2020-01-01", context_sampling=10) + + # Off-grid prediction + X_t = np.array([[0.0, 0.5, 1.0], [0.0, 0.5, 1.0]]) - # Test mixture probs special case - pred_params = ["mixture_probs"] - pred = model.predict(task, X_t=self.da, pred_params=pred_params) - for component in range(model.N_mixture_components): - pred_param = f"mixture_probs_{component}" - assert pred_param in pred["var"] + # Check that nothing breaks and the correct parameters are returned + pred = model.predict(task, X_t=X_t, pred_params=pred_params) + for pred_param in pred_params: + assert pred_param in pred["var"] + + # Test mixture probs special case + pred_params = ["mixture_probs"] + pred = model.predict(task, X_t=self.da, pred_params=pred_params) + for component in range(model.N_mixture_components): + pred_param = f"mixture_probs_{component}" + assert pred_param in pred["var"] def test_highlevel_predict_with_pred_params_xarray(self): """ Test that passing ``pred_params`` to ``.predict`` works with - a spikes-beta likelihood for prediction to xarray. + mixture model likelihoods for gridded prediction to xarray. """ tl = TaskLoader(context=self.da, target=self.da) - model = ConvNP( - self.dp, - tl, - unet_channels=(5, 5, 5), - verbose=False, - likelihood="cnp-spikes-beta", - ) - task = tl("2020-01-01", context_sampling=10, target_sampling=10) - # Check that nothing breaks and the correct parameters are returned - pred_params = ["mean", "std", "variance", "alpha", "beta"] - pred = model.predict(task, X_t=self.da, pred_params=pred_params) - for pred_param in pred_params: - assert pred_param in pred["var"] - - # Test mixture probs special case - pred_params = ["mixture_probs"] - pred = model.predict(task, X_t=self.da, pred_params=pred_params) - for component in range(model.N_mixture_components): - pred_param = f"mixture_probs_{component}" - assert pred_param in pred["var"] + likelihoods = ["cnp-spikes-beta", "bernoulli-gamma"] + expected_pred_params = [ + ["mean", "std", "variance", "alpha", "beta"], + ["mean", "std", "variance", "k", "scale"], + ] + + for likelihood, pred_params in zip(likelihoods, expected_pred_params): + model = ConvNP( + self.dp, + tl, + unet_channels=(5, 5, 5), + verbose=False, + likelihood=likelihood, + ) + task = tl("2020-01-01", context_sampling=10) + + # Check that nothing breaks and the correct parameters are returned + pred = model.predict(task, X_t=self.da, pred_params=pred_params) + for pred_param in pred_params: + assert pred_param in pred["var"] + + # Test mixture probs special case + pred_params = ["mixture_probs"] + pred = model.predict(task, X_t=self.da, pred_params=pred_params) + for component in range(model.N_mixture_components): + pred_param = f"mixture_probs_{component}" + assert pred_param in pred["var"] def test_highlevel_predict_with_invalid_pred_params(self): """Test that passing ``pred_params`` to ``.predict`` works."""