Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix 10593 -- add --keep option to experiment remove #10632

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions dvc/commands/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@ def add_parser(subparsers, parent_parser):
hide_subparsers_from_help(experiments_subparsers)


def add_keep_selection_flag(experiments_subcmd_parser):
experiments_subcmd_parser.add_argument(
"--keep",
action="store_true",
default=False,
help="Keep the selected experiments instead of removing them (use it with `--rev` and `--num` or with experiment names).",
)


def add_rev_selection_flags(
experiments_subcmd_parser, command: str, default: bool = True
):
Expand Down
4 changes: 3 additions & 1 deletion dvc/commands/experiments/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def run(self):
num=self.args.num,
queue=self.args.queue,
git_remote=self.args.git_remote,
keep_selected=self.args.keep,
)
if removed:
ui.write(f"Removed experiments: {humanize.join(map(repr, removed))}")
Expand All @@ -44,7 +45,7 @@ def run(self):


def add_parser(experiments_subparsers, parent_parser):
from . import add_rev_selection_flags
from . import add_keep_selection_flag, add_rev_selection_flags

EXPERIMENTS_REMOVE_HELP = "Remove experiments."
experiments_remove_parser = experiments_subparsers.add_parser(
Expand All @@ -57,6 +58,7 @@ def add_parser(experiments_subparsers, parent_parser):
)
remove_group = experiments_remove_parser.add_mutually_exclusive_group()
add_rev_selection_flags(experiments_remove_parser, "Remove", False)
add_keep_selection_flag(experiments_remove_parser)
remove_group.add_argument(
"--queue", action="store_true", help="Remove all queued experiments."
)
Expand Down
32 changes: 32 additions & 0 deletions dvc/repo/experiments/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def remove( # noqa: C901, PLR0912
num: int = 1,
queue: bool = False,
git_remote: Optional[str] = None,
keep_selected: bool = False, # keep the experiments instead of removing them
) -> list[str]:
removed: list[str] = []
if not any([exp_names, queue, all_commits, rev]):
Expand All @@ -43,6 +44,36 @@ def remove( # noqa: C901, PLR0912

exp_ref_list: list[ExpRefInfo] = []
queue_entry_list: list[QueueEntry] = []

if keep_selected:
# In keep_selected mode, identify all experiments and remove the unselected ones
all_exp_refs = exp_refs(repo.scm, git_remote)

if exp_names:
selected_exp_names = (
set(exp_names) if isinstance(exp_names, list) else {exp_names}
)
elif rev:
selected_exp_names = set(
_resolve_exp_by_baseline(
repo, [rev] if isinstance(rev, str) else rev, num, git_remote
).keys()
)
else:
selected_exp_names = set()

# Identify experiments to remove: all experiments - selected experiments
unselected_exp_refs = [
ref for ref in all_exp_refs if ref.name not in selected_exp_names
]
removed = [ref.name for ref in unselected_exp_refs]

# Remove the unselected experiments
if unselected_exp_refs:
_remove_commited_exps(repo.scm, unselected_exp_refs, git_remote)

return removed

if exp_names:
results: dict[str, ExpRefAndQueueEntry] = (
celery_queue.get_ref_and_entry_by_names(exp_names, git_remote)
Expand Down Expand Up @@ -83,6 +114,7 @@ def remove( # noqa: C901, PLR0912

removed_refs = [str(r) for r in exp_ref_list]
notify_refs_to_studio(repo, git_remote, removed=removed_refs)

return removed


Expand Down
165 changes: 165 additions & 0 deletions tests/func/experiments/test_remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,168 @@ def test_remove_multi_rev(tmp_dir, scm, dvc, exp_stage):

assert scm.get_ref(str(baseline_exp_ref)) is None
assert scm.get_ref(str(new_exp_ref)) is None


def test_keep_selected_by_name(tmp_dir, scm, dvc, exp_stage):
# Setup: Run experiments
results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"], name="exp1")
exp1_ref = first(exp_refs_by_rev(scm, first(results)))

results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"], name="exp2")
exp2_ref = first(exp_refs_by_rev(scm, first(results)))

results = dvc.experiments.run(exp_stage.addressing, params=["foo=3"], name="exp3")
exp3_ref = first(exp_refs_by_rev(scm, first(results)))

# Ensure experiments exist
assert scm.get_ref(str(exp1_ref)) is not None
assert scm.get_ref(str(exp2_ref)) is not None
assert scm.get_ref(str(exp3_ref)) is not None

# Keep "exp2" and remove others
removed = dvc.experiments.remove(exp_names=["exp2"], keep_selected=True)
assert removed == ["exp1", "exp3"]

# Check remaining experiments
assert scm.get_ref(str(exp1_ref)) is None
assert scm.get_ref(str(exp2_ref)) is not None
assert scm.get_ref(str(exp3_ref)) is None


def test_keep_selected_multiple_by_name(tmp_dir, scm, dvc, exp_stage):
# Setup: Run experiments
results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"], name="exp1")
exp1_ref = first(exp_refs_by_rev(scm, first(results)))

results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"], name="exp2")
exp2_ref = first(exp_refs_by_rev(scm, first(results)))

results = dvc.experiments.run(exp_stage.addressing, params=["foo=3"], name="exp3")
exp3_ref = first(exp_refs_by_rev(scm, first(results)))

# Ensure experiments exist
assert scm.get_ref(str(exp1_ref)) is not None
assert scm.get_ref(str(exp2_ref)) is not None
assert scm.get_ref(str(exp3_ref)) is not None

# Keep "exp1" and "exp2" and remove "exp3"
removed = dvc.experiments.remove(exp_names=["exp1", "exp2"], keep_selected=True)
assert removed == ["exp3"]

# Check remaining experiments
assert scm.get_ref(str(exp1_ref)) is not None
assert scm.get_ref(str(exp2_ref)) is not None
assert scm.get_ref(str(exp3_ref)) is None


def test_keep_selected_all_by_name(tmp_dir, scm, dvc, exp_stage):
# Setup: Run experiments
results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"], name="exp1")
exp1_ref = first(exp_refs_by_rev(scm, first(results)))

results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"], name="exp2")
exp2_ref = first(exp_refs_by_rev(scm, first(results)))

results = dvc.experiments.run(exp_stage.addressing, params=["foo=3"], name="exp3")
exp3_ref = first(exp_refs_by_rev(scm, first(results)))

# Ensure experiments exist
assert scm.get_ref(str(exp1_ref)) is not None
assert scm.get_ref(str(exp2_ref)) is not None
assert scm.get_ref(str(exp3_ref)) is not None

# Keep "exp1" and "exp2" and remove "exp3"
removed = dvc.experiments.remove(
exp_names=["exp1", "exp2", "exp3"], keep_selected=True
)
assert removed == []

# Check remaining experiments
assert scm.get_ref(str(exp1_ref)) is not None
assert scm.get_ref(str(exp2_ref)) is not None
assert scm.get_ref(str(exp3_ref)) is not None


def test_keep_selected_by_nonexistent_name(tmp_dir, scm, dvc, exp_stage):
# Setup: Run experiments
results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"], name="exp1")
exp1_ref = first(exp_refs_by_rev(scm, first(results)))

results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"], name="exp2")
exp2_ref = first(exp_refs_by_rev(scm, first(results)))

results = dvc.experiments.run(exp_stage.addressing, params=["foo=3"], name="exp3")
exp3_ref = first(exp_refs_by_rev(scm, first(results)))

# Ensure experiments exist
assert scm.get_ref(str(exp1_ref)) is not None
assert scm.get_ref(str(exp2_ref)) is not None
assert scm.get_ref(str(exp3_ref)) is not None

# Keep "exp1" and "exp2" and remove "exp3"
removed = dvc.experiments.remove(exp_names=["nonexistent"], keep_selected=True)
assert removed == ["exp1", "exp2", "exp3"]

# Check remaining experiments
assert scm.get_ref(str(exp1_ref)) is None
assert scm.get_ref(str(exp2_ref)) is None
assert scm.get_ref(str(exp3_ref)) is None


def test_keep_selected_by_rev(tmp_dir, scm, dvc, exp_stage):
# Setup: Run experiments and commit
baseline = scm.get_rev()
results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"], name="exp1")
exp1_ref = first(exp_refs_by_rev(scm, first(results)))
scm.commit("commit1")

new_results = dvc.experiments.run(
exp_stage.addressing, params=["foo=2"], name="exp2"
)
exp2_ref = first(exp_refs_by_rev(scm, first(new_results)))
new_rev = scm.get_rev()

# Ensure experiments exist
assert scm.get_ref(str(exp1_ref)) is not None
assert scm.get_ref(str(exp2_ref)) is not None

# Keep the experiment from the new revision
removed = dvc.experiments.remove(rev=new_rev, num=1, keep_selected=True)
assert removed == ["exp1"]

# Check remaining experiments
assert scm.get_ref(str(exp2_ref)) is not None
assert scm.get_ref(str(exp1_ref)) is None


def test_keep_selected_by_rev_multiple(tmp_dir, scm, dvc, exp_stage):
# Setup: Run experiments and commit
baseline = scm.get_rev()
exp1_rev = scm.get_rev()
results = dvc.experiments.run(exp_stage.addressing, params=["foo=1"], name="exp1")
exp1_ref = first(exp_refs_by_rev(scm, first(results)))
scm.commit("commit1")

exp2_rev = scm.get_rev()
results = dvc.experiments.run(exp_stage.addressing, params=["foo=2"], name="exp2")
exp2_ref = first(exp_refs_by_rev(scm, first(results)))
scm.commit("commit2")

exp3_rev = scm.get_rev()
results = dvc.experiments.run(exp_stage.addressing, params=["foo=3"], name="exp3")
exp3_ref = first(exp_refs_by_rev(scm, first(results)))
scm.commit("commit3")

# Ensure experiments exist
assert scm.get_ref(str(exp1_ref)) is not None
assert scm.get_ref(str(exp2_ref)) is not None
assert scm.get_ref(str(exp3_ref)) is not None

# Keep the last 2, remove first
removed = dvc.experiments.remove(rev=exp3_rev, num=2, keep_selected=True)
assert removed == ["exp1"]

# Check remaining experiments
assert scm.get_ref(str(exp3_ref)) is not None
assert scm.get_ref(str(exp2_ref)) is not None
assert scm.get_ref(str(exp1_ref)) is None
Loading