Skip to content

Commit

Permalink
Refactor gnn training (#252)
Browse files Browse the repository at this point in the history
* minor upds

* refactor: training pipeline

* feat: find gcs image path

* feat: feature generation in trainer

* feat: validation sets in training

* bug: hgraph forward passes with missing edge types

* refactor: hgnn trainer

* feat: functional training pipeline

* bug: set validation data

* refactor: combined train engine and pipeline

* refactor: infernce pipeline, evaluation

---------

Co-authored-by: anna-grim <[email protected]>
  • Loading branch information
anna-grim and anna-grim authored Sep 27, 2024
1 parent 2c7bbad commit 3c51a61
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 77 deletions.
2 changes: 1 addition & 1 deletion src/deep_neurographs/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def get_profile(img, xyz_arr, process_id=None, window=[5, 5, 5]):
"""
profile = []
for xyz in xyz_arr:
if type(img) == ts.TensorStore:
if type(img) is ts.TensorStore:
profile.append(np.max(util.read_tensorstore(img, xyz, window)))
else:
profile.append(np.max(util.get_chunk(img, xyz, window)))
Expand Down
65 changes: 0 additions & 65 deletions src/deep_neurographs/machine_learning/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,74 +82,9 @@ def run_evaluation(neurograph, accepts, proposals):
stats["Overall"][metric].append(overall_stats[metric])
stats["Simple"][metric].append(simple_stats[metric])
stats["Complex"][metric].append(complex_stats[metric])

return stats


def run_evaluation_blocks(neurographs, blocks, accepts):
"""
Runs an evaluation on the accuracy of the predictions generated by an edge
classication model for a given list of blocks.
Parameters
----------
neurographs : list[NeuroGraph]
Predicted neurographs.
blocks : list[str], optional
List of block_ids that indicate which predictions to evaluate.
accepts : list
Accepted proposals.
Returns
-------
dict[dict]
Acuracy of the edge classification model on all edges, simple edges,
and complex edges. The metrics contained in a sub-dictionary where the
keys are identical to "METRICS_LIST"].
"""
avg_wgts = {"Overall": [], "Simple": [], "Complex": []}
stats = {
"Overall": init_stats(),
"Simple": init_stats(),
"Complex": init_stats(),
}
for block_id in blocks:
# Compute accuracy
overall_stats_i = get_stats(
neurographs[block_id],
neurographs[block_id].proposals,
accepts[block_id],
)

simple_stats_i = get_stats(
neurographs[block_id],
neurographs[block_id].simple_proposals(),
accepts[block_id],
)

complex_stats_i = get_stats(
neurographs[block_id],
neurographs[block_id].complex_proposals(),
accepts[block_id],
)

# Store results
avg_wgts["Overall"].append(len(neurographs[block_id].proposals))
avg_wgts["Simple"].append(
len(neurographs[block_id].simple_proposals())
)
avg_wgts["Complex"].append(
len(neurographs[block_id].complex_proposals())
)
for metric in METRICS_LIST:
stats["Overall"][metric].append(overall_stats_i[metric])
stats["Simple"][metric].append(simple_stats_i[metric])
stats["Complex"][metric].append(complex_stats_i[metric])

return stats, avg_wgts


def get_stats(neurograph, proposals, accepts):
"""
Accuracy of the predictions generated by an edge classication model on a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ def generate_chunks(neurograph, proposals, img, labels):
def get_chunk(img, labels, voxel_1, voxel_2, thread_id=None):
# Extract chunks
midpoint = geometry.get_midpoint(voxel_1, voxel_2).astype(int)
if type(img) == ts.TensorStore:
if type(img) is ts.TensorStore:
chunk = util.read_tensorstore(img, midpoint, CHUNK_SIZE)
labels_chunk = util.read_tensorstore(labels, midpoint, CHUNK_SIZE)
else:
Expand Down
17 changes: 10 additions & 7 deletions src/deep_neurographs/machine_learning/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
search_radius,
batch_size=BATCH_SIZE,
confidence_threshold=CONFIDENCE_THRESHOLD,
downsample_factor=0,
downsample_factor=1,
):
"""
Initializes an inference engine by loading images and setting class
Expand Down Expand Up @@ -122,8 +122,10 @@ def run(self, neurograph, proposals):
preds = self.run_model(dataset)

# Update graph
batch_accepts = get_accepted_proposals(neurograph, preds)
for proposal in map(frozenset, batch_accepts):
batch_accepts = get_accepted_proposals(
neurograph, preds, self.threshold
)
for proposal in batch_accepts:
neurograph.merge_proposal(proposal)

# Finish
Expand Down Expand Up @@ -222,7 +224,7 @@ def run_model(self, dataset):

# Filter preds
idxs = dataset.idxs_proposals["idx_to_edge"]
return {idxs[i]: p for i, p in enumerate(preds) if p > self.threshold}
return {idxs[i]: p for i, p in enumerate(preds)}


# --- run machine learning model ---
Expand Down Expand Up @@ -257,7 +259,7 @@ def run_gnn_model(data, model, model_type):


# --- Accepting proposals ---
def get_accepted_proposals(neurograph, preds, high_threshold=0.9):
def get_accepted_proposals(neurograph, preds, threshold, high_threshold=0.9):
"""
Determines which proposals to accept based on prediction scores and the
specified threshold.
Expand All @@ -280,6 +282,7 @@ def get_accepted_proposals(neurograph, preds, high_threshold=0.9):
"""
# Partition proposals into best and the rest
preds = {k: v for k, v in preds.items() if v > threshold}
best_proposals, proposals = separate_best(
preds, neurograph.simple_proposals(), high_threshold
)
Expand Down Expand Up @@ -359,8 +362,8 @@ def filter_proposals(graph, proposals):
created_cycle, _ = gutil.creates_cycle(subgraph, (i, j))
if not created_cycle:
graph.add_edge(i, j)
accepts.append((i, j))
graph.remove_edges_from(accepts)
accepts.append(frozenset({i, j}))
graph.remove_edges_from(map(tuple, accepts))
return accepts


Expand Down
2 changes: 1 addition & 1 deletion src/deep_neurographs/neurograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,7 +852,7 @@ def branch_contained(self, xyz_list):

def to_voxels(self, node_or_xyz, shift=False):
shift = self.origin if shift else np.zeros((3))
if type(node_or_xyz) == int:
if type(node_or_xyz) is int:
coord = img_util.to_voxels(self.nodes[node_or_xyz]["xyz"])
else:
coord = img_util.to_voxels(node_or_xyz)
Expand Down
2 changes: 1 addition & 1 deletion src/deep_neurographs/utils/img_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,4 +479,4 @@ def find_img_path(bucket_name, img_root, dataset_name):
for subdir in util.list_gcs_subdirectories(bucket_name, img_root):
if dataset_name in subdir:
return subdir + "whole-brain/fused.zarr/"
raise(f"Dataset not found in {bucket_name} - {img_root}")
raise f"Dataset not found in {bucket_name} - {img_root}"
2 changes: 1 addition & 1 deletion src/deep_neurographs/utils/swc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ def to_graph(swc_dict, swc_id=None, set_attrs=False):
graph.add_edges_from(zip(swc_dict["id"][1:], swc_dict["pid"][1:]))
if set_attrs:
xyz = swc_dict["xyz"]
if type(swc_dict["xyz"]) == np.ndarray:
if type(swc_dict["xyz"]) is np.ndarray:
xyz = util.numpy_to_hashable(swc_dict["xyz"])
graph = __add_attributes(swc_dict, graph)
xyz_to_node = dict(zip(xyz, swc_dict["id"]))
Expand Down

0 comments on commit 3c51a61

Please sign in to comment.