-
Notifications
You must be signed in to change notification settings - Fork 0
/
metrics_pair.py
78 lines (63 loc) · 2.74 KB
/
metrics_pair.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import logging
import tensorflow as tf
class Recorder:
def __init__(self):
self.loss = tf.keras.metrics.Mean()
self.loss_0 = tf.keras.metrics.Mean()
self.loss_1 = tf.keras.metrics.Mean()
self.pattern = 'Epoch: {}, step: {}, loss: {:.4f}, loss_0: {:.4f}, loss_1: {:.4f}'
def record(self, losses, losses_0, losses_1):
self.loss.update_state(losses)
self.loss_0.update_state(losses_0)
self.loss_1.update_state(losses_1)
def reset(self):
self.loss.reset_states()
self.loss_0.reset_states()
self.loss_1.reset_states()
def _results(self):
loss = self.loss.result().numpy()
loss_0 = self.loss_0.result().numpy()
loss_1 = self.loss_1.result().numpy()
return [loss, loss_0, loss_1]
def score(self):
return self._results()[0].numpy()
def log(self, epoch, num_step, prefix='', suffix=''):
loss, loss_0, loss_1 = self._results()
logging.info(prefix + self.pattern.format(epoch, num_step, loss, loss_0, loss_1) + suffix)
class Recorder_3:
def __init__(self):
self.loss = tf.keras.metrics.Mean()
self.loss0 = tf.keras.metrics.Mean()
self.loss1 = tf.keras.metrics.Mean()
self.loss2 = tf.keras.metrics.Mean()
# self.precision = tf.keras.metrics.Precision()
# self.recall = tf.keras.metrics.Recall()
self.pattern = 'Epoch: {}, step: {}, loss: {:.4f}, loss_0: {:.4f},loss_1: {:.4f},loss_2: {:.4f}'
def record(self, losses, loss_0,loss_1,loss_2):
self.loss.update_state(losses)
self.loss0.update_state(loss_0)
self.loss1.update_state(loss_1)
self.loss2.update_state(loss_2)
# self.precision.update_state(labels, predictions)
# self.recall.update_state(labels, predictions)
def reset(self):
self.loss.reset_states()
self.loss0.reset_states()
self.loss1.reset_states()
self.loss2.reset_states()
# self.precision.reset_states()
# self.recall.reset_states()
def _results(self):
loss = self.loss.result().numpy()
loss0 = self.loss0.result().numpy()
loss1 = self.loss1.result().numpy()
loss2 = self.loss2.result().numpy()
# precision = self.precision.result().numpy()
# recall = self.recall.result().numpy()
# f1 = 2 * precision * recall / (precision + recall + 1e-6) # avoid division by 0
return [loss, loss0, loss1, loss2]#,precision, recall, f1]
def score(self):
return self._results()[-1].numpy()
def log(self, epoch, num_step, prefix='', suffix=''):
loss,loss0,loss1, loss2 = self._results()
logging.info(prefix + self.pattern.format(epoch, num_step, loss,loss0,loss1,loss2) + suffix)