diff --git a/ghstack/submit.py b/ghstack/submit.py index 1e801af..e3cddeb 100644 --- a/ghstack/submit.py +++ b/ghstack/submit.py @@ -378,7 +378,7 @@ def run(self) -> List[DiffMeta]: ] self.push_updates(diffs_to_submit) if new_head := rebase_index.get( - GitCommitHash(self.sh.git("rev-parse", "HEAD")) + old_head := GitCommitHash(self.sh.git("rev-parse", "HEAD")) ): self.sh.git("reset", "--soft", new_head) # TODO: print out commit hashes for things we rebased but not accessible @@ -390,11 +390,31 @@ def run(self) -> List[DiffMeta]: # TODO: Do a separate check for this if h.commit_id not in diff_meta_index: continue + new_orig = diff_meta_index[h.commit_id].orig self.check_invariants_for_diff( h.commit_id, - diff_meta_index[h.commit_id].orig, + new_orig, pre_branch_state_index.get(h.commit_id), ) + # Test that orig commits are accessible from HEAD, if the old + # commits were accessible. And if the commit was not + # accessible, it better not be accessible now! + if self.sh.git( + "merge-base", "--is-ancestor", h.commit_id, old_head, exitcode=True + ): + assert new_head is not None + assert self.sh.git( + "merge-base", "--is-ancestor", new_orig, new_head, exitcode=True + ) + else: + if new_head is not None: + assert not self.sh.git( + "merge-base", + "--is-ancestor", + new_orig, + new_head, + exitcode=True, + ) # NB: earliest first, which is the intuitive order for unit testing return list(reversed(diffs_to_submit)) @@ -1296,11 +1316,17 @@ def assert_eq(a: Any, b: Any) -> None: assert_eq(base_commit.tree, user_parent_commit.tree) # 6. Orig commit was correctly pushed - assert_eq(orig_commit.commit_id, GitCommitHash( - self.sh.git( - "rev-parse", self.remote_name + "/" + branch_orig(self.username, elaborated_orig_diff.ghnum) - ) - )) + assert_eq( + orig_commit.commit_id, + GitCommitHash( + self.sh.git( + "rev-parse", + self.remote_name + + "/" + + branch_orig(self.username, elaborated_orig_diff.ghnum), + ) + ), + ) # 7. Branches are either unchanged, or parent (no force pushes) # NB: head is always merged in as first parent diff --git a/test_ghstack.py b/test_ghstack.py index 09da546..5cee2d3 100644 --- a/test_ghstack.py +++ b/test_ghstack.py @@ -153,15 +153,7 @@ def gh( revs: Sequence[str] = (), stack: bool = True, ) -> List[ghstack.submit.DiffMeta]: - """ - ghstack.submit.parse_revs( - revs, - base_ref="origin/master", - sh=self.sh, - ) - """ - - kwargs = dict( + r = ghstack.submit.main( msg=msg, username="ezyang", github=self.github, @@ -179,8 +171,8 @@ def gh( stack=stack, check_invariants=True, ) - - return ghstack.submit.main(**kwargs) + self.check_global_github_invariants() + return r def gh_land(self, pull_request: str) -> None: return ghstack.land.main( @@ -201,6 +193,32 @@ def gh_unlink(self) -> None: remote_name="origin", ) + def check_global_github_invariants(self) -> None: + r = self.github.graphql( + """ + query { + repository(name: "pytorch", owner: "pytorch") { + pullRequests { + nodes { + baseRefName + headRefName + closed + } + } + } + } + """ + ) + # No refs may be reused for multiple open PRs + seen_refs = set() + for pr in r["data"]["repository"]["pullRequests"]["nodes"]: + if pr["closed"]: + continue + assert pr["baseRefName"] not in seen_refs + seen_refs.add(pr["baseRefName"]) + assert pr["headRefName"] not in seen_refs + seen_refs.add(pr["headRefName"]) + def dump_github(self) -> str: r = self.github.graphql( """