-
Notifications
You must be signed in to change notification settings - Fork 8
/
trainTextToFace.py
28 lines (25 loc) · 1 KB
/
trainTextToFace.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import numpy as np
from model import textTodlatentsModel
from bert_serving.client import BertClient
def loadData(dataPath):
textEmbedding, imageDlatents = [], []
bc = BertClient()
with open(dataPath) as file:
for line in file.readlines():
data = line.strip().split("\t")
if len(data) != 2:
continue
imageName = data[0]
text = data[1]
sentenceEmbedding = bc.encode([text])
textEmbedding.append(sentenceEmbedding)
imageDlatent = np.load("./latent_representations/" + imageName.split(".")[0] + ".npy")
imageDlatents.append(imageDlatent)
return np.array(textEmbedding), np.array(imageDlatents)
def train():
model = textTodlatentsModel().build()
sentenceEmbedding, imageDlatents = loadData("./data/train.txt")
model.fit(sentenceEmbedding, imageDlatents, batch_size = 5, nb_epoch = 1000)
model.save("./model/textEmbeddingDlatents.h5")
if __name__ == "__main__":
train()