-
Notifications
You must be signed in to change notification settings - Fork 1
/
learning_rate_on_plateau.py
138 lines (126 loc) · 5.73 KB
/
learning_rate_on_plateau.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import logging
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import backend
# Note: the idea was to add the restoration of the best weights at each decay to `tf.keras.callbacks.ReduceLROnPlateau`.
class CustomReduceLearningRateOnPlateauCallback(tf.keras.callbacks.Callback):
"""Reduce learning rate when a metric has stopped improving.
Models often benefit from reducing the learning rate by a factor
of 2-10 once learning stagnates. This callback monitors a
quantity and if no improvement is seen for a 'patience' number
of epochs, the learning rate is reduced and the current model weights are
set to the weights of the best model seen so far (if necessary).
Example:
```python
reduce_lr = ReduceLROnPlateau(
monitor='val_loss', factor=0.2, patience=5, min_lr=0.001
)
model.fit(X_train, Y_train, callbacks=[reduce_lr])
```
Arguments:
monitor: quantity to be monitored.
factor: factor by which the learning rate will be reduced. new_lr = lr *
factor
patience: number of epochs with no improvement after which learning rate
will be reduced.
verbose: int. 0: quiet, 1: update messages.
mode: one of {auto, min, max}. In `min` mode, lr will be reduced when the
quantity monitored has stopped decreasing; in `max` mode it will be
reduced when the quantity monitored has stopped increasing; in `auto`
mode, the direction is automatically inferred from the name of the
monitored quantity.
min_delta: threshold for measuring the new optimum, to only focus on
significant changes.
cooldown: number of epochs to wait before resuming normal operation after
lr has been reduced.
min_lr: lower bound on the learning rate.
restore_best_weights: bool. True: go back to the best model seen so far
at each reduction, False: keep the current model
"""
def __init__(
self,
monitor='val_loss',
factor=0.1,
patience=10,
verbose=0,
mode='auto',
min_delta=1e-4,
cooldown=0,
min_lr=0,
restore_best_weights=True,
**kwargs
):
super().__init__()
self.monitor = monitor
if factor >= 1.0:
raise ValueError('ReduceLROnPlateau ' 'does not support a factor >= 1.0.')
if 'epsilon' in kwargs:
min_delta = kwargs.pop('epsilon')
logging.warning('`epsilon` argument is deprecated and '
'will be removed, use `min_delta` instead.')
self.factor = factor
self.min_lr = min_lr
self.min_delta = min_delta
self.patience = patience
self.verbose = verbose
self.cooldown = cooldown
self.cooldown_counter = 0 # Cooldown counter.
self.wait = 0
self.best = 0
self.mode = mode
self.monitor_op = None
self.restore_best_weights = restore_best_weights
self.best_weights = None
self._reset()
def _reset(self):
"""Resets wait counter and cooldown counter."""
if self.mode not in ['auto', 'min', 'max']:
logging.warning('Learning Rate Plateau Reducing mode %s is unknown, '
'fallback to auto mode.', self.mode)
self.mode = 'auto'
if self.mode == 'min' or (self.mode == 'auto' and 'acc' not in self.monitor):
self.monitor_op = lambda a, b: np.less(a, b - self.min_delta)
self.best = np.Inf
else:
self.monitor_op = lambda a, b: np.greater(a, b + self.min_delta)
self.best = -np.Inf
self.cooldown_counter = 0
self.wait = 0
def on_train_begin(self, logs=None):
self._reset()
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
logs['lr'] = backend.get_value(self.model.optimizer.lr)
current = logs.get(self.monitor)
if current is None:
logging.warning('Reduce LR on plateau conditioned on metric `%s` '
'which is not available. Available metrics are: %s',
self.monitor, ','.join(list(logs.keys())))
else:
if self.in_cooldown():
self.cooldown_counter -= 1
self.wait = 0
if self.monitor_op(current, self.best):
self.best = current
self.wait = 0
if self.restore_best_weights:
self.best_weights = self.model.get_weights()
elif not self.in_cooldown():
self.wait += 1
if self.wait >= self.patience:
old_lr = float(backend.get_value(self.model.optimizer.lr))
if old_lr > self.min_lr:
new_lr = old_lr * self.factor
new_lr = max(new_lr, self.min_lr)
backend.set_value(self.model.optimizer.lr, new_lr)
if self.verbose > 0:
print('\nEpoch %05d: ReduceLROnPlateau reducing learning '
'rate to %s.' % (epoch + 1, new_lr))
if self.restore_best_weights:
if self.verbose > 0:
print('Restoring model weights from the end of the best epoch.')
self.model.set_weights(self.best_weights)
self.cooldown_counter = self.cooldown
self.wait = 0
def in_cooldown(self):
return self.cooldown_counter > 0