Skip to content

Commit

Permalink
bug: circular arg fixed, class rename (#279)
Browse files Browse the repository at this point in the history
Co-authored-by: anna-grim <[email protected]>
  • Loading branch information
anna-grim and anna-grim authored Nov 12, 2024
1 parent 4760906 commit 3cec969
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 47 deletions.
9 changes: 4 additions & 5 deletions src/deep_neurographs/neurograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@
from numpy import concatenate
from scipy.spatial import KDTree

from deep_neurographs import generate_proposals, geometry
from deep_neurographs import generate_proposals, geometry, utils
from deep_neurographs.groundtruth_generation import init_targets
from deep_neurographs.utils import graph_util as gutil
from deep_neurographs.utils import img_util, util


class NeuroGraph(nx.Graph):
class FragmentsGraph(nx.Graph):
"""
A class of graphs whose nodes correspond to irreducible nodes from the
predicted swc files.
Expand All @@ -48,7 +47,7 @@ def __init__(self, img_bbox=None, node_spacing=1):
None
"""
super(NeuroGraph, self).__init__()
super(FragmentsGraph, self).__init__()
# General class attributes
self.leaf_kdtree = None
self.node_cnt = 0
Expand Down Expand Up @@ -97,7 +96,7 @@ def set_proxy_soma_ids(self, k):
None
"""
for i in gutil.largest_components(self, k):
for i in utils.graph_util.largest_components(self, k):
self.soma_ids[self.nodes[i]["swc_id"]] = i

def get_leafs(self):
Expand Down
31 changes: 17 additions & 14 deletions src/deep_neurographs/utils/graph_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
@email: [email protected]
Routines for loading fragments and building a neurograph.
Routines for loading fragments and building a fragments_graph.
Terminology
Expand All @@ -31,7 +31,6 @@
from tqdm import tqdm

from deep_neurographs import geometry
from deep_neurographs.neurograph import NeuroGraph
from deep_neurographs.utils import img_util, swc_util, util

MIN_SIZE = 30
Expand Down Expand Up @@ -82,8 +81,7 @@ def __init__(
Returns
-------
FragmentsGraph
FragmentsGraph generated from swc files.
None
"""
self.anisotropy = anisotropy
Expand Down Expand Up @@ -120,6 +118,8 @@ def run(
FragmentsGraph generated from swc files.
"""
from deep_neurographs.neurograph import FragmentsGraph

# Load fragments and extract irreducibles
self.img_bbox = img_util.init_bbox(img_patch_origin, img_patch_shape)
swc_dicts = self.reader.load(fragments_pointer)
Expand All @@ -129,13 +129,13 @@ def run(
if self.progress_bar:
pbar = tqdm(total=len(irreducibles), desc="Combine Graphs")

neurograph = NeuroGraph(node_spacing=self.node_spacing)
fragments_graph = FragmentsGraph(node_spacing=self.node_spacing)
while len(irreducibles):
irreducible_set = irreducibles.pop()
neurograph.add_component(irreducible_set)
fragments_graph.add_component(irreducible_set)
if self.progress_bar:
pbar.update(1)
return neurograph
return fragments_graph

# --- Graph structure extraction ---
def schedule_processes(self, swc_dicts):
Expand Down Expand Up @@ -645,7 +645,8 @@ def compute_dist(graph, i, j):
Returns
-------
Euclidean distance between i and j.
float
Euclidean distance between i and j.
"""
return geometry.dist(graph.nodes[i]["xyz"], graph.nodes[j]["xyz"])
Expand Down Expand Up @@ -686,6 +687,7 @@ def get_leafs(graph):
-------
list
Leaf nodes "graph".
"""
return [i for i in graph.nodes if graph.degree[i] == 1]

Expand Down Expand Up @@ -746,20 +748,21 @@ def count_components(graph):
Graph to be searched.
Returns
-------
Number of connected components.
-------'
int
Number of connected components.
"""
return nx.number_connected_components(graph)


def largest_components(neurograph, k):
def largest_components(graph, k):
"""
Finds the "k" largest connected components in "neurograph".
Finds the "k" largest connected components in "graph".
Parameters
----------
neurograph : NeuroGraph
graph : nx.Graph
Graph to be searched.
k : int
Number of largest connected components to return.
Expand All @@ -773,7 +776,7 @@ def largest_components(neurograph, k):
"""
component_cardinalities = k * [-1]
node_ids = k * [-1]
for nodes in nx.connected_components(neurograph):
for nodes in nx.connected_components(graph):
if len(nodes) > component_cardinalities[-1]:
i = 0
while i < k:
Expand Down
62 changes: 34 additions & 28 deletions src/deep_neurographs/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


def visualize_connected_components(
graph, line_width=3, return_data=False, title=""
graph, width=3, return_data=False, title=""
):
"""
Visualizes the connected components in "graph".
Expand All @@ -25,8 +25,8 @@ def visualize_connected_components(
----------
graph : networkx.Graph
Graph to be visualized.
line_width : int, optional
Line width used to plot "subset". The default is 5.
width : int, optional
Line width used to plot edges in "subset". The default is 5.
return_data : bool, optional
Indication of whether to return data object that is used to generate
plot. The default is False.
Expand All @@ -50,7 +50,7 @@ def visualize_connected_components(
color = colors[cnt % len(colors)]
data.extend(
plot_edges(
graph, subgraph.edges, color=color, line_width=line_width
graph, subgraph.edges, color=color, width=width
)
)
cnt += 1
Expand Down Expand Up @@ -85,16 +85,18 @@ def visualize_graph(graph, title=""):
plot(data, title)


def visualize_proposals(graph, target_graph=None, title="Proposals"):
def visualize_proposals(
graph, color=None, groundtruth_graph=None, title="Proposals"
):
"""
Visualizes a graph with proposals.
Visualizes a graph and its proposals.
Parameters
----------
graph : networkx.Graph
Graph to be visualized.
target_graph : networkx.Graph, optional
Graph generated from ground truth tracings. The default is None.
groundtruth_graph : networkx.Graph, optional
Graph generated from groundtruth tracings. The default is None.
title : str, optional
Title of the plot. Default is "Proposals".
Expand All @@ -106,24 +108,25 @@ def visualize_proposals(graph, target_graph=None, title="Proposals"):
visualize_subset(
graph,
graph.proposals,
color=color,
proposal_subset=True,
target_graph=target_graph,
groundtruth_graph=groundtruth_graph,
title=title,
)


def visualize_targets(
graph, target_graph=None, title="Ground Truth - Accepted Proposals"
def visualize_groundtruth(
graph, groundtruth_graph=None, title="Ground Truth - Accepted Proposals"
):
"""
Visualizes a graph and its ground truth accept proposals.
Visualizes a graph and its groundtruth accepted proposals.
Parameters
----------
graph : networkx.Graph
Graph to be visualized.
target_graph : networkx.Graph, optional
Graph generated from ground truth tracings. The default is None.
groundtruth_graph : networkx.Graph, optional
Graph generated from groundtruth tracings. The default is None.
title : str, optional
Title of the plot. Default is "Ground Truth - Accepted Proposals".
Expand All @@ -136,17 +139,18 @@ def visualize_targets(
graph,
graph.target_edges,
proposal_subset=True,
target_graph=target_graph,
groundtruth_graph=groundtruth_graph,
title=title,
)


def visualize_subset(
graph,
subset,
line_width=5,
color=None,
width=5,
proposal_subset=False,
target_graph=None,
groundtruth_graph=None,
title="",
):
"""
Expand All @@ -158,12 +162,12 @@ def visualize_subset(
Graph to be visualized.
subset : container
Subset of edges or proposals to be visualized.
line_width : int, optional
width : int, optional
Line width used to plot "subset". The default is 5.
proposals_subset : bool, optional
Indication of whether "subset" is a subset of proposals. The default
is False.
target_graph : networkx.Graph, optional
groundtruth_graph : networkx.Graph, optional
Graph generated from ground truth tracings. The default is None.
title : str, optional
Title of the plot. Default is "Proposals".
Expand All @@ -177,13 +181,15 @@ def visualize_subset(
data = plot_edges(graph, graph.edges, color="black")
data.append(plot_nodes(graph))
if proposal_subset:
data.extend(plot_proposals(graph, subset, line_width=line_width))
data.extend(
plot_proposals(graph, subset, color=color, width=width)
)
else:
data.extend(plot_edges(graph, subset, line_width=line_width))
data.extend(plot_edges(graph, subset, width=width))

# Add target graph (if applicable)
if target_graph:
cc = visualize_connected_components(target_graph, return_data=True)
if groundtruth_graph:
cc = visualize_connected_components(groundtruth_graph, return_data=True)
data.extend(cc)
plot(data, title)

Expand All @@ -202,12 +208,12 @@ def plot_nodes(graph):
)


def plot_proposals(graph, proposals, color=None, line_width=5):
def plot_proposals(graph, proposals, color=None, width=5):
# Set preferences
if color is None:
line = dict(width=line_width)
line = dict(width=width)
else:
line = dict(color=color, width=line_width)
line = dict(color=color, width=width)

# Add traces
traces = []
Expand All @@ -225,10 +231,10 @@ def plot_proposals(graph, proposals, color=None, line_width=5):
return traces


def plot_edges(graph, edges, color=None, line_width=3):
def plot_edges(graph, edges, color=None, width=3):
traces = []
line = (
dict(width=5) if color is None else dict(color=color, width=line_width)
dict(width=5) if color is None else dict(color=color, width=width)
)
for i, j in edges:
trace = go.Scatter3d(
Expand Down

0 comments on commit 3cec969

Please sign in to comment.