forked from clvrai/BicycleGAN-Tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
encoder.py
72 lines (61 loc) · 3.05 KB
/
encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import tensorflow as tf
from utils import logger
import ops
class Encoder(object):
def __init__(self, name, is_train, norm='instance', activation='leaky',
image_size=128, latent_dim=8, use_resnet=True):
logger.info('Init Encoder %s', name)
self.name = name
self._is_train = is_train
self._norm = norm
self._activation = activation
self._reuse = False
self._image_size = image_size
self._latent_dim = latent_dim
self._use_resnet = use_resnet
def __call__(self, input):
if self._use_resnet:
return self._resnet(input)
else:
return self._convnet(input)
def _convnet(self, input):
with tf.variable_scope(self.name, reuse=self._reuse):
num_filters = [64, 128, 256, 512, 512, 512, 512]
if self._image_size == 256:
num_filters.append(512)
E = input
for i, n in enumerate(num_filters):
E = ops.conv_block(E, n, 'C{}_{}'.format(n, i), 4, 2, self._is_train,
self._reuse, norm=self._norm if i else None, activation='leaky')
E = ops.flatten(E)
mu = ops.mlp(E, self._latent_dim, 'FC8_mu', self._is_train, self._reuse,
norm=None, activation=None)
log_sigma = ops.mlp(E, self._latent_dim, 'FC8_sigma', self._is_train, self._reuse,
norm=None, activation=None)
z = mu + tf.random_normal(shape=tf.shape(self._latent_dim)) * tf.exp(log_sigma)
self._reuse = True
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)
return z, mu, log_sigma
def _resnet(self, input):
with tf.variable_scope(self.name, reuse=self._reuse):
num_filters = [128, 256, 512, 512]
if self._image_size == 256:
num_filters.append(512)
E = input
E = ops.conv_block(E, 64, 'C{}_{}'.format(64, 0), 4, 2, self._is_train,
self._reuse, norm=None, activation='leaky', bias=True)
for i, n in enumerate(num_filters):
E = ops.residual(E, n, 'res{}_{}'.format(n, i + 1), self._is_train,
self._reuse, norm=self._norm, bias=True)
E = tf.nn.avg_pool(E, [1, 2, 2, 1], [1, 2, 2, 1], 'SAME')
E = tf.nn.relu(E)
E = tf.nn.avg_pool(E, [1, 8, 8, 1], [1, 8, 8, 1], 'SAME')
E = ops.flatten(E)
mu = ops.mlp(E, self._latent_dim, 'FC8_mu', self._is_train, self._reuse,
norm=None, activation=None)
log_sigma = ops.mlp(E, self._latent_dim, 'FC8_sigma', self._is_train, self._reuse,
norm=None, activation=None)
z = mu + tf.random_normal(shape=tf.shape(self._latent_dim)) * tf.exp(log_sigma)
self._reuse = True
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.name)
return z, mu, log_sigma