-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* bug: circular arg fixed, class rename * refactor: simplified evaluation and editted documentation --------- Co-authored-by: anna-grim <[email protected]>
- Loading branch information
Showing
3 changed files
with
104 additions
and
126 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ | |
@author: Anna Grim | ||
@email: [email protected] | ||
Evaluates performance of edge classifiation model. | ||
Evaluates performance of proposal classifiation model. | ||
""" | ||
|
||
|
@@ -19,95 +19,63 @@ | |
] | ||
|
||
|
||
def init_stats(): | ||
def run_evaluation(fragments_graph, proposals, accepts): | ||
""" | ||
Initializes a dictionary that stores stats computes by routines in this | ||
module. | ||
Parameters | ||
---------- | ||
None | ||
Returns | ||
------- | ||
dict | ||
Dictionary that stores stats computes by routines in this module. | ||
""" | ||
return dict([(metric, []) for metric in METRICS_LIST]) | ||
|
||
|
||
def run_evaluation(neurograph, accepts, proposals): | ||
""" | ||
Runs an evaluation on the accuracy of the predictions generated by an edge | ||
Evaluates the accuracy of predictions made by a proposal | ||
classication model. | ||
Parameters | ||
---------- | ||
neurographs : list[NeuroGraph] | ||
Predicted neurographs. | ||
accepts : list | ||
fragments_graphs : FragmentsGraph | ||
Graph generated from fragments of a predicted segmentation. | ||
proposals : list[frozenset] | ||
Proposals classified by model. | ||
accepts : list[frozenset] | ||
Accepted proposals. | ||
proposals : list | ||
Proposals that were classified as either accept or reject. | ||
Returns | ||
------- | ||
dict | ||
Dictionary that stores the accuracy of the edge classification model | ||
on all edges (i.e. "Overall"), simple edges, and complex edges. The | ||
metrics contained in this dictionary are identical to "METRICS_LIST"]. | ||
Dictionary that stores statistics calculated for all proposal | ||
predictions and separately for simple and complex proposals, as | ||
specified in "METRICS_LIST". | ||
""" | ||
# Initializations | ||
stats = { | ||
"Overall": init_stats(), | ||
"Simple": init_stats(), | ||
"Complex": init_stats(), | ||
} | ||
stats = dict() | ||
simple_proposals = fragments_graph.simple_proposals() | ||
complex_proposals = fragments_graph.complex_proposals() | ||
|
||
# Evaluation | ||
overall_stats = get_stats(neurograph, proposals, accepts) | ||
|
||
simple_stats = get_stats( | ||
neurograph, neurograph.simple_proposals(), accepts | ||
) | ||
|
||
complex_stats = get_stats( | ||
neurograph, neurograph.complex_proposals(), accepts | ||
) | ||
|
||
# Store results | ||
for metric in METRICS_LIST: | ||
stats["Overall"][metric].append(overall_stats[metric]) | ||
stats["Simple"][metric].append(simple_stats[metric]) | ||
stats["Complex"][metric].append(complex_stats[metric]) | ||
stats["Overall"] = get_stats(fragments_graph, proposals, accepts) | ||
stats["Simple"] = get_stats(fragments_graph, simple_proposals, accepts) | ||
stats["Complex"] = get_stats(fragments_graph, complex_proposals, accepts) | ||
return stats | ||
|
||
|
||
def get_stats(neurograph, proposals, accepts): | ||
def get_stats(fragments_graph, proposals, accepts): | ||
""" | ||
Accuracy of the predictions generated by an edge classication model on a | ||
given block and "edge_type" (e.g. overall, simple, or complex). | ||
Computes statistics that reflect the accuracy of the predictions made by | ||
a proposal classication model. | ||
Parameters | ||
---------- | ||
neurograph : NeuroGraph | ||
Predicted neurograph | ||
proposals : set[frozenset] | ||
Set of edge proposals for a given "edge_type". | ||
fragments_graph : FragmentsGraph | ||
Graph generated from fragments of a predicted segmentation. | ||
proposals : list[frozenset] | ||
List of proposals of a specified "proposal_type". | ||
accepts : numpy.ndarray | ||
Binary predictions of edges generated by classifcation model. | ||
Accepted proposals. | ||
Returns | ||
------- | ||
dict | ||
Results of evaluation where the keys are identical to "METRICS_LIST". | ||
""" | ||
n_pos = len([e for e in proposals if e in neurograph.target_edges]) | ||
n_pos = len([e for e in proposals if e in fragments_graph.gt_accepts]) | ||
a_baseline = n_pos / (len(proposals) if len(proposals) > 0 else 1) | ||
tp, fp, a, p, r, f1 = get_accuracy(neurograph, proposals, accepts) | ||
tp, fp, a, p, r, f1 = get_accuracy(fragments_graph, proposals, accepts) | ||
stats = { | ||
"# splits fixed": tp, | ||
"# merges created": fp, | ||
|
@@ -120,80 +88,63 @@ def get_stats(neurograph, proposals, accepts): | |
return stats | ||
|
||
|
||
def get_accuracy(neurograph, proposals, accepts): | ||
def get_accuracy(fragments_graph, proposals, accepts): | ||
""" | ||
Computes the following metrics for a given set of predicted edges: | ||
(1) true positives, (2) false positive, (3) precision, (4) recall, and | ||
(5) f1-score. | ||
Computes the following metrics for a given set of predicted proposals: | ||
(1) true positives, (2) false positive, (3) accuracy, (4) precision, | ||
(5) recall, and (6) f1-score. | ||
Parameters | ||
---------- | ||
neurograph : NeuroGraph | ||
Predicted neurograph | ||
fragments_graph : FragmentsGraph | ||
Graph generated from fragments of a predicted segmentation. | ||
proposals : set[frozenset] | ||
Set of edge proposals for a given "edge_type". | ||
List of proposals of a specified "proposal_type". | ||
accepts : list | ||
Accepted proposals. | ||
Returns | ||
------- | ||
float | ||
Number of true positives. | ||
float | ||
Number of false positives. | ||
float | ||
Precision. | ||
float | ||
Recall. | ||
float | ||
F1-score. | ||
float, float, float, float, float, float | ||
Number true positives, number of false positives, accuracy, precision, | ||
recall, and F1-score. | ||
""" | ||
tp, tn, fp, fn = get_accuracy_counts(neurograph, proposals, accepts) | ||
tp, tn, fp, fn = get_detection_cnts(fragments_graph, proposals, accepts) | ||
a = (tp + tn) / len(proposals) if len(proposals) else 1 | ||
p = 1 if tp + fp == 0 else tp / (tp + fp) | ||
r = 1 if tp + fn == 0 else tp / (tp + fn) | ||
f1 = (2 * r * p) / max(r + p, 1e-3) | ||
return tp, fp, a, p, r, f1 | ||
|
||
|
||
def get_accuracy_counts(neurograph, proposals, accepts): | ||
def get_detection_cnts(fragments_graph, proposals, accepts): | ||
""" | ||
Computes the following values: (1) true positives, (2) false positive, and | ||
(3) false negatives. | ||
Computes the following values: (1) true positives, (2) true negatives, | ||
(3) false positive, and (4) false negatives. | ||
Parameters | ||
---------- | ||
neurograph : NeuroGraph | ||
Predicted neurograph | ||
fragments_graph : FragmentsGraph | ||
Graph generated from fragments of a predicted segmentation. | ||
proposals : set[frozenset] | ||
Set of edge proposals for a given "edge_type". | ||
List of proposals of a specified "proposal_type". | ||
accepts : list | ||
Accepted proposals. | ||
Returns | ||
------- | ||
float | ||
Number of true positives. | ||
float | ||
Number of false positives. | ||
float | ||
Number of false negatives. | ||
float, float, float, float | ||
Number of true positives, true negatives, false positives, and false | ||
negatives. | ||
""" | ||
tp = 0 | ||
tn = 0 | ||
fp = 0 | ||
fn = 0 | ||
for edge in proposals: | ||
if edge in neurograph.target_edges: | ||
if edge in accepts: | ||
tp += 1 | ||
else: | ||
fn += 1 | ||
tp, tn, fp, fn = 0, 0, 0, 0 | ||
for p in proposals: | ||
if p in fragments_graph.gt_accepts: | ||
tp += 1 if p in accepts else 0 | ||
fn += 1 if p not in accepts else 0 | ||
else: | ||
if edge in accepts: | ||
fp += 1 | ||
else: | ||
tn += 1 | ||
fp += 1 if p in accepts else 0 | ||
tn += 1 if p not in accepts else 0 | ||
return tp, tn, fp, fn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters