-
Notifications
You must be signed in to change notification settings - Fork 52
/
loss_functions.py
executable file
·48 lines (36 loc) · 1.55 KB
/
loss_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from __future__ import division
import torch
from torch import nn
import numpy as np
def compute_errors_test(gt, pred):
gt = gt.numpy()
pred = pred.numpy()
thresh = np.maximum((gt / pred), (pred / gt))
a1 = (thresh < 1.25 ).mean()
a2 = (thresh < 1.25 ** 2).mean()
a3 = (thresh < 1.25 ** 3).mean()
rmse = (gt - pred) ** 2
rmse = np.sqrt(rmse.mean())
rmse_log = (np.log(gt) - np.log(pred)) ** 2
rmse_log = np.sqrt(rmse_log.mean())
abs_diff = np.mean(np.abs(gt - pred))
abs_rel = np.mean(np.abs(gt - pred) / gt)
sq_rel = np.mean(((gt - pred)**2) / gt)
return abs_rel, abs_diff, sq_rel, rmse, rmse_log, a1, a2, a3
def compute_errors_train(gt, pred, valid):
abs_diff, abs_rel, sq_rel, a1, a2, a3 = 0,0,0,0,0,0
batch_size = gt.size(0)
for current_gt, current_pred, current_valid in zip(gt, pred, valid):
valid_gt = current_gt[current_valid]
valid_pred = current_pred[current_valid]
if len(valid_gt) == 0:
continue
else:
thresh = torch.max((valid_gt / valid_pred), (valid_pred / valid_gt))
a1 += (thresh < 1.25).float().mean()
a2 += (thresh < 1.25 ** 2).float().mean()
a3 += (thresh < 1.25 ** 3).float().mean()
abs_diff += torch.mean(torch.abs(valid_gt - valid_pred))
abs_rel += torch.mean(torch.abs(valid_gt - valid_pred) / valid_gt)
sq_rel += torch.mean(((valid_gt - valid_pred)**2) / valid_gt)
return [metric / batch_size for metric in [abs_rel, abs_diff, sq_rel, a1, a2, a3]]