Skip to content

Commit

Permalink
refactor: batch formation (#283)
Browse files Browse the repository at this point in the history
Co-authored-by: anna-grim <[email protected]>
  • Loading branch information
anna-grim and anna-grim authored Nov 27, 2024
1 parent 91795f0 commit e6dc7e0
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 58 deletions.
33 changes: 33 additions & 0 deletions src/deep_neurographs/fragments_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,39 @@ def list_proposals(self):
"""
return list(self.proposals)

def proposal_connected_component(self, proposal):
"""
Extracts the connected component that "proposal" belongs to in the
proposal induced subgraph.
Parameters
----------
proposal : frozenset
Proposal used to as the root to extract its connected component
in the proposal induced subgraph.
Returns
-------
List[frozenset]
List of proposals in the connected component that "proposal"
belongs to in the proposal induced subgraph.
"""
queue = [proposal]
visited = set()
while len(queue) > 0:
# Visit proposal
p = queue.pop()
visited.add(p)

# Update queue
for i in p:
for j in self.nodes[i]["proposals"]:
p_ij = frozenset({i, j})
if p_ij not in visited:
queue.append(p_ij)
return visited

# -- KDTree --
def init_kdtree(self, node_type):
"""
Expand Down
82 changes: 52 additions & 30 deletions src/deep_neurographs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(
self.model_path,
self.ml_config.model_type,
self.graph_config.search_radius,
batch_size=self.ml_config.batch_size,
confidence_threshold=self.ml_config.threshold,
device=device,
downsample_factor=self.ml_config.downsample_factor,
Expand Down Expand Up @@ -535,62 +536,66 @@ def __init__(
if self.is_gnn and "cuda" in device:
self.model = self.model.to(self.device)

def run(self, neurograph, proposals):
def run(self, fragments_graph, proposals):
"""
Runs inference by forming batches of proposals, then performing the
following steps for each batch: (1) generate features, (2) classify
proposals by running model, and (3) adding each accepted proposal as
an edge to "neurograph" if it does not create a cycle.
an edge to "fragments_graph" if it does not create a cycle.
Parameters
----------
neurograph : NeuroGraph
fragments_graph : FragmentsGraph
Graph that inference will be performed on.
proposals : list
Proposals to be classified as accept or reject.
Returns
-------
NeuroGraph
FragmentsGraph
Updated graph with accepted proposals added as edges.
list
Accepted proposals.
"""
# Initializations
assert not gutil.cycle_exists(neurograph), "Graph contains cycle!"
assert not gutil.cycle_exists(fragments_graph), "Graph has cycle!"
if self.is_gnn:
proposals = set(proposals)
else:
proposals = sort_proposals(neurograph, proposals)
proposals = sort_proposals(fragments_graph, proposals)

# Main
flagged = get_large_proposal_components(fragments_graph, 4)
with tqdm(total=len(proposals), desc="Inference") as pbar:
accepts = list()
while len(proposals) > 0:
# Predict
batch = self.get_batch(neurograph, proposals)
dataset = self.get_batch_dataset(neurograph, batch)
batch = self.get_batch(fragments_graph, proposals, flagged)
dataset = self.get_batch_dataset(fragments_graph, batch)
preds = self.predict(dataset)

# Update graph
for p in get_accepts(neurograph, preds, self.threshold):
neurograph.merge_proposal(p)
for p in get_accepts(fragments_graph, preds, self.threshold):
fragments_graph.merge_proposal(p)
accepts.append(p)
pbar.update(len(batch["proposals"]))
neurograph.absorb_reducibles()
return neurograph, accepts
fragments_graph.absorb_reducibles()
return fragments_graph, accepts

def get_batch(self, neurograph, proposals):
def get_batch(self, fragments_graph, proposals, flagged_proposals):
"""
Generates a batch of proposals.
Parameters
----------
neurograph : NeuroGraph
fragments_graph : FragmentsGraph
Graph that proposals were generated from.
proposals : list
proposals : List[frozenset]
Proposals for which batch is to be generated from.
flagged_proposals : List[frozenset]
List of proposals that are part of a "large" connected component
in the proposal induced subgraph of "fragments_graph".
Returns
-------
Expand All @@ -600,20 +605,22 @@ def get_batch(self, neurograph, proposals):
"""
if self.is_gnn:
return gnn_util.get_batch(neurograph, proposals, self.batch_size)
return gnn_util.get_batch(
fragments_graph, proposals, self.batch_size, flagged_proposals
)
else:
batch = {"proposals": proposals[0:self.batch_size], "graph": None}
del proposals[0:self.batch_size]
return batch

def get_batch_dataset(self, neurograph, batch):
def get_batch_dataset(self, fragments_graph, batch):
"""
Generates features and initializes dataset that can be input to a
machine learning model.
Parameters
----------
neurograph : NeuroGraph
fragments_graph : FragmentsGraph
Graph that inference will be performed on.
batch : list
Proposals to be classified.
Expand All @@ -623,10 +630,12 @@ def get_batch_dataset(self, neurograph, batch):
...
"""
features = self.feature_generator.run(neurograph, batch, self.radius)
features = self.feature_generator.run(
fragments_graph, batch, self.radius
)
computation_graph = batch["graph"] if type(batch) is dict else None
dataset = ml_util.init_dataset(
neurograph,
fragments_graph,
features,
self.is_gnn,
computation_graph=computation_graph,
Expand Down Expand Up @@ -694,14 +703,14 @@ def predict_with_gnn(model, data, device=None):
return toCPU(preds[0:len(data["proposal"]["y"]), 0])


def get_accepts(neurograph, preds, threshold, high_threshold=0.9):
def get_accepts(fragments_graph, preds, threshold, high_threshold=0.9):
"""
Determines which proposals to accept based on prediction scores and the
specified threshold.
Parameters
----------
neurograph : NeuroGraph
fragments_graph : FragmentsGraph
Graph that proposals belong to.
preds : dict
Dictionary that maps proposal ids to probability generated from
Expand All @@ -713,20 +722,20 @@ def get_accepts(neurograph, preds, threshold, high_threshold=0.9):
Returns
-------
list
Proposals to be added as edges to "neurograph".
Proposals to be added as edges to "fragments_graph".
"""
# Partition proposals into best and the rest
preds = {k: v for k, v in preds.items() if v > threshold}
best_proposals, proposals = separate_best(
preds, neurograph.simple_proposals(), high_threshold
preds, fragments_graph.simple_proposals(), high_threshold
)

# Determine which proposals to accept
accepts = list()
accepts.extend(filter_proposals(neurograph, best_proposals))
accepts.extend(filter_proposals(neurograph, proposals))
neurograph.remove_edges_from(map(tuple, accepts))
accepts.extend(filter_proposals(fragments_graph, best_proposals))
accepts.extend(filter_proposals(fragments_graph, proposals))
fragments_graph.remove_edges_from(map(tuple, accepts))
return accepts


Expand Down Expand Up @@ -795,13 +804,13 @@ def filter_proposals(graph, proposals):
return accepts


def sort_proposals(neurograph, proposals):
def sort_proposals(fragments_graph, proposals):
"""
Sorts proposals by length.
Parameters
----------
neurograph : NeuroGraph
fragments_graph : FragmentsGraph
Graph that proposals were generated from.
proposals : list[frozenset]
List of proposals.
Expand All @@ -812,5 +821,18 @@ def sort_proposals(neurograph, proposals):
Sorted proposals.
"""
idxs = np.argsort([neurograph.proposal_length(p) for p in proposals])
idxs = np.argsort([fragments_graph.proposal_length(p) for p in proposals])
return [proposals[idx] for idx in idxs]


# --- Batch Formation ---
def get_large_proposal_components(fragments_graph, k):
flagged_proposals = set()
visited = set()
for p in fragments_graph.list_proposals():
if p not in visited:
component = fragments_graph.proposal_connected_component(p)
if len(component) > k:
flagged_proposals = flagged_proposals.union(component)
visited = visited.union(component)
return flagged_proposals
Loading

0 comments on commit e6dc7e0

Please sign in to comment.