-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathfuncs.py
124 lines (107 loc) · 4.72 KB
/
funcs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import sys
import numpy as np
from tensorflow import keras
import tensorflow.keras.backend as K
class LR_Warmup(keras.callbacks.Callback):
def __init__(self, lr_base=0.01, min_lr=0.0001, decay=0, warmup_epochs=0):
self.num_passed_batchs = 0
self.warmup_epochs = warmup_epochs
self.lr = lr_base
self.min_lr = min_lr
self.decay = decay
self.steps_per_epoch = 0
def on_batch_begin(self, batch, logs=None):
if self.steps_per_epoch == 0:
if self.params['steps'] == None:
self.steps_per_epoch = np.ceil(1. * self.params['samples'] / self.params['batch_size'])
else:
self.steps_per_epoch = self.params['steps']
if self.num_passed_batchs < self.steps_per_epoch * self.warmup_epochs:
K.set_value(self.model.optimizer.lr,
self.lr * (self.num_passed_batchs + 1) / self.steps_per_epoch / self.warmup_epochs)
else:
K.set_value(self.model.optimizer.lr,
max(self.min_lr, self.lr * ((1-self.decay)**(self.num_passed_batchs-self.steps_per_epoch*self.warmup_epochs))))
self.num_passed_batchs += 1
def on_epoch_begin(self,epoch,logs=None):
print("learning_rate:", K.get_value(self.model.optimizer.lr))
class EarlyStopping(keras.callbacks.Callback):
def __init__(self,
monitor='val_loss',
min_delta=0,
patience=0,
verbose=0,
start_epoch=0,
mode='auto',
baseline=None,
restore_best_weights=False):
super(EarlyStopping, self).__init__()
self.monitor = monitor
self.patience = patience
self.verbose = verbose
self.baseline = baseline
self.min_delta = abs(min_delta)
self.wait = 0
self.stopped_epoch = 0
self.start_epoch = start_epoch
self.restore_best_weights = restore_best_weights
self.best_weights = None
if mode not in ['auto', 'min', 'max']:
logging.warning('EarlyStopping mode %s is unknown, '
'fallback to auto mode.', mode)
mode = 'auto'
if mode == 'min':
self.monitor_op = np.less
elif mode == 'max':
self.monitor_op = np.greater
else:
if 'acc' in self.monitor:
self.monitor_op = np.greater
else:
self.monitor_op = np.less
if self.monitor_op == np.greater:
self.min_delta *= 1
else:
self.min_delta *= -1
def on_train_begin(self, logs=None):
# Allow instances to be re-used
self.wait = 0
self.stopped_epoch = 0
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
self.best_weights = None
def on_epoch_end(self, epoch, logs=None):
current = self.get_monitor_value(logs)
if current is None or epoch < self.start_epoch:
return
if self.restore_best_weights and self.best_weights is None:
# Restore the weights after first epoch if no progress is ever made.
self.best_weights = self.model.get_weights()
self.wait += 1
if self._is_improvement(current, self.best):
self.best = current
if self.restore_best_weights:
self.best_weights = self.model.get_weights()
# Only restart wait if we beat both the baseline and our previous best.
if self.baseline is None or self._is_improvement(current, self.baseline):
self.wait = 0
# Only check after the first epoch.
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
if self.restore_best_weights and self.best_weights is not None:
if self.verbose > 0:
print('Restoring model weights from the end of the best epoch.')
self.model.set_weights(self.best_weights)
def on_train_end(self, logs=None):
if self.stopped_epoch > 0 and self.verbose > 0:
print('Epoch %05d: early stopping' % (self.stopped_epoch + 1))
def get_monitor_value(self, logs):
logs = logs or {}
monitor_value = logs.get(self.monitor)
if monitor_value is None:
logging.warning('Early stopping conditioned on metric `%s` '
'which is not available. Available metrics are: %s',
self.monitor, ','.join(list(logs.keys())))
return monitor_value
def _is_improvement(self, monitor_value, reference_value):
return self.monitor_op(monitor_value - self.min_delta, reference_value)