-
Notifications
You must be signed in to change notification settings - Fork 0
/
generator.py
73 lines (64 loc) · 2.45 KB
/
generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from __future__ import absolute_import
from __future__ import print_function
import constants as c
from data_reader import DataReader
from model import Model
import tensorflow as tf
class LyricsGenerator():
def __init__(self, model_name='model', test=False):
self.session = tf.Session()
print('Process data...')
self.data_reader = DataReader()
self.vocab = self.data_reader.get_vocab()
print('Init model...')
self.model = Model(self.session, self.vocab, c.BATCH_SIZE, c.SEQ_LEN,
c.CELL_SIZE, c.NUM_LAYERS, test)
print('Init variables...')
self.test = test
self.saver = tf.train.Saver(max_to_keep=None)
self.session.run(tf.global_variables_initializer())
self.model_name = model_name
def train(self):
"""Runs a training loop on the model.
"""
while True:
inputs, targets = self.data_reader.get_train_batch(
c.BATCH_SIZE, c.SEQ_LEN)
print('Training model...')
feed_dict = {
self.model.inputs: inputs,
self.model.targets: targets
}
global_step, loss, _ = self.session.run(
[self.model.global_step, self.model.loss, self.model.train_op],
feed_dict=feed_dict)
print('Step: %d | loss: %f' % (global_step, loss))
if global_step % c.MODEL_SAVE_FREQ == 0:
print('Saving model...')
model_path = '{}{}.ckpt'.format(c.MODEL_SAVE_DIR,
self.model_name)
self.saver.save(self.session,
model_path,
global_step=global_step)
def generate(self):
"""Generate the lyrics
"""
return self.model.generate()
def run(self):
# if self.test and self._load_saved_model():
if self.test:
return self.generate()
else:
return self.train()
def _load_saved_model(self):
print("model loading ...")
ok = True
try:
model_path = '{}{}.ckpt-{}'.format(c.MODEL_SAVE_DIR,
self.model_name,
c.MODEL_SAVE_FREQ)
self.saver.restore(self.session, model_path)
except ValueError:
ok = False
print("Done!")
return ok