-
Notifications
You must be signed in to change notification settings - Fork 0
/
metrics.py
62 lines (45 loc) · 1.44 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
""" Custom metrics. All take logits and targets. """
from typing import IO
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def threshold(preds, threshold):
"""
Generates binary predictions based on a threshold.
"""
return (preds > threshold).int()
class Accuracy(nn.Module):
"""
Pixel-wise accuracy.
"""
name = 'accuracy'
def __init__(self, threshold=0.5, activation=torch.sigmoid):
super().__init__()
self.threshold = threshold
self.activation = activation
def forward(self, logits, targets):
preds = self.activation(logits)
preds = threshold(preds, self.threshold)
correct = (preds == targets).sum()
accuracy = correct / targets.view(-1).shape[0]
return accuracy
class IoU(nn.Module):
""" Jaccard, i.e., IoU """
name = 'iou'
SMOOTH = 1e-7
def __init__(self, threshold=0.5, activation=torch.sigmoid):
super().__init__()
self.threshold = threshold
self.activation = activation
def forward(self, logits, targets):
preds = self.activation(logits)
preds = threshold(preds, self.threshold)
intersection = torch.sum(preds * targets)
union = torch.sum(preds) + torch.sum(targets) - intersection
return (intersection + self.SMOOTH) / (union + self.SMOOTH)
metrics_index = {
'accuracy': Accuracy(),
'iou': IoU(),
'jaccard': IoU()
}