Skip to content

Commit

Permalink
Update inference.py
Browse files Browse the repository at this point in the history
updates saved swcs
  • Loading branch information
anna-grim authored Nov 19, 2024
1 parent 652b44a commit f9867d6
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions src/deep_neurographs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def build_graph(self, fragments_pointer):
# Save valid labels and current graph
swcs_path = os.path.join(self.output_dir, "processed-swcs.zip")
labels_path = os.path.join(self.output_dir, "valid_labels.txt")
n_saved = self.graph.to_zipped_swcs(swcs_path, min_size=100)
n_saved = self.graph.to_zipped_swcs(swcs_path)
self.graph.save_labels(labels_path)
self.report(f"# SWCs Saved: {n_saved}")

Expand Down Expand Up @@ -342,11 +342,14 @@ def save_results(self, round_id=None):
suffix = f"-{round_id}" if round_id else ""
filename = f"corrected-processed-swcs{suffix}.zip"
path = os.path.join(self.output_dir, filename)
self.graph.to_zipped_swcs(path, min_size=200)
self.graph.to_zipped_swcs(path)
self.save_connections(round_id=round_id)
self.write_metadata()

# Save result on s3
filename = f"corrected-processed-swcs-s3.zip"
path = os.path.join(self.output_dir, filename)
self.graph.to_zipped_swcs(path)
if self.save_to_s3_bool:
self.save_to_s3()

Expand All @@ -365,7 +368,7 @@ def save_to_s3(self):
"""
bucket_name = self.s3_dict["bucket_name"]
for filename in os.listdir(self.output_dir):
if filename != "processed-swcs.zip":
if "processed-swcs.zip" not in filename:
local_path = os.path.join(self.output_dir, filename)
s3_path = os.path.join(self.s3_dict["prefix"], filename)
util.write_to_s3(local_path, bucket_name, s3_path)
Expand Down

0 comments on commit f9867d6

Please sign in to comment.