diff --git a/hdbscan/hdbscan_.py b/hdbscan/hdbscan_.py index 6ce07964..181df732 100644 --- a/hdbscan/hdbscan_.py +++ b/hdbscan/hdbscan_.py @@ -1299,6 +1299,12 @@ def generate_branch_detection_data(self): branches within clusters. This data is only useful if you are intending to use functions from ``hdbscan.branches``. """ + if self._min_spanning_tree is None: + raise ValueError("Branch prediction requires a minimum spanning tree; please re-run " + "with `branch_repdiction_data=True` or at least `gen_min_spanning_tree=True` " + "and this this function to generate the required information for branch " + "branch detection." + ) if self.metric in FAST_METRICS: min_samples = self.min_samples or self.min_cluster_size if self.metric in KDTREE_VALID_METRICS: diff --git a/hdbscan/tests/test_branches.py b/hdbscan/tests/test_branches.py index 5a6d9a36..c9bd281c 100644 --- a/hdbscan/tests/test_branches.py +++ b/hdbscan/tests/test_branches.py @@ -184,9 +184,7 @@ def test_branch_detection_data_with_unsupported_input(): def test_generate_branch_detection_data(): """Generate branch detection data function does not re-generate MST.""" c = HDBSCAN(min_cluster_size=5).fit(X) - c.generate_branch_detection_data() - assert c.branch_detection_data_ is not None - assert_raises(AttributeError, lambda: c.minimum_spanning_tree_) + assert_raises(ValueError, c.generate_branch_detection_data) # --- Detecting Branches @@ -287,15 +285,12 @@ def test_badargs(): c = HDBSCAN(min_cluster_size=5, branch_detection_data=True).fit(X) c_nofit = HDBSCAN(min_cluster_size=5, branch_detection_data=True) c_nobranch = HDBSCAN(min_cluster_size=5, gen_min_span_tree=True).fit(X) - c_nomst = HDBSCAN(min_cluster_size=5).fit(X) - c_nomst.generate_branch_detection_data() assert_raises(AttributeError, detect_branches_in_clusters, "fail") assert_raises(AttributeError, detect_branches_in_clusters, None) assert_raises(AttributeError, detect_branches_in_clusters, "fail") assert_raises(ValueError, detect_branches_in_clusters, c_nofit) assert_raises(AttributeError, detect_branches_in_clusters, c_nobranch) - assert_raises(ValueError, detect_branches_in_clusters, c_nomst) assert_raises(ValueError, detect_branches_in_clusters, c, min_branch_size=-1) assert_raises(ValueError, detect_branches_in_clusters, c, min_branch_size=0) assert_raises(ValueError, detect_branches_in_clusters, c, min_branch_size=1)