From 600ee5cc1aef7b6b35a1ea702828da4870056ccf Mon Sep 17 00:00:00 2001 From: Ian Goodfellow Date: Mon, 9 Jun 2014 11:28:59 -0400 Subject: [PATCH] Copy the code and hyperparameters from galatea --- README.md | 44 +- __init__.py | 1285 ++++++++++++++++++++++++++++++ cifar10_convolutional.yaml | 174 ++++ cifar10_fully_connected.yaml | 114 +++ deconv.py | 384 +++++++++ mnist.yaml | 113 +++ parzen_ll.py | 158 ++++ sgd.py | 1137 ++++++++++++++++++++++++++ sgd_alt.py | 1151 ++++++++++++++++++++++++++ show_gen_weights.py | 59 ++ show_inpaint_samples.py | 32 + show_samples.py | 50 ++ show_samples_cifar_conv_paper.py | 44 + show_samples_cifar_full_paper.py | 41 + show_samples_inpaint.py | 57 ++ show_samples_mnist_paper.py | 39 + show_samples_tfd.py | 19 + show_samples_tfd_paper.py | 39 + test_deconv.py | 52 ++ tfd_pretrain/pretrain.yaml | 115 +++ tfd_pretrain/train.yaml | 112 +++ 21 files changed, 5216 insertions(+), 3 deletions(-) create mode 100644 __init__.py create mode 100644 cifar10_convolutional.yaml create mode 100644 cifar10_fully_connected.yaml create mode 100644 deconv.py create mode 100644 mnist.yaml create mode 100644 parzen_ll.py create mode 100644 sgd.py create mode 100644 sgd_alt.py create mode 100644 show_gen_weights.py create mode 100644 show_inpaint_samples.py create mode 100644 show_samples.py create mode 100644 show_samples_cifar_conv_paper.py create mode 100644 show_samples_cifar_full_paper.py create mode 100644 show_samples_inpaint.py create mode 100644 show_samples_mnist_paper.py create mode 100644 show_samples_tfd.py create mode 100644 show_samples_tfd_paper.py create mode 100644 test_deconv.py create mode 100644 tfd_pretrain/pretrain.yaml create mode 100644 tfd_pretrain/train.yaml diff --git a/README.md b/README.md index f307dbe..29400e6 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,42 @@ -adversarial -=========== +Generative Adversarial Networks +=============================== -Code and hyperparameters for the paper "Generative Adversarial Networks" +This repository contains the code and hyperparameters for the paper: + +"Generative Adversarial Networks." Ian J. Goodfellow, Jean Pouget-Abadie, +Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, +Yoshua Bengio. ArXiv 2014. + +Please cite this paper if you use the code in this repository as part of +a published research project. + +We are an academic lab, not a software company, and have no personnel +devoted to documenting and maintaing this research code. +Therefore this code is offered with absolutely no support. +Exact reproduction of the numbers in the paper depends on exact +reproduction of many factors, +including the version of all software dependencies and the choice of +underlying hardware (GPU model, etc). We used NVIDA Ge-Force GTX-580 +graphics cards; other hardware will use different tree structures for +summation and incur different rounding error. If you do not reproduce our +setup exactly you should expect to need to re-tune your hyperparameters +slight for your new setup. + +Moreover, we have not integrated any unit tests for this code into Theano +or Pylearn2 so subsequent changes to those libraries may break the code +in this repository. If you encounter problems with this code, you should +make sure that you are using the development branch of Pylearn2 and Theano, +and use "git checkout" to go to a commit from approximately June 9, 2014. + +This code itself requires no installation besides making sure that the +"adversarial" directory is in a directory in your PYTHONPATH. If +installed correctly, 'python -c "import adversarial"' will work. You +must also install Pylearn2 and Pylearn2's dependencies (Theano, numpy, +etc.) + +parzen_ll.py is the script used to estimate the log likelihood of the +model using the Parzen density technique. + +Call pylearn2/scripts/train.py on the various yaml files in this repository +to train the model for each dataset reported in the paper. The names of +*.yaml are fairly self-explanatory. diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..74984a4 --- /dev/null +++ b/__init__.py @@ -0,0 +1,1285 @@ +""" +Code for "Generative Adversarial Networks". Please cite the ArXiv paper in +any published research work making use of this code. +""" +import functools +wraps = functools.wraps +import itertools +import numpy +np = numpy +import theano +import warnings + +from theano.compat import OrderedDict +from theano.sandbox.rng_mrg import MRG_RandomStreams +from theano import tensor as T + +from pylearn2.space import VectorSpace +from pylearn2.costs.cost import Cost +from pylearn2.costs.cost import DefaultDataSpecsMixin +from pylearn2.models.mlp import Layer +from pylearn2.models.mlp import Linear +from pylearn2.models import Model +from pylearn2.space import CompositeSpace +from pylearn2.train_extensions import TrainExtension +from pylearn2.utils import block_gradient +from pylearn2.utils import safe_zip +from pylearn2.utils import serial +from pylearn2.utils import sharedX + +class AdversaryPair(Model): + + def __init__(self, generator, discriminator, inferer=None, + inference_monitoring_batch_size=128, + monitor_generator=True, + monitor_discriminator=True, + monitor_inference=True, + shrink_d = 0.): + Model.__init__(self) + self.__dict__.update(locals()) + del self.self + + def __setstate__(self, state): + self.__dict__.update(state) + if 'inferer' not in state: + self.inferer = None + if 'inference_monitoring_batch_size' not in state: + self.inference_monitoring_batch_size = 128 # TODO: HACK + if 'monitor_generator' not in state: + self.monitor_generator = True + if 'monitor_discriminator' not in state: + self.monitor_discriminator = True + if 'monitor_inference' not in state: + self.monitor_inference = True + + def get_params(self): + p = self.generator.get_params() + self.discriminator.get_params() + if hasattr(self, 'inferer') and self.inferer is not None: + p += self.inferer.get_params() + return p + + def get_input_space(self): + return self.discriminator.get_input_space() + + def get_weights_topo(self): + return self.discriminator.get_weights_topo() + + def get_weights(self): + return self.discriminator.get_weights() + + def get_weights_format(self): + return self.discriminator.get_weights_format() + + def get_weights_view_shape(self): + return self.discriminator.get_weights_view_shape() + + def get_monitoring_channels(self, data): + rval = OrderedDict() + + g_ch = self.generator.get_monitoring_channels(data) + d_ch = self.discriminator.get_monitoring_channels((data, None)) + samples = self.generator.sample(100) + d_samp_ch = self.discriminator.get_monitoring_channels((samples, None)) + + i_ch = OrderedDict() + if self.inferer is not None: + batch_size = self.inference_monitoring_batch_size + sample, noise, _ = self.generator.sample_and_noise(batch_size) + i_ch.update(self.inferer.get_monitoring_channels((sample, noise))) + + if self.monitor_generator: + for key in g_ch: + rval['gen_' + key] = g_ch[key] + if self.monitor_discriminator: + for key in d_ch: + rval['dis_on_data_' + key] = d_samp_ch[key] + for key in d_ch: + rval['dis_on_samp_' + key] = d_ch[key] + if self.monitor_inference: + for key in i_ch: + rval['inf_' + key] = i_ch[key] + return rval + + def get_monitoring_data_specs(self): + + space = self.discriminator.get_input_space() + source = self.discriminator.get_input_source() + return (space, source) + + def _modify_updates(self, updates): + self.generator.modify_updates(updates) + self.discriminator.modify_updates(updates) + if self.shrink_d != 0.: + for param in self.discriminator.get_params(): + if param in updates: + updates[param] = self.shrink_d * updates[param] + if self.inferer is not None: + self.inferer.modify_updates(updates) + + def get_lr_scalers(self): + + rval = self.generator.get_lr_scalers() + rval.update(self.discriminator.get_lr_scalers()) + return rval + +def add_layers(mlp, pretrained, start_layer=0): + model = serial.load(pretrained) + pretrained_layers = model.generator.mlp.layers + assert pretrained_layers[start_layer].get_input_space() == mlp.layers[-1].get_output_space() + mlp.layers.extend(pretrained_layers[start_layer:]) + return mlp + + + +class Generator(Model): + + def __init__(self, mlp, noise = "gaussian", monitor_ll = False, ll_n_samples = 100, ll_sigma = 0.2): + Model.__init__(self) + self.__dict__.update(locals()) + del self.self + self.theano_rng = MRG_RandomStreams(2014 * 5 + 27) + + + def sample_and_noise(self, num_samples, default_input_include_prob=1., default_input_scale=1., all_g_layers=False): + n = self.mlp.get_input_space().get_total_dimension() + noise = self.get_noise((num_samples, n)) + formatted_noise = VectorSpace(n).format_as(noise, self.mlp.get_input_space()) + if all_g_layers: + rval = self.mlp.dropout_fprop(formatted_noise, default_input_include_prob=default_input_include_prob, default_input_scale=default_input_scale, return_all=all_g_layers) + other_layers, rval = rval[:-1], rval[-1] + else: + rval = self.mlp.dropout_fprop(formatted_noise, default_input_include_prob=default_input_include_prob, default_input_scale=default_input_scale) + other_layers = None + return rval, formatted_noise, other_layers + + def sample(self, num_samples, default_input_include_prob=1., default_input_scale=1.): + sample, _, _ = self.sample_and_noise(num_samples, default_input_include_prob, default_input_scale) + return sample + + def inpainting_sample_and_noise(self, X, default_input_include_prob=1., default_input_scale=1.): + # Very hacky! Specifically for inpainting right half of CIFAR-10 given left half + # assumes X is b01c + assert X.ndim == 4 + input_space = self.mlp.get_input_space() + n = input_space.get_total_dimension() + image_size = input_space.shape[0] + half_image = int(image_size / 2) + data_shape = (X.shape[0], image_size, half_image, input_space.num_channels) + + noise = self.theano_rng.normal(size=data_shape, dtype='float32') + Xg = T.set_subtensor(X[:,:,half_image:,:], noise) + sampled_part, noise = self.mlp.dropout_fprop(Xg, default_input_include_prob=default_input_include_prob, default_input_scale=default_input_scale), noise + sampled_part = sampled_part.reshape(data_shape) + rval = T.set_subtensor(X[:, :, half_image:, :], sampled_part) + return rval, noise + + + def get_monitoring_channels(self, data): + if data is None: + m = 100 + else: + m = data.shape[0] + n = self.mlp.get_input_space().get_total_dimension() + noise = self.get_noise((m, n)) + rval = OrderedDict() + + try: + rval.update(self.mlp.get_monitoring_channels((noise, None))) + except Exception: + warnings.warn("something went wrong with generator.mlp's monitoring channels") + + if self.monitor_ll: + rval['ll'] = T.cast(self.ll(data, self.ll_n_samples, self.ll_sigma), + theano.config.floatX).mean() + rval['nll'] = -rval['ll'] + return rval + + def get_noise(self, size): + + if not hasattr(self, 'noise'): + self.noise = "gaussian" + if self.noise == "uniform": + return self.theano_rng.uniform(low=-np.sqrt(3), high=np.sqrt(3), size=size, dtype='float32') + elif self.noise == "gaussian": + return self.theano_rng.normal(size=size, dtype='float32') + elif self.noise == "spherical": + noise = self.theano_rng.normal(size=size, dtype='float32') + noise = noise / T.maximum(1e-7, T.sqrt(T.sqr(noise).sum(axis=1))).dimshuffle(0, 'x') + return noise + else: + raise NotImplementedError("noise should be gaussian or uniform") + + + def get_params(self): + return self.mlp.get_params() + + def get_output_space(self): + return self.mlp.get_output_space() + + def ll(self, data, n_samples, sigma): + + samples = self.sample(n_samples) + output_space = self.mlp.get_output_space() + if 'Conv2D' in str(output_space): + samples = output_space.convert(samples, output_space.axes, ('b', 0, 1, 'c')) + samples = samples.flatten(2) + data = output_space.convert(data, output_space.axes, ('b', 0, 1, 'c')) + data = data.flatten(2) + parzen = theano_parzen(data, samples, sigma) + return parzen + + def _modify_updates(self, updates): + self.mlp.modify_updates(updates) + + def get_lr_scalers(self): + return self.mlp.get_lr_scalers() + + def __setstate__(self, state): + self.__dict__.update(state) + if 'monitor_ll' not in state: + self.monitor_ll = False + + +class IntrinsicDropoutGenerator(Generator): + def __init__(self, default_input_include_prob, default_input_scale, + input_include_probs=None, input_scales=None, **kwargs): + super(IntrinsicDropoutGenerator, self).__init__(**kwargs) + self.__dict__.update(locals()) + del self.self + + def sample_and_noise(self, num_samples, default_input_include_prob=1., default_input_scale=1., all_g_layers=False): + if all_g_layers: + raise NotImplementedError() + n = self.mlp.get_input_space().get_total_dimension() + noise = self.theano_rng.normal(size=(num_samples, n), dtype='float32') + formatted_noise = VectorSpace(n).format_as(noise, self.mlp.get_input_space()) + # ignores dropout args + default_input_include_prob = self.default_input_include_prob + default_input_scale = self.default_input_scale + input_include_probs = self.input_include_probs + input_scales = self.input_scales + return self.mlp.dropout_fprop(formatted_noise, + default_input_include_prob=default_input_include_prob, + default_input_scale=default_input_scale, + input_include_probs=input_include_probs, + input_scales=input_scales), formatted_noise, None + +class AdversaryCost2(DefaultDataSpecsMixin, Cost): + """ + """ + + # Supplies own labels, don't get them from the dataset + supervised = False + + def __init__(self, scale_grads=1, target_scale=.1, + discriminator_default_input_include_prob = 1., + discriminator_input_include_probs=None, + discriminator_default_input_scale=1., + discriminator_input_scales=None, + generator_default_input_include_prob = 1., + generator_default_input_scale=1., + inference_default_input_include_prob=None, + inference_input_include_probs=None, + inference_default_input_scale=1., + inference_input_scales=None, + init_now_train_generator=True, + ever_train_discriminator=True, + ever_train_generator=True, + ever_train_inference=True, + no_drop_in_d_for_g=False, + alternate_g = False, + infer_layer=None, + noise_both = 0., + blend_obj = False, + minimax_coeff = 1., + zurich_coeff = 1.): + self.__dict__.update(locals()) + del self.self + # These allow you to dynamically switch off training parts. + # If the corresponding ever_train_* is False, these have + # no effect. + self.now_train_generator = sharedX(init_now_train_generator) + self.now_train_discriminator = sharedX(numpy.array(1., dtype='float32')) + self.now_train_inference = sharedX(numpy.array(1., dtype='float32')) + + def expr(self, model, data, **kwargs): + S, d_obj, g_obj, i_obj = self.get_samples_and_objectives(model, data) + l = [] + # This stops stuff from ever getting computed if we're not training + # it. + if self.ever_train_discriminator: + l.append(d_obj) + if self.ever_train_generator: + l.append(g_obj) + if self.ever_train_inference: + l.append(i_obj) + return sum(l) + + def get_samples_and_objectives(self, model, data): + space, sources = self.get_data_specs(model) + space.validate(data) + assert isinstance(model, AdversaryPair) + g = model.generator + d = model.discriminator + + # Note: this assumes data is design matrix + X = data + m = data.shape[space.get_batch_axis()] + y1 = T.alloc(1, m, 1) + y0 = T.alloc(0, m, 1) + # NOTE: if this changes to optionally use dropout, change the inference + # code below to use a non-dropped-out version. + S, z, other_layers = g.sample_and_noise(m, default_input_include_prob=self.generator_default_input_include_prob, default_input_scale=self.generator_default_input_scale, all_g_layers=(self.infer_layer is not None)) + + if self.noise_both != 0.: + rng = MRG_RandomStreams(2014 / 6 + 2) + S = S + rng.normal(size=S.shape, dtype=S.dtype) * self.noise_both + X = X + rng.normal(size=X.shape, dtype=S.dtype) * self.noise_both + + y_hat1 = d.dropout_fprop(X, self.discriminator_default_input_include_prob, + self.discriminator_input_include_probs, + self.discriminator_default_input_scale, + self.discriminator_input_scales) + y_hat0 = d.dropout_fprop(S, self.discriminator_default_input_include_prob, + self.discriminator_input_include_probs, + self.discriminator_default_input_scale, + self.discriminator_input_scales) + + d_obj = 0.5 * (d.layers[-1].cost(y1, y_hat1) + d.layers[-1].cost(y0, y_hat0)) + + if self.no_drop_in_d_for_g: + y_hat0_no_drop = d.dropout_fprop(S) + g_obj = d.layers[-1].cost(y1, y_hat0_no_drop) + else: + g_obj = d.layers[-1].cost(y1, y_hat0) + + if self.blend_obj: + g_obj = (self.zurich_coeff * g_obj - self.minimax_coeff * d_obj) / (self.zurich_coeff + self.minimax_coeff) + + if model.inferer is not None: + # Change this if we ever switch to using dropout in the + # construction of S. + S_nograd = block_gradient(S) # Redundant as long as we have custom get_gradients + pred = model.inferer.dropout_fprop(S_nograd, self.inference_default_input_include_prob, + self.inference_input_include_probs, + self.inference_default_input_scale, + self.inference_input_scales) + if self.infer_layer is None: + target = z + else: + target = other_layers[self.infer_layer] + i_obj = model.inferer.layers[-1].cost(target, pred) + else: + i_obj = 0 + + return S, d_obj, g_obj, i_obj + + def get_gradients(self, model, data, **kwargs): + space, sources = self.get_data_specs(model) + space.validate(data) + assert isinstance(model, AdversaryPair) + g = model.generator + d = model.discriminator + + S, d_obj, g_obj, i_obj = self.get_samples_and_objectives(model, data) + + g_params = g.get_params() + d_params = d.get_params() + for param in g_params: + assert param not in d_params + for param in d_params: + assert param not in g_params + d_grads = T.grad(d_obj, d_params) + g_grads = T.grad(g_obj, g_params) + + if self.scale_grads: + S_grad = T.grad(g_obj, S) + scale = T.maximum(1., self.target_scale / T.sqrt(T.sqr(S_grad).sum())) + g_grads = [g_grad * scale for g_grad in g_grads] + + rval = OrderedDict() + zeros = itertools.repeat(theano.tensor.constant(0., dtype='float32')) + if self.ever_train_discriminator: + rval.update(OrderedDict(safe_zip(d_params, [self.now_train_discriminator * dg for dg in d_grads]))) + else: + rval.update(OrderedDict(zip(d_params, zeros))) + if self.ever_train_generator: + rval.update(OrderedDict(safe_zip(g_params, [self.now_train_generator * gg for gg in g_grads]))) + else: + rval.update(OrderedDict(zip(g_params, zeros))) + if self.ever_train_inference and model.inferer is not None: + i_params = model.inferer.get_params() + i_grads = T.grad(i_obj, i_params) + rval.update(OrderedDict(safe_zip(i_params, [self.now_train_inference * ig for ig in i_grads]))) + elif model.inferer is not None: + rval.update(OrderedDict(model.inferer.get_params(), zeros)) + + updates = OrderedDict() + + # Two d steps for every g step + if self.alternate_g: + updates[self.now_train_generator] = 1. - self.now_train_generator + + return rval, updates + + def get_monitoring_channels(self, model, data, **kwargs): + + rval = OrderedDict() + + m = data.shape[0] + + g = model.generator + d = model.discriminator + + y_hat = d.fprop(data) + + rval['false_negatives'] = T.cast((y_hat < 0.5).mean(), 'float32') + + samples = g.sample(m) + y_hat = d.fprop(samples) + rval['false_positives'] = T.cast((y_hat > 0.5).mean(), 'float32') + # y = T.alloc(0., m, 1) + cost = d.cost_from_X((samples, y_hat)) + sample_grad = T.grad(-cost, samples) + rval['sample_grad_norm'] = T.sqrt(T.sqr(sample_grad).sum()) + _S, d_obj, g_obj, i_obj = self.get_samples_and_objectives(model, data) + if model.monitor_inference and i_obj != 0: + rval['objective_i'] = i_obj + if model.monitor_discriminator: + rval['objective_d'] = d_obj + if model.monitor_generator: + rval['objective_g'] = g_obj + + rval['now_train_generator'] = self.now_train_generator + return rval + +def recapitate_discriminator(pair_path, new_head): + pair = serial.load(pair_path) + d = pair.discriminator + del d.layers[-1] + d.add_layers([new_head]) + return d + +def theano_parzen(data, mu, sigma): + """ + Credit: Yann N. Dauphin + """ + x = data + + a = ( x.dimshuffle(0, 'x', 1) - mu.dimshuffle('x', 0, 1) ) / sigma + + E = log_mean_exp(-0.5*(a**2).sum(2)) + + Z = mu.shape[1] * T.log(sigma * numpy.sqrt(numpy.pi * 2)) + + #return theano.function([x], E - Z) + return E - Z + + +def log_mean_exp(a): + """ + Credit: Yann N. Dauphin + """ + + max_ = a.max(1) + + return max_ + T.log(T.exp(a - max_.dimshuffle(0, 'x')).mean(1)) + +class Sum(Layer): + """ + Monitoring channels are hardcoded for C01B batches + """ + + def __init__(self, layer_name): + Model.__init__(self) + self.__dict__.update(locals()) + del self.self + self._params = [] + + def set_input_space(self, space): + self.input_space = space + assert isinstance(space, CompositeSpace) + self.output_space = space.components[0] + + def fprop(self, state_below): + rval = state_below[0] + for i in xrange(1, len(state_below)): + rval = rval + state_below[i] + rval.came_from_sum = True + return rval + + @functools.wraps(Layer.get_layer_monitoring_channels) + def get_layer_monitoring_channels(self, state_below=None, + state=None, targets=None): + rval = OrderedDict() + + if state is None: + state = self.fprop(state_below) + vars_and_prefixes = [(state, '')] + + for var, prefix in vars_and_prefixes: + if not hasattr(var, 'ndim') or var.ndim != 4: + print "expected 4D tensor, got " + print var + print type(var) + if isinstance(var, tuple): + print "tuple length: ", len(var) + assert False + v_max = var.max(axis=(1, 2, 3)) + v_min = var.min(axis=(1, 2, 3)) + v_mean = var.mean(axis=(1, 2, 3)) + v_range = v_max - v_min + + # max_x.mean_u is "the mean over *u*nits of the max over + # e*x*amples" The x and u are included in the name because + # otherwise its hard to remember which axis is which when reading + # the monitor I use inner.outer rather than outer_of_inner or + # something like that because I want mean_x.* to appear next to + # each other in the alphabetical list, as these are commonly + # plotted together + for key, val in [('max_x.max_u', v_max.max()), + ('max_x.mean_u', v_max.mean()), + ('max_x.min_u', v_max.min()), + ('min_x.max_u', v_min.max()), + ('min_x.mean_u', v_min.mean()), + ('min_x.min_u', v_min.min()), + ('range_x.max_u', v_range.max()), + ('range_x.mean_u', v_range.mean()), + ('range_x.min_u', v_range.min()), + ('mean_x.max_u', v_mean.max()), + ('mean_x.mean_u', v_mean.mean()), + ('mean_x.min_u', v_mean.min())]: + rval[prefix+key] = val + + return rval + +def marginals(dataset): + return dataset.X.mean(axis=0) + +class ActivateGenerator(TrainExtension): + def __init__(self, active_after, value=1.): + self.__dict__.update(locals()) + del self.self + self.cur_epoch = 0 + + def on_monitor(self, model, dataset, algorithm): + if self.cur_epoch == self.active_after: + algorithm.cost.now_train_generator.set_value(np.array(self.value, dtype='float32')) + self.cur_epoch += 1 + +class InpaintingAdversaryCost(DefaultDataSpecsMixin, Cost): + """ + """ + + # Supplies own labels, don't get them from the dataset + supervised = False + + def __init__(self, scale_grads=1, target_scale=.1, + discriminator_default_input_include_prob = 1., + discriminator_input_include_probs=None, + discriminator_default_input_scale=1., + discriminator_input_scales=None, + generator_default_input_include_prob = 1., + generator_default_input_scale=1., + inference_default_input_include_prob=None, + inference_input_include_probs=None, + inference_default_input_scale=1., + inference_input_scales=None, + init_now_train_generator=True, + ever_train_discriminator=True, + ever_train_generator=True, + ever_train_inference=True, + no_drop_in_d_for_g=False, + alternate_g = False): + self.__dict__.update(locals()) + del self.self + # These allow you to dynamically switch off training parts. + # If the corresponding ever_train_* is False, these have + # no effect. + self.now_train_generator = sharedX(init_now_train_generator) + self.now_train_discriminator = sharedX(numpy.array(1., dtype='float32')) + self.now_train_inference = sharedX(numpy.array(1., dtype='float32')) + + def expr(self, model, data, **kwargs): + S, d_obj, g_obj, i_obj = self.get_samples_and_objectives(model, data) + return d_obj + g_obj + i_obj + + def get_samples_and_objectives(self, model, data): + space, sources = self.get_data_specs(model) + space.validate(data) + assert isinstance(model, AdversaryPair) + g = model.generator + d = model.discriminator + + # Note: this assumes data is b01c + X = data + assert X.ndim == 4 + m = data.shape[space.get_batch_axis()] + y1 = T.alloc(1, m, 1) + y0 = T.alloc(0, m, 1) + # NOTE: if this changes to optionally use dropout, change the inference + # code below to use a non-dropped-out version. + S, z = g.inpainting_sample_and_noise(X, default_input_include_prob=self.generator_default_input_include_prob, default_input_scale=self.generator_default_input_scale) + y_hat1 = d.dropout_fprop(X, self.discriminator_default_input_include_prob, + self.discriminator_input_include_probs, + self.discriminator_default_input_scale, + self.discriminator_input_scales) + y_hat0 = d.dropout_fprop(S, self.discriminator_default_input_include_prob, + self.discriminator_input_include_probs, + self.discriminator_default_input_scale, + self.discriminator_input_scales) + + d_obj = 0.5 * (d.layers[-1].cost(y1, y_hat1) + d.layers[-1].cost(y0, y_hat0)) + + if self.no_drop_in_d_for_g: + y_hat0_no_drop = d.dropout_fprop(S) + g_obj = d.layers[-1].cost(y1, y_hat0) + else: + g_obj = d.layers[-1].cost(y1, y_hat0) + + if model.inferer is not None: + # Change this if we ever switch to using dropout in the + # construction of S. + S_nograd = block_gradient(S) # Redundant as long as we have custom get_gradients + z_hat = model.inferer.dropout_fprop(S_nograd, self.inference_default_input_include_prob, + self.inference_input_include_probs, + self.inference_default_input_scale, + self.inference_input_scales) + i_obj = model.inferer.layers[-1].cost(z, z_hat) + else: + i_obj = 0 + + return S, d_obj, g_obj, i_obj + + def get_gradients(self, model, data, **kwargs): + space, sources = self.get_data_specs(model) + space.validate(data) + assert isinstance(model, AdversaryPair) + g = model.generator + d = model.discriminator + + S, d_obj, g_obj, i_obj = self.get_samples_and_objectives(model, data) + + g_params = g.get_params() + d_params = d.get_params() + for param in g_params: + assert param not in d_params + for param in d_params: + assert param not in g_params + d_grads = T.grad(d_obj, d_params) + g_grads = T.grad(g_obj, g_params) + + if self.scale_grads: + S_grad = T.grad(g_obj, S) + scale = T.maximum(1., self.target_scale / T.sqrt(T.sqr(S_grad).sum())) + g_grads = [g_grad * scale for g_grad in g_grads] + + rval = OrderedDict() + if self.ever_train_discriminator: + rval.update(OrderedDict(safe_zip(d_params, [self.now_train_discriminator * dg for dg in d_grads]))) + else: + rval.update(OrderedDict(zip(d_params, itertools.repeat(theano.tensor.constant(0., dtype='float32'))))) + + if self.ever_train_generator: + rval.update(OrderedDict(safe_zip(g_params, [self.now_train_generator * gg for gg in g_grads]))) + else: + rval.update(OrderedDict(zip(g_params, itertools.repeat(theano.tensor.constant(0., dtype='float32'))))) + + if self.ever_train_inference and model.inferer is not None: + i_params = model.inferer.get_params() + i_grads = T.grad(i_obj, i_params) + rval.update(OrderedDict(safe_zip(i_params, [self.now_train_inference * ig for ig in i_grads]))) + + updates = OrderedDict() + + # Two d steps for every g step + if self.alternate_g: + updates[self.now_train_generator] = 1. - self.now_train_generator + + return rval, updates + + def get_monitoring_channels(self, model, data, **kwargs): + + rval = OrderedDict() + + m = data.shape[0] + + g = model.generator + d = model.discriminator + + y_hat = d.fprop(data) + + rval['false_negatives'] = T.cast((y_hat < 0.5).mean(), 'float32') + + samples, noise = g.inpainting_sample_and_noise(data) + y_hat = d.fprop(samples) + rval['false_positives'] = T.cast((y_hat > 0.5).mean(), 'float32') + # y = T.alloc(0., m, 1) + cost = d.cost_from_X((samples, y_hat)) + sample_grad = T.grad(-cost, samples) + rval['sample_grad_norm'] = T.sqrt(T.sqr(sample_grad).sum()) + _S, d_obj, g_obj, i_obj = self.get_samples_and_objectives(model, data) + if i_obj != 0: + rval['objective_i'] = i_obj + rval['objective_d'] = d_obj + rval['objective_g'] = g_obj + + rval['now_train_generator'] = self.now_train_generator + return rval + +class Cycler(object): + + def __init__(self, k): + self.__dict__.update(locals()) + del self.self + self.i = 0 + + def __call__(self, sgd): + self.i = (self.i + 1) % self.k + sgd.cost.now_train_generator.set_value(np.cast['float32'](self.i == 0)) + +class NoiseCat(Layer): + + def __init__(self, new_dim, std, layer_name): + Layer.__init__(self) + self.__dict__.update(locals()) + del self.self + self._params = [] + + def set_input_space(self, space): + assert isinstance(space, VectorSpace) + self.input_space = space + self.output_space = VectorSpace(space.dim + self.new_dim) + self.theano_rng = MRG_RandomStreams(self.mlp.rng.randint(2 ** 16)) + + def fprop(self, state): + noise = self.theano_rng.normal(std=self.std, avg=0., size=(state.shape[0], self.new_dim), + dtype=state.dtype) + return T.concatenate((state, noise), axis=1) + +class RectifiedLinear(Layer): + + def __init__(self, layer_name, left_slope=0.0, **kwargs): + super(RectifiedLinear, self).__init__(**kwargs) + self.__dict__.update(locals()) + del self.self + self._params = [] + + def set_input_space(self, space): + self.input_space = space + self.output_space = space + + def fprop(self, state_below): + p = state_below + p = T.switch(p > 0., p, self.left_slope * p) + return p + +class Sigmoid(Layer): + + def __init__(self, layer_name, left_slope=0.0, **kwargs): + super(Sigmoid, self).__init__(**kwargs) + self.__dict__.update(locals()) + del self.self + self._params = [] + + def set_input_space(self, space): + self.input_space = space + self.output_space = space + + def fprop(self, state_below): + p = T.nnet.sigmoid(state_below) + return p + +class SubtractHalf(Layer): + + def __init__(self, layer_name, left_slope=0.0, **kwargs): + super(SubtractHalf, self).__init__(**kwargs) + self.__dict__.update(locals()) + del self.self + self._params = [] + + def set_input_space(self, space): + self.input_space = space + self.output_space = space + + def fprop(self, state_below): + return state_below - 0.5 + + def get_weights(self): + return self.mlp.layers[1].get_weights() + + def get_weights_format(self): + return self.mlp.layers[1].get_weights_format() + + def get_weights_view_shape(self): + return self.mlp.layers[1].get_weights_view_shape() + +class SubtractRealMean(Layer): + + def __init__(self, layer_name, dataset, also_sd = False, **kwargs): + super(SubtractRealMean, self).__init__(**kwargs) + self.__dict__.update(locals()) + del self.self + self._params = [] + self.mean = sharedX(dataset.X.mean(axis=0)) + if also_sd: + self.sd = sharedX(dataset.X.std(axis=0)) + del self.dataset + + def set_input_space(self, space): + self.input_space = space + self.output_space = space + + def fprop(self, state_below): + return (state_below - self.mean) / self.sd + + def get_weights(self): + return self.mlp.layers[1].get_weights() + + def get_weights_format(self): + return self.mlp.layers[1].get_weights_format() + + def get_weights_view_shape(self): + return self.mlp.layers[1].get_weights_view_shape() + + +class Clusterize(Layer): + + def __init__(self, scale, layer_name): + Layer.__init__(self) + self.__dict__.update(locals()) + del self.self + self._params = [] + + def set_input_space(self, space): + assert isinstance(space, VectorSpace) + self.input_space = space + self.output_space = space + self.theano_rng = MRG_RandomStreams(self.mlp.rng.randint(2 ** 16)) + + def fprop(self, state): + noise = self.theano_rng.binomial(size=state.shape, p=0.5, + dtype=state.dtype) * 2. - 1. + return state + self.scale * noise + + + +class ThresholdedAdversaryCost(DefaultDataSpecsMixin, Cost): + """ + """ + + # Supplies own labels, don't get them from the dataset + supervised = False + + def __init__(self, scale_grads=1, target_scale=.1, + discriminator_default_input_include_prob = 1., + discriminator_input_include_probs=None, + discriminator_default_input_scale=1., + discriminator_input_scales=None, + generator_default_input_include_prob = 1., + generator_default_input_scale=1., + inference_default_input_include_prob=None, + inference_input_include_probs=None, + inference_default_input_scale=1., + inference_input_scales=None, + init_now_train_generator=True, + ever_train_discriminator=True, + ever_train_generator=True, + ever_train_inference=True, + no_drop_in_d_for_g=False, + alternate_g = False, + infer_layer=None, + noise_both = 0.): + self.__dict__.update(locals()) + del self.self + # These allow you to dynamically switch off training parts. + # If the corresponding ever_train_* is False, these have + # no effect. + self.now_train_generator = sharedX(init_now_train_generator) + self.now_train_discriminator = sharedX(numpy.array(1., dtype='float32')) + self.now_train_inference = sharedX(numpy.array(1., dtype='float32')) + + def expr(self, model, data, **kwargs): + S, d_obj, g_obj, i_obj = self.get_samples_and_objectives(model, data) + l = [] + # This stops stuff from ever getting computed if we're not training + # it. + if self.ever_train_discriminator: + l.append(d_obj) + if self.ever_train_generator: + l.append(g_obj) + if self.ever_train_inference: + l.append(i_obj) + return sum(l) + + def get_samples_and_objectives(self, model, data): + space, sources = self.get_data_specs(model) + space.validate(data) + assert isinstance(model, AdversaryPair) + g = model.generator + d = model.discriminator + + # Note: this assumes data is design matrix + X = data + m = data.shape[space.get_batch_axis()] + y1 = T.alloc(1, m, 1) + y0 = T.alloc(0, m, 1) + # NOTE: if this changes to optionally use dropout, change the inference + # code below to use a non-dropped-out version. + S, z, other_layers = g.sample_and_noise(m, default_input_include_prob=self.generator_default_input_include_prob, default_input_scale=self.generator_default_input_scale, all_g_layers=(self.infer_layer is not None)) + + if self.noise_both != 0.: + rng = MRG_RandomStreams(2014 / 6 + 2) + S = S + rng.normal(size=S.shape, dtype=S.dtype) * self.noise_both + X = X + rng.normal(size=X.shape, dtype=S.dtype) * self.noise_both + + y_hat1 = d.dropout_fprop(X, self.discriminator_default_input_include_prob, + self.discriminator_input_include_probs, + self.discriminator_default_input_scale, + self.discriminator_input_scales) + y_hat0 = d.dropout_fprop(S, self.discriminator_default_input_include_prob, + self.discriminator_input_include_probs, + self.discriminator_default_input_scale, + self.discriminator_input_scales) + + d_obj = 0.5 * (d.layers[-1].cost(y1, y_hat1) + d.layers[-1].cost(y0, y_hat0)) + + if self.no_drop_in_d_for_g: + y_hat0_no_drop = d.dropout_fprop(S) + g_cost_mat = d.layers[-1].cost_matrix(y1, y_hat0_no_drop) + else: + g_cost_mat = d.layers[-1].cost_matrix(y1, y_hat0) + assert g_cost_mat.ndim == 2 + assert y_hat0.ndim == 2 + + mask = y_hat0 < 0.5 + masked_cost = g_cost_mat * mask + g_obj = masked_cost.mean() + + + if model.inferer is not None: + # Change this if we ever switch to using dropout in the + # construction of S. + S_nograd = block_gradient(S) # Redundant as long as we have custom get_gradients + pred = model.inferer.dropout_fprop(S_nograd, self.inference_default_input_include_prob, + self.inference_input_include_probs, + self.inference_default_input_scale, + self.inference_input_scales) + if self.infer_layer is None: + target = z + else: + target = other_layers[self.infer_layer] + i_obj = model.inferer.layers[-1].cost(target, pred) + else: + i_obj = 0 + + return S, d_obj, g_obj, i_obj + + def get_gradients(self, model, data, **kwargs): + space, sources = self.get_data_specs(model) + space.validate(data) + assert isinstance(model, AdversaryPair) + g = model.generator + d = model.discriminator + + S, d_obj, g_obj, i_obj = self.get_samples_and_objectives(model, data) + + g_params = g.get_params() + d_params = d.get_params() + for param in g_params: + assert param not in d_params + for param in d_params: + assert param not in g_params + d_grads = T.grad(d_obj, d_params) + g_grads = T.grad(g_obj, g_params) + + if self.scale_grads: + S_grad = T.grad(g_obj, S) + scale = T.maximum(1., self.target_scale / T.sqrt(T.sqr(S_grad).sum())) + g_grads = [g_grad * scale for g_grad in g_grads] + + rval = OrderedDict() + zeros = itertools.repeat(theano.tensor.constant(0., dtype='float32')) + if self.ever_train_discriminator: + rval.update(OrderedDict(safe_zip(d_params, [self.now_train_discriminator * dg for dg in d_grads]))) + else: + rval.update(OrderedDict(zip(d_params, zeros))) + if self.ever_train_generator: + rval.update(OrderedDict(safe_zip(g_params, [self.now_train_generator * gg for gg in g_grads]))) + else: + rval.update(OrderedDict(zip(g_params, zeros))) + if self.ever_train_inference and model.inferer is not None: + i_params = model.inferer.get_params() + i_grads = T.grad(i_obj, i_params) + rval.update(OrderedDict(safe_zip(i_params, [self.now_train_inference * ig for ig in i_grads]))) + elif model.inferer is not None: + rval.update(OrderedDict(model.inferer.get_params(), zeros)) + + updates = OrderedDict() + + # Two d steps for every g step + if self.alternate_g: + updates[self.now_train_generator] = 1. - self.now_train_generator + + return rval, updates + + def get_monitoring_channels(self, model, data, **kwargs): + + rval = OrderedDict() + + m = data.shape[0] + + g = model.generator + d = model.discriminator + + y_hat = d.fprop(data) + + rval['false_negatives'] = T.cast((y_hat < 0.5).mean(), 'float32') + + samples = g.sample(m) + y_hat = d.fprop(samples) + rval['false_positives'] = T.cast((y_hat > 0.5).mean(), 'float32') + # y = T.alloc(0., m, 1) + cost = d.cost_from_X((samples, y_hat)) + sample_grad = T.grad(-cost, samples) + rval['sample_grad_norm'] = T.sqrt(T.sqr(sample_grad).sum()) + _S, d_obj, g_obj, i_obj = self.get_samples_and_objectives(model, data) + if model.monitor_inference and i_obj != 0: + rval['objective_i'] = i_obj + if model.monitor_discriminator: + rval['objective_d'] = d_obj + if model.monitor_generator: + rval['objective_g'] = g_obj + + rval['now_train_generator'] = self.now_train_generator + return rval + + +class HardSigmoid(Linear): + """ + Hard "sigmoid" (note: shifted along the x axis) + """ + + def __init__(self, left_slope=0.0, **kwargs): + super(HardSigmoid, self).__init__(**kwargs) + self.left_slope = left_slope + + @wraps(Layer.fprop) + def fprop(self, state_below): + + p = self._linear_part(state_below) + # Original: p = p * (p > 0.) + self.left_slope * p * (p < 0.) + # T.switch is faster. + # For details, see benchmarks in + # pylearn2/scripts/benchmark/time_relu.py + p = T.clip(p, 0., 1.) + return p + + @wraps(Layer.cost) + def cost(self, *args, **kwargs): + + raise NotImplementedError() + + +class LazyAdversaryCost(DefaultDataSpecsMixin, Cost): + """ + """ + + # Supplies own labels, don't get them from the dataset + supervised = False + + def __init__(self, scale_grads=1, target_scale=.1, + discriminator_default_input_include_prob = 1., + discriminator_input_include_probs=None, + discriminator_default_input_scale=1., + discriminator_input_scales=None, + generator_default_input_include_prob = 1., + generator_default_input_scale=1., + inference_default_input_include_prob=None, + inference_input_include_probs=None, + inference_default_input_scale=1., + inference_input_scales=None, + init_now_train_generator=True, + ever_train_discriminator=True, + ever_train_generator=True, + ever_train_inference=True, + no_drop_in_d_for_g=False, + alternate_g = False, + infer_layer=None, + noise_both = 0., + g_eps = 0., + d_eps =0.): + self.__dict__.update(locals()) + del self.self + # These allow you to dynamically switch off training parts. + # If the corresponding ever_train_* is False, these have + # no effect. + self.now_train_generator = sharedX(init_now_train_generator) + self.now_train_discriminator = sharedX(numpy.array(1., dtype='float32')) + self.now_train_inference = sharedX(numpy.array(1., dtype='float32')) + + def expr(self, model, data, **kwargs): + S, d_obj, g_obj, i_obj = self.get_samples_and_objectives(model, data) + l = [] + # This stops stuff from ever getting computed if we're not training + # it. + if self.ever_train_discriminator: + l.append(d_obj) + if self.ever_train_generator: + l.append(g_obj) + if self.ever_train_inference: + l.append(i_obj) + return sum(l) + + def get_samples_and_objectives(self, model, data): + space, sources = self.get_data_specs(model) + space.validate(data) + assert isinstance(model, AdversaryPair) + g = model.generator + d = model.discriminator + + # Note: this assumes data is design matrix + X = data + m = data.shape[space.get_batch_axis()] + y1 = T.alloc(1, m, 1) + y0 = T.alloc(0, m, 1) + # NOTE: if this changes to optionally use dropout, change the inference + # code below to use a non-dropped-out version. + S, z, other_layers = g.sample_and_noise(m, default_input_include_prob=self.generator_default_input_include_prob, default_input_scale=self.generator_default_input_scale, all_g_layers=(self.infer_layer is not None)) + + if self.noise_both != 0.: + rng = MRG_RandomStreams(2014 / 6 + 2) + S = S + rng.normal(size=S.shape, dtype=S.dtype) * self.noise_both + X = X + rng.normal(size=X.shape, dtype=S.dtype) * self.noise_both + + y_hat1 = d.dropout_fprop(X, self.discriminator_default_input_include_prob, + self.discriminator_input_include_probs, + self.discriminator_default_input_scale, + self.discriminator_input_scales) + y_hat0 = d.dropout_fprop(S, self.discriminator_default_input_include_prob, + self.discriminator_input_include_probs, + self.discriminator_default_input_scale, + self.discriminator_input_scales) + + # d_obj = 0.5 * (d.layers[-1].cost(y1, y_hat1) + d.layers[-1].cost(y0, y_hat0)) + + pos_mask = y_hat1 < .5 + self.d_eps + neg_mask = y_hat0 > .5 - self.d_eps + + pos_cost_matrix = d.layers[-1].cost_matrix(y1, y_hat1) + neg_cost_matrix = d.layers[-1].cost_matrix(y0, y_hat0) + + pos_cost = (pos_mask * pos_cost_matrix).mean() + neg_cost = (neg_mask * neg_cost_matrix).mean() + + d_obj = 0.5 * (pos_cost + neg_cost) + + if self.no_drop_in_d_for_g: + y_hat0_no_drop = d.dropout_fprop(S) + g_cost_mat = d.layers[-1].cost_matrix(y1, y_hat0_no_drop) + else: + g_cost_mat = d.layers[-1].cost_matrix(y1, y_hat0) + assert g_cost_mat.ndim == 2 + assert y_hat0.ndim == 2 + + mask = y_hat0 < 0.5 + self.g_eps + masked_cost = g_cost_mat * mask + g_obj = masked_cost.mean() + + + if model.inferer is not None: + # Change this if we ever switch to using dropout in the + # construction of S. + S_nograd = block_gradient(S) # Redundant as long as we have custom get_gradients + pred = model.inferer.dropout_fprop(S_nograd, self.inference_default_input_include_prob, + self.inference_input_include_probs, + self.inference_default_input_scale, + self.inference_input_scales) + if self.infer_layer is None: + target = z + else: + target = other_layers[self.infer_layer] + i_obj = model.inferer.layers[-1].cost(target, pred) + else: + i_obj = 0 + + return S, d_obj, g_obj, i_obj + + def get_gradients(self, model, data, **kwargs): + space, sources = self.get_data_specs(model) + space.validate(data) + assert isinstance(model, AdversaryPair) + g = model.generator + d = model.discriminator + + S, d_obj, g_obj, i_obj = self.get_samples_and_objectives(model, data) + + g_params = g.get_params() + d_params = d.get_params() + for param in g_params: + assert param not in d_params + for param in d_params: + assert param not in g_params + d_grads = T.grad(d_obj, d_params) + g_grads = T.grad(g_obj, g_params) + + if self.scale_grads: + S_grad = T.grad(g_obj, S) + scale = T.maximum(1., self.target_scale / T.sqrt(T.sqr(S_grad).sum())) + g_grads = [g_grad * scale for g_grad in g_grads] + + rval = OrderedDict() + zeros = itertools.repeat(theano.tensor.constant(0., dtype='float32')) + if self.ever_train_discriminator: + rval.update(OrderedDict(safe_zip(d_params, [self.now_train_discriminator * dg for dg in d_grads]))) + else: + rval.update(OrderedDict(zip(d_params, zeros))) + if self.ever_train_generator: + rval.update(OrderedDict(safe_zip(g_params, [self.now_train_generator * gg for gg in g_grads]))) + else: + rval.update(OrderedDict(zip(g_params, zeros))) + if self.ever_train_inference and model.inferer is not None: + i_params = model.inferer.get_params() + i_grads = T.grad(i_obj, i_params) + rval.update(OrderedDict(safe_zip(i_params, [self.now_train_inference * ig for ig in i_grads]))) + elif model.inferer is not None: + rval.update(OrderedDict(model.inferer.get_params(), zeros)) + + updates = OrderedDict() + + # Two d steps for every g step + if self.alternate_g: + updates[self.now_train_generator] = 1. - self.now_train_generator + + return rval, updates + + def get_monitoring_channels(self, model, data, **kwargs): + + rval = OrderedDict() + + m = data.shape[0] + + g = model.generator + d = model.discriminator + + y_hat = d.fprop(data) + + rval['false_negatives'] = T.cast((y_hat < 0.5).mean(), 'float32') + + samples = g.sample(m) + y_hat = d.fprop(samples) + rval['false_positives'] = T.cast((y_hat > 0.5).mean(), 'float32') + # y = T.alloc(0., m, 1) + cost = d.cost_from_X((samples, y_hat)) + sample_grad = T.grad(-cost, samples) + rval['sample_grad_norm'] = T.sqrt(T.sqr(sample_grad).sum()) + _S, d_obj, g_obj, i_obj = self.get_samples_and_objectives(model, data) + if model.monitor_inference and i_obj != 0: + rval['objective_i'] = i_obj + if model.monitor_discriminator: + rval['objective_d'] = d_obj + if model.monitor_generator: + rval['objective_g'] = g_obj + + rval['now_train_generator'] = self.now_train_generator + return rval diff --git a/cifar10_convolutional.yaml b/cifar10_convolutional.yaml new file mode 100644 index 0000000..11cb294 --- /dev/null +++ b/cifar10_convolutional.yaml @@ -0,0 +1,174 @@ +!obj:pylearn2.train.Train { + dataset: &train !obj:pylearn2.datasets.cifar10.CIFAR10 { + axes: ['c', 0, 1, 'b'], + gcn: 55., + which_set: 'train', + start: 0, + stop: 40000 + }, + model: !obj:adversarial.AdversaryPair { + generator: !obj:adversarial.Generator { + mlp: !obj:pylearn2.models.mlp.MLP { + layers: [ + !obj:pylearn2.models.mlp.RectifiedLinear { + layer_name: 'gh0', + dim: 8000, + irange: .05, + #max_col_norm: 1.9365, + }, + !obj:pylearn2.models.mlp.Sigmoid { + layer_name: 'h1', + dim: 8000, + irange: .05, + #max_col_norm: 1.9365, + }, + !obj:pylearn2.models.mlp.SpaceConverter { + layer_name: 'converter', + output_space: !obj:pylearn2.space.Conv2DSpace { + shape: [10, 10], + num_channels: 80, + axes: ['c', 0, 1, 'b'], + }}, + !obj:adversarial.deconv.Deconv { + #W_lr_scale: .05, + #b_lr_scale: .05, + num_channels: 3, + output_stride: [3, 3], + kernel_shape: [5, 5], + pad_out: 0, + #max_kernel_norm: 1.9365, + # init_bias: !obj:pylearn2.models.dbm.init_sigmoid_bias_from_marginals { dataset: *train}, + layer_name: 'y', + irange: .05, + tied_b: 0 + }, + ], + nvis: 100, + }}, + discriminator: + !obj:pylearn2.models.mlp.MLP { + layers: [ + !obj:pylearn2.models.maxout.MaxoutConvC01B { + layer_name: 'dh0', + pad: 4, + tied_b: 1, + #W_lr_scale: .05, + #b_lr_scale: .05, + num_channels: 32, + num_pieces: 2, + kernel_shape: [8, 8], + pool_shape: [4, 4], + pool_stride: [2, 2], + irange: .005, + #max_kernel_norm: .9, + partial_sum: 33, + }, + !obj:pylearn2.models.maxout.MaxoutConvC01B { + layer_name: 'h1', + pad: 3, + tied_b: 1, + #W_lr_scale: .05, + #b_lr_scale: .05, + num_channels: 32, # 192 ran out of memory + num_pieces: 2, + kernel_shape: [8, 8], + pool_shape: [4, 4], + pool_stride: [2, 2], + irange: .005, + #max_kernel_norm: 1.9365, + partial_sum: 15, + }, + !obj:pylearn2.models.maxout.MaxoutConvC01B { + pad: 3, + layer_name: 'h2', + tied_b: 1, + #W_lr_scale: .05, + #b_lr_scale: .05, + num_channels: 192, + num_pieces: 2, + kernel_shape: [5, 5], + pool_shape: [2, 2], + pool_stride: [2, 2], + irange: .005, + #max_kernel_norm: 1.9365, + }, + !obj:pylearn2.models.maxout.Maxout { + layer_name: 'h3', + irange: .005, + num_units: 500, + num_pieces: 5, + #max_col_norm: 1.9 + }, + !obj:pylearn2.models.mlp.Sigmoid { + #W_lr_scale: .1, + #b_lr_scale: .1, + #max_col_norm: 1.9365, + layer_name: 'y', + dim: 1, + irange: .005 + } + ], + input_space: !obj:pylearn2.space.Conv2DSpace { + shape: [32, 32], + num_channels: 3, + axes: ['c', 0, 1, 'b'], + } + }, + }, + algorithm: !obj:pylearn2.training_algorithms.sgd.SGD { + batch_size: 128, + learning_rate: .004, + learning_rule: !obj:pylearn2.training_algorithms.learning_rule.Momentum { + init_momentum: .5, + }, + monitoring_dataset: + { + #'train' : *train, + 'valid' : !obj:pylearn2.datasets.cifar10.CIFAR10 { + axes: ['c', 0, 1, 'b'], + gcn: 55., + which_set: 'train', + start: 40000, + stop: 50000 + }, + #'test' : !obj:pylearn2.datasets.cifar10.CIFAR10 { + # which_set: 'test', + # gcn: 55., + # } + }, + cost: !obj:adversarial.AdversaryCost2 { + scale_grads: 0, + #target_scale: .1, + discriminator_default_input_include_prob: .5, + discriminator_input_include_probs: { + 'dh0': .8 + }, + discriminator_default_input_scale: 2., + discriminator_input_scales: { + 'dh0': 1.25 + } + }, + #termination_criterion: !obj:pylearn2.termination_criteria.MonitorBased { + # channel_name: "valid_y_misclass", + # prop_decrease: 0., + # N: 100 + #}, + update_callbacks: !obj:pylearn2.training_algorithms.sgd.ExponentialDecay { + decay_factor: 1.000004, + min_lr: .000001 + } + }, + extensions: [ + #!obj:pylearn2.train_extensions.best_params.MonitorBasedSaveBest { + # channel_name: 'valid_y_misclass', + # save_path: "${PYLEARN2_TRAIN_FILE_FULL_STEM}_best.pkl" + #}, + !obj:pylearn2.training_algorithms.learning_rule.MomentumAdjustor { + start: 1, + saturate: 250, + final_momentum: .7 + } + ], + save_path: "${PYLEARN2_TRAIN_FILE_FULL_STEM}.pkl", + save_freq: 1 +} diff --git a/cifar10_fully_connected.yaml b/cifar10_fully_connected.yaml new file mode 100644 index 0000000..30ef170 --- /dev/null +++ b/cifar10_fully_connected.yaml @@ -0,0 +1,114 @@ +!obj:pylearn2.train.Train { + dataset: &train !obj:pylearn2.datasets.cifar10.CIFAR10 { + gcn: 55., + which_set: 'train', + start: 0, + stop: 40000 + }, + model: !obj:adversarial.AdversaryPair { + generator: !obj:adversarial.Generator { + mlp: !obj:pylearn2.models.mlp.MLP { + layers: [ + !obj:pylearn2.models.mlp.RectifiedLinear { + layer_name: 'gh0', + dim: 8000, + irange: .05, + }, + !obj:pylearn2.models.mlp.Sigmoid { + layer_name: 'h1', + dim: 8000, + irange: .05, + }, + !obj:pylearn2.models.mlp.Linear { + # init_bias: !obj:pylearn2.models.dbm.init_sigmoid_bias_from_marginals { dataset: *train}, + layer_name: 'y', + irange: .5, + dim: 3072 + } + ], + nvis: 100, + }}, + discriminator: + !obj:pylearn2.models.mlp.MLP { + layers: [ + !obj:pylearn2.models.maxout.Maxout { + layer_name: 'dh0', + num_units: 1600, + num_pieces: 5, + irange: .005, + }, + !obj:pylearn2.models.maxout.Maxout { + layer_name: 'h1', + num_units: 1600, + num_pieces: 5, + irange: .005, + }, + !obj:pylearn2.models.mlp.Sigmoid { + layer_name: 'y', + dim: 1, + irange: .005 + } + ], + nvis: 3072, + }, + }, + algorithm: !obj:pylearn2.training_algorithms.sgd.SGD { + batch_size: 100, + learning_rate: .025, + learning_rule: !obj:pylearn2.training_algorithms.learning_rule.Momentum { + init_momentum: .5, + }, + monitoring_dataset: + { + #'train' : *train, + 'valid' : !obj:pylearn2.datasets.cifar10.CIFAR10 { + gcn: 55., + which_set: 'train', + start: 40000, + stop: 50000 + }, + #'test' : !obj:pylearn2.datasets.cifar10.CIFAR10 { + # which_set: 'test', + # gcn: 55., + # } + }, + cost: !obj:adversarial.AdversaryCost2 { + scale_grads: 0, + #target_scale: .1, + discriminator_default_input_include_prob: .5, + discriminator_input_include_probs: { + 'dh0': .8 + }, + discriminator_default_input_scale: 2., + discriminator_input_scales: { + 'dh0': 1.25 + } + }, + #!obj:pylearn2.costs.mlp.dropout.Dropout { + # input_include_probs: { 'h0' : .8 }, + # input_scales: { 'h0': 1. } + #}, + #termination_criterion: !obj:pylearn2.termination_criteria.MonitorBased { + # channel_name: "valid_y_misclass", + # prop_decrease: 0., + # N: 100 + #}, + update_callbacks: !obj:pylearn2.training_algorithms.sgd.ExponentialDecay { + decay_factor: 1.000004, + min_lr: .000001 + } + }, + extensions: [ + #!obj:pylearn2.train_extensions.best_params.MonitorBasedSaveBest { + # channel_name: 'valid_y_misclass', + # save_path: "${PYLEARN2_TRAIN_FILE_FULL_STEM}_best.pkl" + #}, + !obj:pylearn2.training_algorithms.learning_rule.MomentumAdjustor { + start: 1, + saturate: 250, + final_momentum: .7 + } + ], + save_path: "${PYLEARN2_TRAIN_FILE_FULL_STEM}.pkl", + save_freq: 1 +} diff --git a/deconv.py b/deconv.py new file mode 100644 index 0000000..cd051bb --- /dev/null +++ b/deconv.py @@ -0,0 +1,384 @@ +import functools +import logging +import numpy as np + +from theano.compat import OrderedDict +from theano import tensor as T + +from pylearn2.linear.conv2d_c01b import make_random_conv2D +from pylearn2.models import Model +from pylearn2.models.maxout import check_cuda # TODO: import from original path +from pylearn2.models.mlp import Layer +#from pylearn2.models.maxout import py_integer_types # TODO: import from orig path +from pylearn2.space import Conv2DSpace +from pylearn2.utils import sharedX + +logger = logging.getLogger(__name__) + +class Deconv(Layer): + def __init__(self, + num_channels, + kernel_shape, + layer_name, + irange=None, + init_bias=0., + W_lr_scale=None, + b_lr_scale=None, + pad_out=0, + fix_kernel_shape=False, + partial_sum=1, + tied_b=False, + max_kernel_norm=None, + output_stride=(1, 1)): + check_cuda(str(type(self))) + super(Deconv, self).__init__() + + detector_channels = num_channels + + self.__dict__.update(locals()) + del self.self + + @functools.wraps(Model.get_lr_scalers) + def get_lr_scalers(self): + + if not hasattr(self, 'W_lr_scale'): + self.W_lr_scale = None + + if not hasattr(self, 'b_lr_scale'): + self.b_lr_scale = None + + rval = OrderedDict() + + if self.W_lr_scale is not None: + W, = self.transformer.get_params() + rval[W] = self.W_lr_scale + + if self.b_lr_scale is not None: + rval[self.b] = self.b_lr_scale + + return rval + + def set_input_space(self, space): + """ + Tells the layer to use the specified input space. + + This resets parameters! The kernel tensor is initialized with the + size needed to receive input from this space. + + Parameters + ---------- + space : Space + The Space that the input will lie in. + """ + + setup_deconv_detector_layer_c01b(layer=self, + input_space=space, + rng=self.mlp.rng) + + rng = self.mlp.rng + + detector_shape = self.detector_space.shape + + + self.output_space = self.detector_space + + logger.info('Output space: {0}'.format(self.output_space.shape)) + + def _modify_updates(self, updates): + """ + Replaces the values in `updates` if needed to enforce the options set + in the __init__ method, including `max_kernel_norm`. + + Parameters + ---------- + updates : OrderedDict + A dictionary mapping parameters (including parameters not + belonging to this model) to updated values of those parameters. + The dictionary passed in contains the updates proposed by the + learning algorithm. This function modifies the dictionary + directly. The modified version will be compiled and executed + by the learning algorithm. + """ + + if self.max_kernel_norm is not None: + W, = self.transformer.get_params() + if W in updates: + updated_W = updates[W] + row_norms = T.sqrt(T.sum(T.sqr(updated_W), axis=(0, 1, 2))) + desired_norms = T.clip(row_norms, 0, self.max_kernel_norm) + scales = desired_norms / (1e-7 + row_norms) + updates[W] = (updated_W * scales.dimshuffle('x', 'x', 'x', 0)) + + @functools.wraps(Model.get_params) + def get_params(self): + assert self.b.name is not None + W, = self.transformer.get_params() + assert W.name is not None + rval = self.transformer.get_params() + assert not isinstance(rval, set) + rval = list(rval) + assert self.b not in rval + rval.append(self.b) + return rval + + @functools.wraps(Layer.get_weight_decay) + def get_weight_decay(self, coeff): + if isinstance(coeff, str): + coeff = float(coeff) + assert isinstance(coeff, float) or hasattr(coeff, 'dtype') + W, = self.transformer.get_params() + return coeff * T.sqr(W).sum() + + @functools.wraps(Layer.set_weights) + def set_weights(self, weights): + W, = self.transformer.get_params() + W.set_value(weights) + + @functools.wraps(Layer.set_biases) + def set_biases(self, biases): + self.b.set_value(biases) + + @functools.wraps(Layer.get_biases) + def get_biases(self): + return self.b.get_value() + + @functools.wraps(Model.get_weights_topo) + def get_weights_topo(self): + return self.transformer.get_weights_topo() + + @functools.wraps(Layer.get_monitoring_channels) + def get_layer_monitoring_channels(self, state_below=None, state=None, targets=None): + + W, = self.transformer.get_params() + + assert W.ndim == 4 + + sq_W = T.sqr(W) + + row_norms = T.sqrt(sq_W.sum(axis=(0, 1, 2))) + + P = state + + rval = OrderedDict() + + vars_and_prefixes = [(P, '')] + + for var, prefix in vars_and_prefixes: + if not hasattr(var, 'ndim') or var.ndim != 4: + print "expected 4D tensor, got " + print var + print type(var) + if isinstance(var, tuple): + print "tuple length: ", len(var) + assert False + v_max = var.max(axis=(1, 2, 3)) + v_min = var.min(axis=(1, 2, 3)) + v_mean = var.mean(axis=(1, 2, 3)) + v_range = v_max - v_min + + # max_x.mean_u is "the mean over *u*nits of the max over + # e*x*amples" The x and u are included in the name because + # otherwise its hard to remember which axis is which when reading + # the monitor I use inner.outer rather than outer_of_inner or + # something like that because I want mean_x.* to appear next to + # each other in the alphabetical list, as these are commonly + # plotted together + for key, val in [('max_x.max_u', v_max.max()), + ('max_x.mean_u', v_max.mean()), + ('max_x.min_u', v_max.min()), + ('min_x.max_u', v_min.max()), + ('min_x.mean_u', v_min.mean()), + ('min_x.min_u', v_min.min()), + ('range_x.max_u', v_range.max()), + ('range_x.mean_u', v_range.mean()), + ('range_x.min_u', v_range.min()), + ('mean_x.max_u', v_mean.max()), + ('mean_x.mean_u', v_mean.mean()), + ('mean_x.min_u', v_mean.min())]: + rval[prefix+key] = val + + rval.update(OrderedDict([('kernel_norms_min', row_norms.min()), + ('kernel_norms_mean', row_norms.mean()), + ('kernel_norms_max', row_norms.max()), ])) + + return rval + + @functools.wraps(Layer.fprop) + def fprop(self, state_below): + check_cuda(str(type(self))) + + self.input_space.validate(state_below) + + z = self.transformer.lmul_T(state_below) + + self.output_space.validate(z) + + if not hasattr(self, 'tied_b'): + self.tied_b = False + if self.tied_b: + b = self.b.dimshuffle(0, 'x', 'x', 'x') + else: + b = self.b.dimshuffle(0, 1, 2, 'x') + + return z + b + + + +def setup_deconv_detector_layer_c01b(layer, input_space, rng, irange="not specified"): + """ + layer. This function sets up only the detector layer. + + Does the following: + + * raises a RuntimeError if cuda is not available + * sets layer.input_space to input_space + * sets up addition of dummy channels for compatibility with cuda-convnet: + + - layer.dummy_channels: # of dummy channels that need to be added + (You might want to check this and raise an Exception if it's not 0) + - layer.dummy_space: The Conv2DSpace representing the input with dummy + channels added + + * sets layer.detector_space to the space for the detector layer + * sets layer.transformer to be a Conv2D instance + * sets layer.b to the right value + + Parameters + ---------- + layer : object + Any python object that allows the modifications described below and + has the following attributes: + + * pad : int describing amount of zero padding to add + * kernel_shape : 2-element tuple or list describing spatial shape of + kernel + * fix_kernel_shape : bool, if true, will shrink the kernel shape to + make it feasible, as needed (useful for hyperparameter searchers) + * detector_channels : The number of channels in the detector layer + * init_bias : numeric constant added to a tensor of zeros to + initialize the bias + * tied_b : If true, biases are shared across all spatial locations + input_space : WRITEME + A Conv2DSpace to be used as input to the layer + rng : WRITEME + A numpy RandomState or equivalent + """ + + if irange != "not specified": + raise AssertionError( + "There was a bug in setup_detector_layer_c01b." + "It uses layer.irange instead of the irange parameter to the " + "function. The irange parameter is now disabled by this " + "AssertionError, so that this error message can alert you that " + "the bug affected your code and explain why the interface is " + "changing. The irange parameter to the function and this " + "error message may be removed after April 21, 2014." + ) + + # Use "self" to refer to layer from now on, so we can pretend we're + # just running in the set_input_space method of the layer + self = layer + + # Make sure cuda is available + check_cuda(str(type(self))) + + # Validate input + if not isinstance(input_space, Conv2DSpace): + raise TypeError("The input to a convolutional layer should be a " + "Conv2DSpace, but layer " + self.layer_name + " got " + + str(type(self.input_space))) + + if not hasattr(self, 'detector_channels'): + raise ValueError("layer argument must have a 'detector_channels' " + "attribute specifying how many channels to put in " + "the convolution kernel stack.") + + # Store the input space + self.input_space = input_space + + # Make sure number of channels is supported by cuda-convnet + # (multiple of 4 or <= 3) + # If not supported, pad the input with dummy channels + ch = self.detector_channels + rem = ch % 4 + if ch > 3 and rem != 0: + raise NotImplementedError("Need to do dummy channels on the output") + # self.dummy_channels = 4 - rem + #else: + # self.dummy_channels = 0 + #self.dummy_space = Conv2DSpace( + # shape=input_space.shape, + # channels=input_space.num_channels + self.dummy_channels, + # axes=('c', 0, 1, 'b') + #) + + if hasattr(self, 'output_stride'): + kernel_stride = self.output_stride + else: + assert False # not sure if I got the name right, remove this assert if I did + kernel_stride = [1, 1] + + + #o_sh = int(np.ceil((i_sh + 2. * self.pad - k_sh) / float(k_st))) + 1 + #o_sh -1 = np.ceil((i_sh + 2. * self.pad - k_sh) / float(k_st)) + #inv_ceil(o_sh -1) = (i_sh + 2. * self.pad - k_sh) / float(k_st) + #float(k_st) inv_cel(o_sh -1) = (i_sh + 2 * self.pad -k_sh) + # i_sh = k_st inv_ceil(o_sh-1) - 2 * self.pad + k_sh + + output_shape = \ + [k_st * (i_sh - 1) - 2 * self.pad_out + k_sh + for i_sh, k_sh, k_st in zip(self.input_space.shape, + self.kernel_shape, kernel_stride)] + + + if self.input_space.num_channels < 16: + raise ValueError("Cuda-convnet requires the input to lmul_T to have " + "at least 16 channels.") + + self.detector_space = Conv2DSpace(shape=output_shape, + num_channels=self.detector_channels, + axes=('c', 0, 1, 'b')) + + if hasattr(self, 'partial_sum'): + partial_sum = self.partial_sum + else: + partial_sum = 1 + + if hasattr(self, 'sparse_init') and self.sparse_init is not None: + self.transformer = \ + checked_call(make_sparse_random_conv2D, + OrderedDict([('num_nonzero', self.sparse_init), + ('input_space', self.detector_space), + ('output_space', self.input_space), + ('kernel_shape', self.kernel_shape), + ('pad', self.pad), + ('partial_sum', partial_sum), + ('kernel_stride', kernel_stride), + ('rng', rng)])) + else: + self.transformer = make_random_conv2D( + irange=self.irange, + input_axes=self.detector_space.axes, + output_axes=self.input_space.axes, + input_channels=self.detector_space.num_channels, + output_channels=self.input_space.num_channels, + kernel_shape=self.kernel_shape, + pad=self.pad_out, + partial_sum=partial_sum, + kernel_stride=kernel_stride, + rng=rng, + input_shape=self.detector_space.shape + ) + + W, = self.transformer.get_params() + W.name = self.layer_name + '_W' + + if self.tied_b: + self.b = sharedX(np.zeros(self.detector_space.num_channels) + + self.init_bias) + else: + self.b = sharedX(self.detector_space.get_origin() + self.init_bias) + self.b.name = self.layer_name + '_b' + + logger.info('Input shape: {0}'.format(self.input_space.shape)) + print layer.layer_name + ' detector space: {0}'.format(self.detector_space.shape) diff --git a/mnist.yaml b/mnist.yaml new file mode 100644 index 0000000..2e35ea7 --- /dev/null +++ b/mnist.yaml @@ -0,0 +1,113 @@ +!obj:pylearn2.train.Train { + dataset: &train !obj:pylearn2.datasets.mnist.MNIST { + which_set: 'train', + start: 0, + stop: 50000 + }, + model: !obj:adversarial.AdversaryPair { + generator: !obj:adversarial.Generator { + monitor_ll: 1, + noise: "uniform", + mlp: !obj:pylearn2.models.mlp.MLP { + layers: [ + !obj:pylearn2.models.mlp.RectifiedLinear { + layer_name: 'h0', + dim: 1200, + irange: .05, + }, + !obj:pylearn2.models.mlp.RectifiedLinear { + layer_name: 'h1', + dim: 1200, + irange: .05, + }, + !obj:pylearn2.models.mlp.Sigmoid { + init_bias: !obj:pylearn2.models.dbm.init_sigmoid_bias_from_marginals { dataset: *train}, + layer_name: 'y', + irange: .05, + dim: 784 + } + ], + nvis: 100, + }}, + discriminator: + !obj:pylearn2.models.mlp.MLP { + layers: [ + !obj:pylearn2.models.maxout.Maxout { + layer_name: 'h0', + num_units: 240, + num_pieces: 5, + irange: .005, + }, + !obj:pylearn2.models.maxout.Maxout { + layer_name: 'h1', + num_units: 240, + num_pieces: 5, + irange: .005, + }, + !obj:pylearn2.models.mlp.Sigmoid { + layer_name: 'y', + dim: 1, + irange: .005 + } + ], + nvis: 784, + }, + }, + algorithm: !obj:pylearn2.training_algorithms.sgd.SGD { + batch_size: 100, + learning_rate: .1, + learning_rule: !obj:pylearn2.training_algorithms.learning_rule.Momentum { + init_momentum: .5, + }, + monitoring_dataset: + { + #'train' : *train, + 'valid' : !obj:pylearn2.datasets.mnist.MNIST { + which_set: 'train', + start: 50000, + stop: 60000 + }, + #'test' : !obj:pylearn2.datasets.mnist.MNIST { + # which_set: 'test', + # } + }, + cost: !obj:adversarial.AdversaryCost2 { + scale_grads: 0, + #target_scale: 1., + discriminator_default_input_include_prob: .5, + discriminator_input_include_probs: { + 'h0': .8 + }, + discriminator_default_input_scale: 2., + discriminator_input_scales: { + 'h0': 1.25 + } + }, + #!obj:pylearn2.costs.mlp.dropout.Dropout { + # input_include_probs: { 'h0' : .8 }, + # input_scales: { 'h0': 1. } + #}, + #termination_criterion: !obj:pylearn2.termination_criteria.MonitorBased { + # channel_name: "valid_y_misclass", + # prop_decrease: 0., + # N: 100 + #}, + update_callbacks: !obj:pylearn2.training_algorithms.sgd.ExponentialDecay { + decay_factor: 1.000004, + min_lr: .000001 + } + }, + extensions: [ + #!obj:pylearn2.train_extensions.best_params.MonitorBasedSaveBest { + # channel_name: 'valid_y_misclass', + # save_path: "${PYLEARN2_TRAIN_FILE_FULL_STEM}_best.pkl" + #}, + !obj:pylearn2.training_algorithms.learning_rule.MomentumAdjustor { + start: 1, + saturate: 250, + final_momentum: .7 + } + ], + save_path: "${PYLEARN2_TRAIN_FILE_FULL_STEM}.pkl", + save_freq: 1 +} diff --git a/parzen_ll.py b/parzen_ll.py new file mode 100644 index 0000000..0eb0911 --- /dev/null +++ b/parzen_ll.py @@ -0,0 +1,158 @@ +import argparse +import time +import gc +import numpy +import theano +import theano.tensor as T +from pylearn2.utils import serial +from pylearn2.config import yaml_parse +from pylearn2.datasets.mnist import MNIST +from pylearn2.datasets.tfd import TFD + + + +def get_nll(x, parzen, batch_size=10): + """ + Credit: Yann N. Dauphin + """ + + inds = range(x.shape[0]) + n_batches = int(numpy.ceil(float(len(inds)) / batch_size)) + + times = [] + nlls = [] + for i in range(n_batches): + begin = time.time() + nll = parzen(x[inds[i::n_batches]]) + end = time.time() + times.append(end-begin) + nlls.extend(nll) + + if i % 10 == 0: + print i, numpy.mean(times), numpy.mean(nlls) + + return numpy.array(nlls) + + +def log_mean_exp(a): + """ + Credit: Yann N. Dauphin + """ + + max_ = a.max(1) + + return max_ + T.log(T.exp(a - max_.dimshuffle(0, 'x')).mean(1)) + + +def theano_parzen(mu, sigma): + """ + Credit: Yann N. Dauphin + """ + + x = T.matrix() + mu = theano.shared(mu) + a = ( x.dimshuffle(0, 'x', 1) - mu.dimshuffle('x', 0, 1) ) / sigma + E = log_mean_exp(-0.5*(a**2).sum(2)) + Z = mu.shape[1] * T.log(sigma * numpy.sqrt(numpy.pi * 2)) + + return theano.function([x], E - Z) + + +def cross_validate_sigma(samples, data, sigmas, batch_size): + + lls = [] + for sigma in sigmas: + print sigma + parzen = theano_parzen(samples, sigma) + tmp = get_nll(data, parzen, batch_size = batch_size) + lls.append(numpy.asarray(tmp).mean()) + del parzen + gc.collect() + + ind = numpy.argmax(lls) + return sigmas[ind] + + +def get_valid(ds, limit_size = -1, fold = 0): + if ds == 'mnist': + data = MNIST('train', start=50000, stop=60000) + return data.X[:limit_size] + elif ds == 'tfd': + data = TFD('valid', fold = fold, scale=True) + return data.X + else: + raise ValueError("Unknow dataset: {}".format(args.dataet)) + + +def get_test(ds, test, fold=0): + if ds == 'mnist': + return test.get_test_set() + elif ds == 'tfd': + return test.get_test_set(fold=fold) + else: + raise ValueError("Unknow dataset: {}".format(args.dataet)) + + +def main(): + parser = argparse.ArgumentParser(description = 'Parzen window, log-likelihood estimator') + parser.add_argument('-p', '--path', help='model path') + parser.add_argument('-s', '--sigma', default = None) + parser.add_argument('-d', '--dataset', choices=['mnist', 'tfd']) + parser.add_argument('-f', '--fold', default = 0, type=int) + parser.add_argument('-v', '--valid', default = False, action='store_true') + parser.add_argument('-n', '--num_samples', default=10000, type=int) + parser.add_argument('-l', '--limit_size', default=1000, type=int) + parser.add_argument('-b', '--batch_size', default=100, type=int) + parser.add_argument('-c', '--cross_val', default=10, type=int, + help="Number of cross valiation folds") + parser.add_argument('--sigma_start', default=-1, type=float) + parser.add_argument('--sigma_end', default=0., type=float) + args = parser.parse_args() + + # load model + model = serial.load(args.path) + src = model.dataset_yaml_src + batch_size = args.batch_size + model.set_batch_size(batch_size) + + # load test set + test = yaml_parse.load(src) + test = get_test(args.dataset, test, args.fold) + + # generate samples + samples = model.generator.sample(args.num_samples).eval() + output_space = model.generator.mlp.get_output_space() + if 'Conv2D' in str(output_space): + samples = output_space.convert(samples, output_space.axes, ('b', 0, 1, 'c')) + samples = samples.reshape((samples.shape[0], numpy.prod(samples.shape[1:]))) + del model + gc.collect() + + # cross validate sigma + if args.sigma is None: + valid = get_valid(args.dataset, limit_size = args.limit_size, fold = args.fold) + sigma_range = numpy.logspace(args.sigma_start, args.sigma_end, num=args.cross_val) + sigma = cross_validate_sigma(samples, valid, sigma_range, batch_size) + else: + sigma = float(args.sigma) + + print "Using Sigma: {}".format(sigma) + gc.collect() + + # fit and evaulate + parzen = theano_parzen(samples, sigma) + ll = get_nll(test.X, parzen, batch_size = batch_size) + se = ll.std() / numpy.sqrt(test.X.shape[0]) + + print "Log-Likelihood of test set = {}, se: {}".format(ll.mean(), se) + + # valid + if args.valid: + valid = get_valid(args.dataset) + ll = get_nll(valid, parzen, batch_size = batch_size) + se = ll.std() / numpy.sqrt(val.shape[0]) + print "Log-Likelihood of valid set = {}, se: {}".format(ll.mean(), se) + + +if __name__ == "__main__": + main() diff --git a/sgd.py b/sgd.py new file mode 100644 index 0000000..5483ca4 --- /dev/null +++ b/sgd.py @@ -0,0 +1,1137 @@ +""" +Copy of pylearn2's sgd.py, hacked to support doing steps on +discriminator separately from the generator. Ideally this would +be accomplished using pylearn2's FixedVarDescr implementation, +but it is currently not very well supported. +""" +from __future__ import division + +__authors__ = "Ian Goodfellow" +__copyright__ = "Copyright 2010-2012, Universite de Montreal" +__credits__ = ["Ian Goodfellow, David Warde-Farley"] +__license__ = "3-clause BSD" +__maintainer__ = "David Warde-Farley" +__email__ = "pylearn-dev@googlegroups" + +import logging +import warnings +import numpy as np + +from theano import config +from theano import function +from theano.compat.python2x import OrderedDict +from theano.gof.op import get_debug_values + +from pylearn2.monitor import Monitor +from pylearn2.space import CompositeSpace, NullSpace +from pylearn2.train_extensions import TrainExtension +from pylearn2.training_algorithms.training_algorithm import TrainingAlgorithm +from pylearn2.training_algorithms.learning_rule import Momentum +from pylearn2.training_algorithms.learning_rule import MomentumAdjustor \ + as LRMomentumAdjustor +from pylearn2.utils.iteration import is_stochastic, has_uniform_batch_size +from pylearn2.utils import py_integer_types, py_float_types +from pylearn2.utils import safe_zip +from pylearn2.utils import serial +from pylearn2.utils import sharedX +from pylearn2.utils.data_specs import DataSpecsMapping +from pylearn2.utils.timing import log_timing +from pylearn2.utils.rng import make_np_rng + + +log = logging.getLogger(__name__) + + +class SGD(TrainingAlgorithm): + """ + SGD = (Minibatch) Stochastic Gradient Descent. + A TrainingAlgorithm that does stochastic gradient descent on minibatches + of training examples. + + For theoretical background on this algorithm, see Yoshua Bengio's machine + learning course notes on the subject: + + http://www.iro.umontreal.ca/~pift6266/H10/notes/gradient.html + + Parameters + ---------- + learning_rate : float + The learning rate to use. Train object callbacks can change the + learning rate after each epoch. SGD update_callbacks can change + it after each minibatch. + cost : pylearn2.costs.cost.Cost, optional + Cost object specifying the objective function to be minimized. + Optionally, may be None. In this case, SGD will call the model's + get_default_cost method to obtain the objective function. + batch_size : int, optional + The size of the batch to be used. + If not specified, the model will be asked for the batch size, so + you must have specified the batch size there. + (Some models are rigidly defined to only work with one batch size) + monitoring_batch_size : int, optional + The size of the monitoring batches. + monitoring_batches : int, optional + At the start of each epoch, we run "monitoring", to evaluate + quantities such as the validation set error. + monitoring_batches, if specified, determines the number of batches + to draw from the iterator for each monitoring dataset. + Unnecessary if not using monitoring or if `monitor_iteration_mode` + is 'sequential' and `batch_size` is specified (number of + batches will be calculated based on full dataset size). + TODO: make it possible to specify different monitoring_batches + for each monitoring dataset. The Monitor itself already supports + this. + monitoring_dataset : Dataset or dictionary, optional + If not specified, no monitoring is used. + If specified to be a Dataset, monitor on that Dataset. + If specified to be dictionary, the keys should be string names + of datasets, and the values should be Datasets. All monitoring + channels will be computed for all monitoring Datasets and will + have the dataset name and an underscore prepended to them. + monitor_iteration_mode : str, optional + The iteration mode used to iterate over the examples in all + monitoring datasets. If not specified, defaults to 'sequential'. + TODO: make it possible to specify different modes for different + datasets. + termination_criterion : instance of \ + pylearn2.termination_criteria.TerminationCriterion, optional + + Used to determine when the algorithm should stop running. + If not specified, runs forever--or more realistically, until + external factors halt the python process (Kansas 1977). + update_callbacks : list, optional + If specified, each member of the list should be a callable that + accepts an SGD instance as its only argument. + All callbacks will be called with this SGD instance after each + SGD step. + learning_rule : training_algorithms.learning_rule.LearningRule, optional + A learning rule computes the new parameter values given old + parameters and first-order gradients. If learning_rule is None, + sgd.SGD will update parameters according to the standard SGD + learning rule: + + .. code-block:: none + + param := param - learning_rate * d cost / d param + + This argument allows more sophisticated learning rules, such + as SGD with momentum. + init_momentum : float, **DEPRECATED** option + Use learning_rule instead. + If None, does not use momentum otherwise, use momentum and + initialize the momentum coefficient to init_momentum. Callbacks + can change this over time just like the learning rate. If the + gradient is the same on every step, then the update taken by the + SGD algorithm is scaled by a factor of 1/(1-momentum). See + section 9 of Geoffrey Hinton's "A Practical Guide to Training + Restricted Boltzmann Machines" for details. + set_batch_size : bool, optional + Defaults to False. + If True, and batch_size conflicts with model.force_batch_size, + will call model.set_batch_size(batch_size) in an attempt to + change model.force_batch_size + train_iteration_mode : str, optional + Defaults to 'shuffled_sequential'. + The iteration mode to use for iterating through training examples. + batches_per_iter : int, optional + The number of batches to draw from the iterator over training + examples. + If iteration mode is 'sequential' or 'shuffled_sequential', this + is unnecessary; when unspecified we will iterate over all examples. + theano_function_mode : a valid argument to theano.function's \ + 'mode' parameter, optional + + The theano mode to compile the updates function with. Note that + pylearn2 includes some wraplinker modes that are not bundled with + theano. See pylearn2.devtools. These extra modes let you do + things like check for NaNs at every step, or record md5 digests + of all computations performed by the update function to help + isolate problems with nondeterminism. + monitoring_costs : list, optional + a list of Cost instances. The Monitor will also include all + channels defined by these Costs, even though we don't train + using them. + seed : valid argument to np.random.RandomState, optional + The seed used for the random number generate to be passed to the + training dataset iterator (if any) + """ + def __init__(self, learning_rate, cost=None, batch_size=None, + monitoring_batch_size=None, monitoring_batches=None, + monitoring_dataset=None, monitor_iteration_mode='sequential', + termination_criterion=None, update_callbacks=None, + learning_rule = None, init_momentum = None, + set_batch_size = False, + train_iteration_mode = None, batches_per_iter=None, + theano_function_mode = None, monitoring_costs=None, + seed=[2012, 10, 5], discriminator_steps=1): + self.discriminator_steps = discriminator_steps + + if isinstance(cost, (list, tuple, set)): + raise TypeError("SGD no longer supports using collections of " + + "Costs to represent a sum of Costs. Use " + + "pylearn2.costs.cost.SumOfCosts instead.") + + if init_momentum: + warnings.warn("init_momentum interface is deprecated and will " + "become officially unsuported as of May 9, 2014. Please use the " + "`learning_rule` parameter instead, providing an object of type " + "`pylearn2.training_algorithms.learning_rule.Momentum` instead") + # Convert to new interface under the hood. + self.learning_rule = Momentum(init_momentum) + else: + self.learning_rule = learning_rule + + self.learning_rate = sharedX(learning_rate, 'learning_rate') + self.cost = cost + self.batch_size = batch_size + self.set_batch_size = set_batch_size + self.batches_per_iter = batches_per_iter + self._set_monitoring_dataset(monitoring_dataset) + self.monitoring_batch_size = monitoring_batch_size + self.monitoring_batches = monitoring_batches + self.monitor_iteration_mode = monitor_iteration_mode + if monitoring_dataset is None: + if monitoring_batch_size is not None: + raise ValueError("Specified a monitoring batch size " + + "but not a monitoring dataset.") + if monitoring_batches is not None: + raise ValueError("Specified an amount of monitoring batches " + + "but not a monitoring dataset.") + self.termination_criterion = termination_criterion + self._register_update_callbacks(update_callbacks) + if train_iteration_mode is None: + train_iteration_mode = 'shuffled_sequential' + self.train_iteration_mode = train_iteration_mode + self.first = True + self.rng = make_np_rng(seed, which_method=["randn","randint"]) + self.theano_function_mode = theano_function_mode + self.monitoring_costs = monitoring_costs + + def setup(self, model, dataset): + """ + Compiles the theano functions needed for the train method. + + Parameters + ---------- + model : a Model instance + dataset : Dataset + """ + self.i = 0 + if self.cost is None: + self.cost = model.get_default_cost() + + inf_params = [param for param in model.get_params() + if np.any(np.isinf(param.get_value()))] + if len(inf_params) > 0: + raise ValueError("These params are Inf: "+str(inf_params)) + if any([np.any(np.isnan(param.get_value())) + for param in model.get_params()]): + nan_params = [param for param in model.get_params() + if np.any(np.isnan(param.get_value()))] + raise ValueError("These params are NaN: "+str(nan_params)) + self.model = model + + self._synchronize_batch_size(model) + model._test_batch_size = self.batch_size + self.monitor = Monitor.get_monitor(model) + self.monitor._sanity_check() + + # test if force batch size and batch size + if getattr(model, "force_batch_size", False) and \ + any(dataset.get_design_matrix().shape[0] % self.batch_size != 0 for + dataset in self.monitoring_dataset.values()) and \ + not has_uniform_batch_size(self.monitor_iteration_mode): + + raise ValueError("Dataset size is not a multiple of batch size." + "You should set monitor_iteration_mode to " + "even_sequential, even_shuffled_sequential or " + "even_batchwise_shuffled_sequential") + + data_specs = self.cost.get_data_specs(self.model) + mapping = DataSpecsMapping(data_specs) + space_tuple = mapping.flatten(data_specs[0], return_tuple=True) + source_tuple = mapping.flatten(data_specs[1], return_tuple=True) + + # Build a flat tuple of Theano Variables, one for each space. + # We want that so that if the same space/source is specified + # more than once in data_specs, only one Theano Variable + # is generated for it, and the corresponding value is passed + # only once to the compiled Theano function. + theano_args = [] + for space, source in safe_zip(space_tuple, source_tuple): + name = '%s[%s]' % (self.__class__.__name__, source) + arg = space.make_theano_batch(name=name, + batch_size=self.batch_size) + theano_args.append(arg) + theano_args = tuple(theano_args) + + # Methods of `self.cost` need args to be passed in a format compatible + # with data_specs + nested_args = mapping.nest(theano_args) + fixed_var_descr = self.cost.get_fixed_var_descr(model, nested_args) + self.on_load_batch = fixed_var_descr.on_load_batch + + cost_value = self.cost.expr(model, nested_args, + ** fixed_var_descr.fixed_vars) + + if cost_value is not None and cost_value.name is None: + # Concatenate the name of all tensors in theano_args !? + cost_value.name = 'objective' + + # Set up monitor to model the objective value, learning rate, + # momentum (if applicable), and extra channels defined by + # the cost + learning_rate = self.learning_rate + if self.monitoring_dataset is not None: + if (self.monitoring_batch_size is None and + self.monitoring_batches is None): + self.monitoring_batch_size = self.batch_size + self.monitoring_batches = self.batches_per_iter + self.monitor.setup(dataset=self.monitoring_dataset, + cost=self.cost, + batch_size=self.monitoring_batch_size, + num_batches=self.monitoring_batches, + extra_costs=self.monitoring_costs, + mode=self.monitor_iteration_mode) + dataset_name = self.monitoring_dataset.keys()[0] + monitoring_dataset = self.monitoring_dataset[dataset_name] + #TODO: have Monitor support non-data-dependent channels + self.monitor.add_channel(name='learning_rate', + ipt=None, + val=learning_rate, + data_specs=(NullSpace(), ''), + dataset=monitoring_dataset) + + if self.learning_rule: + self.learning_rule.add_channels_to_monitor( + self.monitor, + monitoring_dataset) + + params = list(model.get_params()) + assert len(params) > 0 + for i, param in enumerate(params): + if param.name is None: + param.name = 'sgd_params[%d]' % i + self.params = params + + + grads, updates = self.cost.get_gradients(model, nested_args, + ** fixed_var_descr.fixed_vars) + if not isinstance(grads, OrderedDict): + raise TypeError(str(type(self.cost)) + ".get_gradients returned " + + "something with" + str(type(grads)) + "as its " + + "first member. Expected OrderedDict.") + + for param in grads: + assert param in params + for param in params: + assert param in grads + + lr_scalers = model.get_lr_scalers() + + for key in lr_scalers: + if key not in params: + raise ValueError("Tried to scale the learning rate on " +\ + str(key)+" which is not an optimization parameter.") + + assert len(updates.keys()) == 0 + + def get_func(learn_discriminator, learn_generator): + + updates = OrderedDict() + + assert (learn_discriminator or learn_generator) and not (learn_discriminator and learn_generator) + + if learn_discriminator: + cur_params = model.discriminator.get_params() + else: + cur_params = model.generator.get_params() + + cur_grads = OrderedDict() + for param in cur_params: + cur_grads[param] = grads[param] + + for param in grads: + if grads[param].name is None and cost_value is not None: + grads[param].name = ('grad(%(costname)s, %(paramname)s)' % + {'costname': cost_value.name, + 'paramname': param.name}) + assert grads[param].dtype == param.dtype + + cur_lr_scalers = OrderedDict() + for param in cur_params: + if param in lr_scalers: + lr_scaler = lr_scalers[param] + cur_lr_scalers[param] = lr_scaler + + log.info('Parameter and initial learning rate summary:') + for param in cur_params: + param_name = param.name + if param_name is None: + param_name = 'anon_param' + lr = learning_rate.get_value() * cur_lr_scalers.get(param,1.) + log.info('\t' + param_name + ': ' + str(lr)) + + if self.learning_rule: + updates.update(self.learning_rule.get_updates( + learning_rate, cur_grads, cur_lr_scalers)) + else: + # Use standard SGD updates with fixed learning rate. + updates.update( dict(safe_zip(params, [param - learning_rate * \ + lr_scalers.get(param, 1.) * grads[param] + for param in params]))) + + for param in cur_params: + if updates[param].name is None: + updates[param].name = 'sgd_update(' + param.name + ')' + model.modify_updates(updates) + for param in cur_params: + update = updates[param] + if update.name is None: + update.name = 'censor(sgd_update(' + param.name + '))' + for update_val in get_debug_values(update): + if np.any(np.isinf(update_val)): + raise ValueError("debug value of %s contains infs" % + update.name) + if np.any(np.isnan(update_val)): + raise ValueError("debug value of %s contains nans" % + update.name) + + + with log_timing(log, 'Compiling sgd_update'): + return function(theano_args, + updates=updates, + name='sgd_update', + on_unused_input='ignore', + mode=self.theano_function_mode) + self.d_func = get_func(1, 0) + self.g_func = get_func(0, 1) + + def train(self, dataset): + """ + Runs one epoch of SGD training on the specified dataset. + + Parameters + ---------- + dataset : Dataset + """ + if not hasattr(self, 'd_func'): + raise Exception("train called without first calling setup") + + # Make sure none of the parameters have bad values + for param in self.params: + value = param.get_value(borrow=True) + if np.any(np.isnan(value)) or np.any(np.isinf(value)): + raise Exception("NaN in " + param.name) + + self.first = False + rng = self.rng + if not is_stochastic(self.train_iteration_mode): + rng = None + + data_specs = self.cost.get_data_specs(self.model) + + # The iterator should be built from flat data specs, so it returns + # flat, non-redundent tuples of data. + mapping = DataSpecsMapping(data_specs) + space_tuple = mapping.flatten(data_specs[0], return_tuple=True) + source_tuple = mapping.flatten(data_specs[1], return_tuple=True) + if len(space_tuple) == 0: + # No data will be returned by the iterator, and it is impossible + # to know the size of the actual batch. + # It is not decided yet what the right thing to do should be. + raise NotImplementedError("Unable to train with SGD, because " + "the cost does not actually use data from the data set. " + "data_specs: %s" % str(data_specs)) + flat_data_specs = (CompositeSpace(space_tuple), source_tuple) + + iterator = dataset.iterator(mode=self.train_iteration_mode, + batch_size=self.batch_size, + data_specs=flat_data_specs, return_tuple=True, + rng = rng, num_batches = self.batches_per_iter) + + on_load_batch = self.on_load_batch + i = self.i + for batch in iterator: + for callback in on_load_batch: + callback(*batch) + if i == self.discriminator_steps: + self.g_func(*batch) + i = 0 + else: + self.d_func(*batch) + i += 1 + # iterator might return a smaller batch if dataset size + # isn't divisible by batch_size + # Note: if data_specs[0] is a NullSpace, there is no way to know + # how many examples would actually have been in the batch, + # since it was empty, so actual_batch_size would be reported as 0. + actual_batch_size = flat_data_specs[0].np_batch_size(batch) + self.monitor.report_batch(actual_batch_size) + for callback in self.update_callbacks: + callback(self) + + # Make sure none of the parameters have bad values + for param in self.params: + value = param.get_value(borrow=True) + if np.any(np.isnan(value)) or np.any(np.isinf(value)): + raise Exception("NaN in " + param.name) + self.i = i + + def continue_learning(self, model): + """ + Returns True if the algorithm should continue running, or False + if it has reached convergence / started overfitting and should + stop. + + Parameters + ---------- + model : a Model instance + """ + if self.termination_criterion is None: + return True + else: + return self.termination_criterion.continue_learning(self.model) + +class MonitorBasedLRAdjuster(TrainExtension): + """ + A TrainExtension that uses the on_monitor callback to adjust + the learning rate on each epoch. It pulls out a channel + from the model's monitor and adjusts the learning rate + based on what happened to the monitoring channel on the last + epoch. If the channel is greater than high_trigger times + its previous value, the learning rate will be scaled by + shrink_amt (which should be < 1 for this scheme to make + sense). The idea is that in this case the learning algorithm + is overshooting the bottom of the objective function. + + If the objective is less than high_trigger but + greater than low_trigger times its previous value, the + learning rate will be scaled by grow_amt (which should be > 1 + for this scheme to make sense). The idea is that the learning + algorithm is making progress but at too slow of a rate. + + Parameters + ---------- + high_trigger : float, optional + See class-level docstring + low_trigger : float, optional + See class-level docstring + grow_amt : float, optional + See class-level docstring + min_lr : float, optional + All updates to the learning rate are clipped to be at least + this value. + max_lr : float, optional + All updates to the learning rate are clipped to be at most + this value. + dataset_name : str, optional + If specified, use dataset_name + "_objective" as the channel + to guide the learning rate adaptation. + channel_name : str, optional + If specified, use channel_name as the channel to guide the + learning rate adaptation. Conflicts with dataset_name. + If neither dataset_name nor channel_name is specified, uses + "objective" + """ + + def __init__(self, high_trigger=1., shrink_amt=.99, + low_trigger=.99, grow_amt=1.01, + min_lr = 1e-7, max_lr = 1., + dataset_name=None, channel_name=None): + self.high_trigger = high_trigger + self.shrink_amt = shrink_amt + self.low_trigger = low_trigger + self.grow_amt = grow_amt + self.min_lr = min_lr + self.max_lr = max_lr + self.dataset_name = None + if channel_name is not None: + self.channel_name = channel_name + else: + if dataset_name is not None: + self.channel_name = dataset_name + '_objective' + self.dataset_name = dataset_name + else: + self.channel_name = None + + def on_monitor(self, model, dataset, algorithm): + """ + Adjusts the learning rate based on the contents of model.monitor + + Parameters + ---------- + model : a Model instance + dataset : Dataset + algorithm : WRITEME + """ + model = algorithm.model + lr = algorithm.learning_rate + current_learning_rate = lr.get_value() + assert hasattr(model, 'monitor'), ("no monitor associated with " + + str(model)) + monitor = model.monitor + monitor_channel_specified = True + + if self.channel_name is None: + monitor_channel_specified = False + channels = [elem for elem in monitor.channels + if elem.endswith("objective")] + if len(channels) < 1: + raise ValueError("There are no monitoring channels that end " + "with \"objective\". Please specify either " + "channel_name or dataset_name.") + elif len(channels) > 1: + datasets = algorithm.monitoring_dataset.keys() + raise ValueError("There are multiple monitoring channels that" + "end with \"_objective\". The list of available " + "datasets are: " + + str(datasets) + " . Please specify either " + "channel_name or dataset_name in the " + "MonitorBasedLRAdjuster constructor to " + 'disambiguate.') + else: + self.channel_name = channels[0] + warnings.warn('The channel that has been chosen for ' + 'monitoring is: ' + + str(self.channel_name) + '.') + + try: + v = monitor.channels[self.channel_name].val_record + except KeyError: + err_input = '' + if monitor_channel_specified: + if self.dataset_name: + err_input = 'The dataset_name \'' + str( + self.dataset_name) + '\' is not valid.' + else: + err_input = 'The channel_name \'' + str( + self.channel_name) + '\' is not valid.' + err_message = 'There is no monitoring channel named \'' + \ + str(self.channel_name) + '\'. You probably need to ' + \ + 'specify a valid monitoring channel by using either ' + \ + 'dataset_name or channel_name in the ' + \ + 'MonitorBasedLRAdjuster constructor. ' + err_input + raise ValueError(err_message) + + if len(v) < 1: + if monitor.dataset is None: + assert len(v) == 0 + raise ValueError("You're trying to use a monitor-based " + "learning rate adjustor but the monitor has no " + "entries because you didn't specify a " + "monitoring dataset.") + + raise ValueError("For some reason there are no monitor entries" + "yet the MonitorBasedLRAdjuster has been " + "called. This should never happen. The Train" + " object should call the monitor once on " + "initialization, then call the callbacks. " + "It seems you are either calling the " + "callback manually rather than as part of a " + "training algorithm, or there is a problem " + "with the Train object.") + if len(v) == 1: + #only the initial monitoring has happened + #no learning has happened, so we can't adjust the learning rate yet + #just do nothing + return + + rval = current_learning_rate + + log.info("monitoring channel is {0}".format(self.channel_name)) + + if v[-1] > self.high_trigger * v[-2]: + rval *= self.shrink_amt + log.info("shrinking learning rate to %f" % rval) + elif v[-1] > self.low_trigger * v[-2]: + rval *= self.grow_amt + log.info("growing learning rate to %f" % rval) + + rval = max(self.min_lr, rval) + rval = min(self.max_lr, rval) + + lr.set_value(np.cast[lr.dtype](rval)) + + +class PatienceBasedTermCrit(object): + """ + A monitor-based termination criterion using a geometrically increasing + amount of patience. If the selected channel has decreased by a certain + proportion when comparing to the lowest value seen yet, the patience is + set to a factor of the number of examples seen, which by default + (patience_increase=2.) ensures the model has seen as many examples as the + number of examples that lead to the lowest value before concluding a local + optima has been reached. + + Note: Technically, the patience corresponds to a number of epochs to be + independent of the size of the dataset, so be aware of that when choosing + initial_patience. + + Parameters + ---------- + prop_decrease : float + The factor X in the (1 - X) * best_value threshold + initial_patience : int + Minimal number of epochs the model has to run before it can stop + patience_increase : float, optional + The factor X in the patience = X * n_iter update. + channel_name : string, optional + Name of the channel to examine. If None and the monitor + has only one channel, this channel will be used; otherwise, an + error will be raised. + """ + def __init__(self, prop_decrease, initial_patience, + patience_increase=2., channel_name=None): + self._channel_name = channel_name + self.prop_decrease = prop_decrease + self.patience = initial_patience + self.best_value = np.inf + self.patience_increase = patience_increase + + def __call__(self, model): + """ + Returns True or False depending on whether the optimization should + stop or not. The optimization should stop if it has run for a number + of epochs superior to the patience without any improvement. + + Parameters + ---------- + model : Model + The model used in the experiment and from which the monitor used + in the termination criterion will be extracted. + + Returns + ------- + bool + True or False, indicating if the optimization should stop or not. + """ + monitor = model.monitor + # In the case the monitor has only one channel, the channel_name can + # be omitted and the criterion will examine the only channel + # available. However, if the monitor has multiple channels, leaving + # the channel_name unspecified will raise an error. + if self._channel_name is None: + if len(monitor.channels) != 1: + raise ValueError("Only single-channel monitors are supported " + "for channel_name == None") + v = monitor.channels.values()[0].val_record + else: + v = monitor.channels[self._channel_name].val_record + # If the channel value decrease is higher than the threshold, we + # update the best value to this value and we update the patience. + if v[-1] < self.best_value * (1. - self.prop_decrease): + # Using the max between actual patience and updated patience + # ensures that the model will run for at least the initial + # patience and that it would behave correctly if the user + # chooses a dumb value (i.e. less than 1) + self.patience = max(self.patience, len(v) * self.patience_increase) + self.best_value = v[-1] + + return len(v) < self.patience + + +class AnnealedLearningRate(object): + """ + This is a callback for the SGD algorithm rather than the Train object. + This anneals the learning rate to decrease as 1/t where t is the number + of gradient descent updates done so far. Use OneOverEpoch as Train object + callback if you would prefer 1/t where t is epochs. + + Parameters + ---------- + anneal_start : int + The epoch on which to begin annealing + """ + def __init__(self, anneal_start): + self._initialized = False + self._count = 0 + self._anneal_start = anneal_start + + def __call__(self, algorithm): + """ + Updates the learning rate according to the annealing schedule. + + Parameters + ---------- + algorithm : WRITEME + """ + if not self._initialized: + self._base = algorithm.learning_rate.get_value() + self._count += 1 + algorithm.learning_rate.set_value(self.current_learning_rate()) + + def current_learning_rate(self): + """ + Returns the current desired learning rate according to the + annealing schedule. + """ + return self._base * min(1, self._anneal_start / self._count) + +class ExponentialDecay(object): + """ + This is a callback for the `SGD` algorithm rather than the `Train` object. + This anneals the learning rate by dividing by decay_factor after each + gradient descent step. It will not shrink the learning rate beyond + `min_lr`. + + Parameters + ---------- + decay_factor : float + The learning rate at step t is given by + `init_learning_rate / (decay_factor ** t)` + min_lr : float + The learning rate will be clipped to be at least this value + """ + + def __init__(self, decay_factor, min_lr): + if isinstance(decay_factor, str): + decay_factor = float(decay_factor) + if isinstance(min_lr, str): + min_lr = float(min_lr) + assert isinstance(decay_factor, float) + assert isinstance(min_lr, float) + self.__dict__.update(locals()) + del self.self + self._count = 0 + self._min_reached = False + + def __call__(self, algorithm): + """ + Updates the learning rate according to the exponential decay schedule. + + Parameters + ---------- + algorithm : SGD + The SGD instance whose `learning_rate` field should be modified. + """ + if self._count == 0: + self._base_lr = algorithm.learning_rate.get_value() + self._count += 1 + + if not self._min_reached: + # If we keep on executing the exponentiation on each mini-batch, + # we will eventually get an OverflowError. So make sure we + # only do the computation until min_lr is reached. + new_lr = self._base_lr / (self.decay_factor ** self._count) + if new_lr <= self.min_lr: + self._min_reached = True + new_lr = self.min_lr + else: + new_lr = self.min_lr + + new_lr = np.cast[config.floatX](new_lr) + algorithm.learning_rate.set_value(new_lr) + +class LinearDecay(object): + """ + This is a callback for the SGD algorithm rather than the Train object. + This anneals the learning rate to decay_factor times of the initial value + during time start till saturate. + + Parameters + ---------- + start : int + The step at which to start decreasing the learning rate + saturate : int + The step at which to stop decreating the learning rate + decay_factor : float + `final learning rate = decay_factor * initial learning rate` + """ + + def __init__(self, start, saturate, decay_factor): + if isinstance(decay_factor, str): + decay_factor = float(decay_factor) + if isinstance(start, str): + start = float(start) + if isinstance(saturate, str): + saturate = float(saturate) + assert isinstance(decay_factor, float) + assert isinstance(start, (py_integer_types, py_float_types)) + assert isinstance(saturate, (py_integer_types, py_float_types)) + assert saturate > start + assert start > 0 + self.__dict__.update(locals()) + del self.self + self._count = 0 + + def __call__(self, algorithm): + """ + Adjusts the learning rate according to the linear decay schedule + + Parameters + ---------- + algorithm : WRITEME + """ + if self._count == 0: + self._base_lr = algorithm.learning_rate.get_value() + self._step = ((self._base_lr - self._base_lr * self.decay_factor) / + (self.saturate - self.start + 1)) + self._count += 1 + if self._count >= self.start: + if self._count < self.saturate: + new_lr = self._base_lr - self._step * (self._count + - self.start + 1) + else: + new_lr = self._base_lr * self.decay_factor + else: + new_lr = self._base_lr + assert new_lr > 0 + new_lr = np.cast[config.floatX](new_lr) + algorithm.learning_rate.set_value(new_lr) + + +def MomentumAdjustor(final_momentum, start, saturate): + """ + Deprecated class used with the deprecated init_momentum argument. + Use learning_rule.MomentumAdjustor instead. + + Parameters + ---------- + final_momentum : WRITEME + start : WRITEME + saturate : WRITEME + """ + warnings.warn("sgd.MomentumAdjustor interface is deprecated and will " + "become officially unsupported as of May 9, 2014. Please use " + "`learning_rule.MomentumAdjustor` instead.") + return LRMomentumAdjustor(final_momentum, start, saturate) + + +class OneOverEpoch(TrainExtension): + """ + Scales the learning rate like one over # epochs + + Parameters + ---------- + start : int + The epoch on which to start shrinking the learning rate + half_life : int, optional + How many epochs after start it will take for the learning rate to lose + half its value for the first time (to lose the next half of its value + will take twice as long) + min_lr : float, optional + The minimum value the learning rate can take on + """ + def __init__(self, start, half_life = None, min_lr = 1e-6): + self.__dict__.update(locals()) + del self.self + self._initialized = False + self._count = 0 + assert start >= 0 + if half_life is None: + self.half_life = start + 1 + else: + assert half_life > 0 + + def on_monitor(self, model, dataset, algorithm): + """ + Adjusts the learning rate according to the decay schedule. + + Parameters + ---------- + model : a Model instance + dataset : Dataset + algorithm : WRITEME + """ + + if not self._initialized: + self._init_lr = algorithm.learning_rate.get_value() + if self._init_lr < self.min_lr: + raise ValueError("The initial learning rate is smaller than " + + "the minimum allowed learning rate.") + self._initialized = True + self._count += 1 + algorithm.learning_rate.set_value(np.cast[config.floatX]( + self.current_lr())) + + def current_lr(self): + """ + Returns the learning rate currently desired by the decay schedule. + """ + if self._count < self.start: + scale = 1 + else: + scale = float(self.half_life) / float(self._count - self.start + + self.half_life) + lr = self._init_lr * scale + clipped = max(self.min_lr, lr) + return clipped + +class LinearDecayOverEpoch(TrainExtension): + """ + Scales the learning rate linearly on each epochs + + Parameters + ---------- + start : int + The epoch on which to start shrinking the learning rate + saturate : int + The epoch to saturate the shrinkage + decay_factor : float + The final value would be initial learning rate times decay_factor + """ + + def __init__(self, start, saturate, decay_factor): + self.__dict__.update(locals()) + del self.self + self._initialized = False + self._count = 0 + assert isinstance(decay_factor, float) + assert isinstance(start, (py_integer_types, py_float_types)) + assert isinstance(saturate, (py_integer_types, py_float_types)) + assert saturate > start + assert start >= 0 + assert saturate >= start + + def on_monitor(self, model, dataset, algorithm): + """ + Updates the learning rate based on the linear decay schedule. + + Parameters + ---------- + model : a Model instance + dataset : Dataset + algorithm : WRITEME + """ + if not self._initialized: + self._init_lr = algorithm.learning_rate.get_value() + self._step = ((self._init_lr - self._init_lr * self.decay_factor) / + (self.saturate - self.start + 1)) + self._initialized = True + self._count += 1 + algorithm.learning_rate.set_value(np.cast[config.floatX]( + self.current_lr())) + + def current_lr(self): + """ + Returns the learning rate currently desired by the decay schedule. + """ + if self._count >= self.start: + if self._count < self.saturate: + new_lr = self._init_lr - self._step * (self._count + - self.start + 1) + else: + new_lr = self._init_lr * self.decay_factor + else: + new_lr = self._init_lr + assert new_lr > 0 + return new_lr + +class _PolyakWorker(object): + """ + Only to be used by the PolyakAveraging TrainingCallback below. + Do not use directly. + A callback for the SGD class. + + Parameters + ---------- + model : a Model + The model whose parameters we want to train with Polyak averaging + """ + + def __init__(self, model): + avg_updates = OrderedDict() + t = sharedX(1.) + self.param_to_mean = OrderedDict() + for param in model.get_params(): + mean = sharedX(param.get_value()) + assert type(mean) == type(param) + self.param_to_mean[param] = mean + avg_updates[mean] = mean - (mean - param) / t + avg_updates[t] = t + 1. + self.avg = function([], updates = avg_updates) + + def __call__(self, algorithm): + """ + To be called after each SGD step. + Updates the Polyak averaged-parameters for this model + + Parameters + ---------- + algorithm : WRITEME + """ + self.avg() + +class PolyakAveraging(TrainExtension): + """ + See "A Tutorial on Stochastic Approximation Algorithms + for Training Restricted Boltzmann Machines and + Deep Belief Nets" by Kevin Swersky et al + + This functionality is still a work in progress. Currently, + your model needs to implement "add_polyak_channels" to + use it. + + The problem is that Polyak averaging shouldn't modify + the model parameters. It should keep a second copy + that it averages in the background. This second copy + doesn't get to come back in and affect the learning process + though. + + (IG tried having the second copy get pushed back into + the model once per epoch, but this turned out to be + harmful, at least in limited tests) + + So we need a cleaner interface for monitoring the + averaged copy of the parameters, and we need to make + sure the saved model at the end uses the averaged + parameters, not the parameters used for computing + the gradients during training. + + TODO: make use of the new on_save callback instead + of duplicating Train's save_freq flag + + Parameters + ---------- + start : int + The epoch after which to start averaging (0 = start averaging + immediately) + save_path : str, optional + WRITEME + save_freq : int, optional + WRITEME + + Notes + ----- + This is usually used with a fixed, rather than annealed learning + rate. It may be used in conjunction with momentum. + """ + + def __init__(self, start, save_path=None, save_freq=1): + self.__dict__.update(locals()) + del self.self + self._count = 0 + assert isinstance(start, py_integer_types) + assert start >= 0 + + def on_monitor(self, model, dataset, algorithm): + """ + Make sure Polyak-averaged model gets monitored. + Save the model if necessary. + + Parameters + ---------- + model : a Model instance + dataset : Dataset + algorithm : WRITEME + """ + if self._count == self.start: + self._worker = _PolyakWorker(model) + algorithm.update_callbacks.append(self._worker) + #HACK + try: + model.add_polyak_channels(self._worker.param_to_mean, + algorithm.monitoring_dataset) + except AttributeError: + pass + elif self.save_path is not None and self._count > self.start and \ + self._count % self.save_freq == 0: + saved_params = OrderedDict() + for param in model.get_params(): + saved_params[param] = param.get_value() + param.set_value(self._worker.param_to_mean[param].get_value()) + serial.save(self.save_path, model) + for param in model.get_params(): + param.set_value(saved_params[param]) + self._count += 1 diff --git a/sgd_alt.py b/sgd_alt.py new file mode 100644 index 0000000..bb6ffef --- /dev/null +++ b/sgd_alt.py @@ -0,0 +1,1151 @@ +""" +Copy of pylearn2's sgd.py, hacked to support alternating between +epochs of updating only the discriminator and epochs of updating +both discriminator and generator. Ideally this would +be accomplished using pylearn2's FixedVarDescr implementation, +but it is currently not very well supported. +""" +from __future__ import division + +__authors__ = "Ian Goodfellow" +__copyright__ = "Copyright 2010-2012, Universite de Montreal" +__credits__ = ["Ian Goodfellow, David Warde-Farley"] +__license__ = "3-clause BSD" +__maintainer__ = "David Warde-Farley" +__email__ = "pylearn-dev@googlegroups" + +import logging +import warnings +import numpy as np + +from theano import config +from theano import function +from theano.compat.python2x import OrderedDict +from theano.gof.op import get_debug_values + +from pylearn2.monitor import Monitor +from pylearn2.space import CompositeSpace, NullSpace +from pylearn2.train_extensions import TrainExtension +from pylearn2.training_algorithms.training_algorithm import TrainingAlgorithm +from pylearn2.training_algorithms.learning_rule import Momentum +from pylearn2.training_algorithms.learning_rule import MomentumAdjustor \ + as LRMomentumAdjustor +from pylearn2.utils.iteration import is_stochastic, has_uniform_batch_size +from pylearn2.utils import py_integer_types, py_float_types +from pylearn2.utils import safe_zip +from pylearn2.utils import serial +from pylearn2.utils import sharedX +from pylearn2.utils.data_specs import DataSpecsMapping +from pylearn2.utils.timing import log_timing +from pylearn2.utils.rng import make_np_rng + + +log = logging.getLogger(__name__) + + +class SGD(TrainingAlgorithm): + """ + SGD = (Minibatch) Stochastic Gradient Descent. + A TrainingAlgorithm that does stochastic gradient descent on minibatches + of training examples. + + For theoretical background on this algorithm, see Yoshua Bengio's machine + learning course notes on the subject: + + http://www.iro.umontreal.ca/~pift6266/H10/notes/gradient.html + + Parameters + ---------- + learning_rate : float + The learning rate to use. Train object callbacks can change the + learning rate after each epoch. SGD update_callbacks can change + it after each minibatch. + cost : pylearn2.costs.cost.Cost, optional + Cost object specifying the objective function to be minimized. + Optionally, may be None. In this case, SGD will call the model's + get_default_cost method to obtain the objective function. + batch_size : int, optional + The size of the batch to be used. + If not specified, the model will be asked for the batch size, so + you must have specified the batch size there. + (Some models are rigidly defined to only work with one batch size) + monitoring_batch_size : int, optional + The size of the monitoring batches. + monitoring_batches : int, optional + At the start of each epoch, we run "monitoring", to evaluate + quantities such as the validation set error. + monitoring_batches, if specified, determines the number of batches + to draw from the iterator for each monitoring dataset. + Unnecessary if not using monitoring or if `monitor_iteration_mode` + is 'sequential' and `batch_size` is specified (number of + batches will be calculated based on full dataset size). + TODO: make it possible to specify different monitoring_batches + for each monitoring dataset. The Monitor itself already supports + this. + monitoring_dataset : Dataset or dictionary, optional + If not specified, no monitoring is used. + If specified to be a Dataset, monitor on that Dataset. + If specified to be dictionary, the keys should be string names + of datasets, and the values should be Datasets. All monitoring + channels will be computed for all monitoring Datasets and will + have the dataset name and an underscore prepended to them. + monitor_iteration_mode : str, optional + The iteration mode used to iterate over the examples in all + monitoring datasets. If not specified, defaults to 'sequential'. + TODO: make it possible to specify different modes for different + datasets. + termination_criterion : instance of \ + pylearn2.termination_criteria.TerminationCriterion, optional + + Used to determine when the algorithm should stop running. + If not specified, runs forever--or more realistically, until + external factors halt the python process (Kansas 1977). + update_callbacks : list, optional + If specified, each member of the list should be a callable that + accepts an SGD instance as its only argument. + All callbacks will be called with this SGD instance after each + SGD step. + learning_rule : training_algorithms.learning_rule.LearningRule, optional + A learning rule computes the new parameter values given old + parameters and first-order gradients. If learning_rule is None, + sgd.SGD will update parameters according to the standard SGD + learning rule: + + .. code-block:: none + + param := param - learning_rate * d cost / d param + + This argument allows more sophisticated learning rules, such + as SGD with momentum. + init_momentum : float, **DEPRECATED** option + Use learning_rule instead. + If None, does not use momentum otherwise, use momentum and + initialize the momentum coefficient to init_momentum. Callbacks + can change this over time just like the learning rate. If the + gradient is the same on every step, then the update taken by the + SGD algorithm is scaled by a factor of 1/(1-momentum). See + section 9 of Geoffrey Hinton's "A Practical Guide to Training + Restricted Boltzmann Machines" for details. + set_batch_size : bool, optional + Defaults to False. + If True, and batch_size conflicts with model.force_batch_size, + will call model.set_batch_size(batch_size) in an attempt to + change model.force_batch_size + train_iteration_mode : str, optional + Defaults to 'shuffled_sequential'. + The iteration mode to use for iterating through training examples. + batches_per_iter : int, optional + The number of batches to draw from the iterator over training + examples. + If iteration mode is 'sequential' or 'shuffled_sequential', this + is unnecessary; when unspecified we will iterate over all examples. + theano_function_mode : a valid argument to theano.function's \ + 'mode' parameter, optional + + The theano mode to compile the updates function with. Note that + pylearn2 includes some wraplinker modes that are not bundled with + theano. See pylearn2.devtools. These extra modes let you do + things like check for NaNs at every step, or record md5 digests + of all computations performed by the update function to help + isolate problems with nondeterminism. + monitoring_costs : list, optional + a list of Cost instances. The Monitor will also include all + channels defined by these Costs, even though we don't train + using them. + seed : valid argument to np.random.RandomState, optional + The seed used for the random number generate to be passed to the + training dataset iterator (if any) + """ + def __init__(self, learning_rate, cost=None, batch_size=None, + monitoring_batch_size=None, monitoring_batches=None, + monitoring_dataset=None, monitor_iteration_mode='sequential', + termination_criterion=None, update_callbacks=None, + learning_rule = None, init_momentum = None, + set_batch_size = False, + train_iteration_mode = None, batches_per_iter=None, + theano_function_mode = None, monitoring_costs=None, + seed=[2012, 10, 5], discriminator_steps=1): + + self.discriminator_steps = discriminator_steps + self.train_generator = 0 + + if isinstance(cost, (list, tuple, set)): + raise TypeError("SGD no longer supports using collections of " + + "Costs to represent a sum of Costs. Use " + + "pylearn2.costs.cost.SumOfCosts instead.") + + if init_momentum: + warnings.warn("init_momentum interface is deprecated and will " + "become officially unsuported as of May 9, 2014. Please use the " + "`learning_rule` parameter instead, providing an object of type " + "`pylearn2.training_algorithms.learning_rule.Momentum` instead") + # Convert to new interface under the hood. + self.learning_rule = Momentum(init_momentum) + else: + self.learning_rule = learning_rule + + self.learning_rate = sharedX(learning_rate, 'learning_rate') + self.cost = cost + self.batch_size = batch_size + self.set_batch_size = set_batch_size + self.batches_per_iter = batches_per_iter + self._set_monitoring_dataset(monitoring_dataset) + self.monitoring_batch_size = monitoring_batch_size + self.monitoring_batches = monitoring_batches + self.monitor_iteration_mode = monitor_iteration_mode + if monitoring_dataset is None: + if monitoring_batch_size is not None: + raise ValueError("Specified a monitoring batch size " + + "but not a monitoring dataset.") + if monitoring_batches is not None: + raise ValueError("Specified an amount of monitoring batches " + + "but not a monitoring dataset.") + self.termination_criterion = termination_criterion + self._register_update_callbacks(update_callbacks) + if train_iteration_mode is None: + train_iteration_mode = 'shuffled_sequential' + self.train_iteration_mode = train_iteration_mode + self.first = True + self.rng = make_np_rng(seed, which_method=["randn","randint"]) + self.theano_function_mode = theano_function_mode + self.monitoring_costs = monitoring_costs + + def setup(self, model, dataset): + """ + Compiles the theano functions needed for the train method. + + Parameters + ---------- + model : a Model instance + dataset : Dataset + """ + if self.cost is None: + self.cost = model.get_default_cost() + + inf_params = [param for param in model.get_params() + if np.any(np.isinf(param.get_value()))] + if len(inf_params) > 0: + raise ValueError("These params are Inf: "+str(inf_params)) + if any([np.any(np.isnan(param.get_value())) + for param in model.get_params()]): + nan_params = [param for param in model.get_params() + if np.any(np.isnan(param.get_value()))] + raise ValueError("These params are NaN: "+str(nan_params)) + self.model = model + + self._synchronize_batch_size(model) + model._test_batch_size = self.batch_size + self.monitor = Monitor.get_monitor(model) + self.monitor._sanity_check() + + # test if force batch size and batch size + if getattr(model, "force_batch_size", False) and \ + any(dataset.get_design_matrix().shape[0] % self.batch_size != 0 for + dataset in self.monitoring_dataset.values()) and \ + not has_uniform_batch_size(self.monitor_iteration_mode): + + raise ValueError("Dataset size is not a multiple of batch size." + "You should set monitor_iteration_mode to " + "even_sequential, even_shuffled_sequential or " + "even_batchwise_shuffled_sequential") + + data_specs = self.cost.get_data_specs(self.model) + mapping = DataSpecsMapping(data_specs) + space_tuple = mapping.flatten(data_specs[0], return_tuple=True) + source_tuple = mapping.flatten(data_specs[1], return_tuple=True) + + # Build a flat tuple of Theano Variables, one for each space. + # We want that so that if the same space/source is specified + # more than once in data_specs, only one Theano Variable + # is generated for it, and the corresponding value is passed + # only once to the compiled Theano function. + theano_args = [] + for space, source in safe_zip(space_tuple, source_tuple): + name = '%s[%s]' % (self.__class__.__name__, source) + arg = space.make_theano_batch(name=name, + batch_size=self.batch_size) + theano_args.append(arg) + theano_args = tuple(theano_args) + + # Methods of `self.cost` need args to be passed in a format compatible + # with data_specs + nested_args = mapping.nest(theano_args) + fixed_var_descr = self.cost.get_fixed_var_descr(model, nested_args) + self.on_load_batch = fixed_var_descr.on_load_batch + + cost_value = self.cost.expr(model, nested_args, + ** fixed_var_descr.fixed_vars) + + if cost_value is not None and cost_value.name is None: + # Concatenate the name of all tensors in theano_args !? + cost_value.name = 'objective' + + # Set up monitor to model the objective value, learning rate, + # momentum (if applicable), and extra channels defined by + # the cost + learning_rate = self.learning_rate + if self.monitoring_dataset is not None: + if (self.monitoring_batch_size is None and + self.monitoring_batches is None): + self.monitoring_batch_size = self.batch_size + self.monitoring_batches = self.batches_per_iter + self.monitor.setup(dataset=self.monitoring_dataset, + cost=self.cost, + batch_size=self.monitoring_batch_size, + num_batches=self.monitoring_batches, + extra_costs=self.monitoring_costs, + mode=self.monitor_iteration_mode) + dataset_name = self.monitoring_dataset.keys()[0] + monitoring_dataset = self.monitoring_dataset[dataset_name] + #TODO: have Monitor support non-data-dependent channels + self.monitor.add_channel(name='learning_rate', + ipt=None, + val=learning_rate, + data_specs=(NullSpace(), ''), + dataset=monitoring_dataset) + + if self.learning_rule: + self.learning_rule.add_channels_to_monitor( + self.monitor, + monitoring_dataset) + + params = list(model.get_params()) + assert len(params) > 0 + for i, param in enumerate(params): + if param.name is None: + param.name = 'sgd_params[%d]' % i + self.params = params + + + grads, updates = self.cost.get_gradients(model, nested_args, + ** fixed_var_descr.fixed_vars) + if not isinstance(grads, OrderedDict): + raise TypeError(str(type(self.cost)) + ".get_gradients returned " + + "something with" + str(type(grads)) + "as its " + + "first member. Expected OrderedDict.") + + for param in grads: + assert param in params + for param in params: + assert param in grads + + lr_scalers = model.get_lr_scalers() + + for key in lr_scalers: + if key not in params: + raise ValueError("Tried to scale the learning rate on " +\ + str(key)+" which is not an optimization parameter.") + + assert len(updates.keys()) == 0 + + def get_func(learn_discriminator, learn_generator, dont_you_fucking_dare_touch_the_generator=False): + + updates = OrderedDict() + + assert (learn_discriminator or learn_generator) and not (learn_discriminator and learn_generator) + + if learn_discriminator: + cur_params = model.discriminator.get_params() + else: + cur_params = model.generator.get_params() + + def check(): + for param in params: + if param not in cur_params: + assert param not in updates + + cur_grads = OrderedDict() + for param in cur_params: + cur_grads[param] = grads[param] + + for param in grads: + if grads[param].name is None and cost_value is not None: + grads[param].name = ('grad(%(costname)s, %(paramname)s)' % + {'costname': cost_value.name, + 'paramname': param.name}) + assert grads[param].dtype == param.dtype + + cur_lr_scalers = OrderedDict() + for param in cur_params: + if param in lr_scalers: + lr_scaler = lr_scalers[param] + cur_lr_scalers[param] = lr_scaler + + log.info('Parameter and initial learning rate summary:') + for param in cur_params: + param_name = param.name + if param_name is None: + param_name = 'anon_param' + lr = learning_rate.get_value() * cur_lr_scalers.get(param,1.) + log.info('\t' + param_name + ': ' + str(lr)) + + updates.update(self.learning_rule.get_updates( + learning_rate, cur_grads, cur_lr_scalers)) + check() + + for param in cur_params: + if updates[param].name is None: + updates[param].name = 'sgd_update(' + param.name + ')' + check() + model.modify_updates(updates) + check() + for param in cur_params: + update = updates[param] + if update.name is None: + update.name = 'censor(sgd_update(' + param.name + '))' + for update_val in get_debug_values(update): + if np.any(np.isinf(update_val)): + raise ValueError("debug value of %s contains infs" % + update.name) + if np.any(np.isnan(update_val)): + raise ValueError("debug value of %s contains nans" % + update.name) + + check() + + if dont_you_fucking_dare_touch_the_generator: + for param in model.generator.get_params(): + assert param not in updates + + with log_timing(log, 'Compiling sgd_update'): + return function(theano_args, + updates=updates, + name='sgd_update', + on_unused_input='ignore', + mode=self.theano_function_mode) + self.d_func = get_func(1, 0, dont_you_fucking_dare_touch_the_generator=True) + self.g_func = get_func(0, 1) + + def train(self, dataset): + """ + Runs one epoch of SGD training on the specified dataset. + + Parameters + ---------- + dataset : Dataset + """ + + + if not hasattr(self, 'd_func'): + raise Exception("train called without first calling setup") + + # Make sure none of the parameters have bad values + for param in self.params: + value = param.get_value(borrow=True) + if np.any(np.isnan(value)) or np.any(np.isinf(value)): + raise Exception("NaN in " + param.name) + + self.first = False + rng = self.rng + if not is_stochastic(self.train_iteration_mode): + rng = None + + data_specs = self.cost.get_data_specs(self.model) + + # The iterator should be built from flat data specs, so it returns + # flat, non-redundent tuples of data. + mapping = DataSpecsMapping(data_specs) + space_tuple = mapping.flatten(data_specs[0], return_tuple=True) + source_tuple = mapping.flatten(data_specs[1], return_tuple=True) + if len(space_tuple) == 0: + # No data will be returned by the iterator, and it is impossible + # to know the size of the actual batch. + # It is not decided yet what the right thing to do should be. + raise NotImplementedError("Unable to train with SGD, because " + "the cost does not actually use data from the data set. " + "data_specs: %s" % str(data_specs)) + flat_data_specs = (CompositeSpace(space_tuple), source_tuple) + + iterator = dataset.iterator(mode=self.train_iteration_mode, + batch_size=self.batch_size, + data_specs=flat_data_specs, return_tuple=True, + rng = rng, num_batches = self.batches_per_iter) + + + on_load_batch = self.on_load_batch + i = 0 + for batch in iterator: + for callback in on_load_batch: + callback(*batch) + if self.train_generator and i == self.discriminator_steps: + self.g_func(*batch) + i = 0 + else: + self.d_func(*batch) + i += 1 + # iterator might return a smaller batch if dataset size + # isn't divisible by batch_size + # Note: if data_specs[0] is a NullSpace, there is no way to know + # how many examples would actually have been in the batch, + # since it was empty, so actual_batch_size would be reported as 0. + actual_batch_size = flat_data_specs[0].np_batch_size(batch) + self.monitor.report_batch(actual_batch_size) + for callback in self.update_callbacks: + callback(self) + + + # Make sure none of the parameters have bad values + for param in self.params: + value = param.get_value(borrow=True) + if np.any(np.isnan(value)) or np.any(np.isinf(value)): + raise Exception("NaN in " + param.name) + + self.train_generator = not self.train_generator + + def continue_learning(self, model): + """ + Returns True if the algorithm should continue running, or False + if it has reached convergence / started overfitting and should + stop. + + Parameters + ---------- + model : a Model instance + """ + if self.termination_criterion is None: + return True + else: + return self.termination_criterion.continue_learning(self.model) + +class MonitorBasedLRAdjuster(TrainExtension): + """ + A TrainExtension that uses the on_monitor callback to adjust + the learning rate on each epoch. It pulls out a channel + from the model's monitor and adjusts the learning rate + based on what happened to the monitoring channel on the last + epoch. If the channel is greater than high_trigger times + its previous value, the learning rate will be scaled by + shrink_amt (which should be < 1 for this scheme to make + sense). The idea is that in this case the learning algorithm + is overshooting the bottom of the objective function. + + If the objective is less than high_trigger but + greater than low_trigger times its previous value, the + learning rate will be scaled by grow_amt (which should be > 1 + for this scheme to make sense). The idea is that the learning + algorithm is making progress but at too slow of a rate. + + Parameters + ---------- + high_trigger : float, optional + See class-level docstring + low_trigger : float, optional + See class-level docstring + grow_amt : float, optional + See class-level docstring + min_lr : float, optional + All updates to the learning rate are clipped to be at least + this value. + max_lr : float, optional + All updates to the learning rate are clipped to be at most + this value. + dataset_name : str, optional + If specified, use dataset_name + "_objective" as the channel + to guide the learning rate adaptation. + channel_name : str, optional + If specified, use channel_name as the channel to guide the + learning rate adaptation. Conflicts with dataset_name. + If neither dataset_name nor channel_name is specified, uses + "objective" + """ + + def __init__(self, high_trigger=1., shrink_amt=.99, + low_trigger=.99, grow_amt=1.01, + min_lr = 1e-7, max_lr = 1., + dataset_name=None, channel_name=None): + self.high_trigger = high_trigger + self.shrink_amt = shrink_amt + self.low_trigger = low_trigger + self.grow_amt = grow_amt + self.min_lr = min_lr + self.max_lr = max_lr + self.dataset_name = None + if channel_name is not None: + self.channel_name = channel_name + else: + if dataset_name is not None: + self.channel_name = dataset_name + '_objective' + self.dataset_name = dataset_name + else: + self.channel_name = None + + def on_monitor(self, model, dataset, algorithm): + """ + Adjusts the learning rate based on the contents of model.monitor + + Parameters + ---------- + model : a Model instance + dataset : Dataset + algorithm : WRITEME + """ + model = algorithm.model + lr = algorithm.learning_rate + current_learning_rate = lr.get_value() + assert hasattr(model, 'monitor'), ("no monitor associated with " + + str(model)) + monitor = model.monitor + monitor_channel_specified = True + + if self.channel_name is None: + monitor_channel_specified = False + channels = [elem for elem in monitor.channels + if elem.endswith("objective")] + if len(channels) < 1: + raise ValueError("There are no monitoring channels that end " + "with \"objective\". Please specify either " + "channel_name or dataset_name.") + elif len(channels) > 1: + datasets = algorithm.monitoring_dataset.keys() + raise ValueError("There are multiple monitoring channels that" + "end with \"_objective\". The list of available " + "datasets are: " + + str(datasets) + " . Please specify either " + "channel_name or dataset_name in the " + "MonitorBasedLRAdjuster constructor to " + 'disambiguate.') + else: + self.channel_name = channels[0] + warnings.warn('The channel that has been chosen for ' + 'monitoring is: ' + + str(self.channel_name) + '.') + + try: + v = monitor.channels[self.channel_name].val_record + except KeyError: + err_input = '' + if monitor_channel_specified: + if self.dataset_name: + err_input = 'The dataset_name \'' + str( + self.dataset_name) + '\' is not valid.' + else: + err_input = 'The channel_name \'' + str( + self.channel_name) + '\' is not valid.' + err_message = 'There is no monitoring channel named \'' + \ + str(self.channel_name) + '\'. You probably need to ' + \ + 'specify a valid monitoring channel by using either ' + \ + 'dataset_name or channel_name in the ' + \ + 'MonitorBasedLRAdjuster constructor. ' + err_input + raise ValueError(err_message) + + if len(v) < 1: + if monitor.dataset is None: + assert len(v) == 0 + raise ValueError("You're trying to use a monitor-based " + "learning rate adjustor but the monitor has no " + "entries because you didn't specify a " + "monitoring dataset.") + + raise ValueError("For some reason there are no monitor entries" + "yet the MonitorBasedLRAdjuster has been " + "called. This should never happen. The Train" + " object should call the monitor once on " + "initialization, then call the callbacks. " + "It seems you are either calling the " + "callback manually rather than as part of a " + "training algorithm, or there is a problem " + "with the Train object.") + if len(v) == 1: + #only the initial monitoring has happened + #no learning has happened, so we can't adjust the learning rate yet + #just do nothing + return + + rval = current_learning_rate + + log.info("monitoring channel is {0}".format(self.channel_name)) + + if v[-1] > self.high_trigger * v[-2]: + rval *= self.shrink_amt + log.info("shrinking learning rate to %f" % rval) + elif v[-1] > self.low_trigger * v[-2]: + rval *= self.grow_amt + log.info("growing learning rate to %f" % rval) + + rval = max(self.min_lr, rval) + rval = min(self.max_lr, rval) + + lr.set_value(np.cast[lr.dtype](rval)) + + +class PatienceBasedTermCrit(object): + """ + A monitor-based termination criterion using a geometrically increasing + amount of patience. If the selected channel has decreased by a certain + proportion when comparing to the lowest value seen yet, the patience is + set to a factor of the number of examples seen, which by default + (patience_increase=2.) ensures the model has seen as many examples as the + number of examples that lead to the lowest value before concluding a local + optima has been reached. + + Note: Technically, the patience corresponds to a number of epochs to be + independent of the size of the dataset, so be aware of that when choosing + initial_patience. + + Parameters + ---------- + prop_decrease : float + The factor X in the (1 - X) * best_value threshold + initial_patience : int + Minimal number of epochs the model has to run before it can stop + patience_increase : float, optional + The factor X in the patience = X * n_iter update. + channel_name : string, optional + Name of the channel to examine. If None and the monitor + has only one channel, this channel will be used; otherwise, an + error will be raised. + """ + def __init__(self, prop_decrease, initial_patience, + patience_increase=2., channel_name=None): + self._channel_name = channel_name + self.prop_decrease = prop_decrease + self.patience = initial_patience + self.best_value = np.inf + self.patience_increase = patience_increase + + def __call__(self, model): + """ + Returns True or False depending on whether the optimization should + stop or not. The optimization should stop if it has run for a number + of epochs superior to the patience without any improvement. + + Parameters + ---------- + model : Model + The model used in the experiment and from which the monitor used + in the termination criterion will be extracted. + + Returns + ------- + bool + True or False, indicating if the optimization should stop or not. + """ + monitor = model.monitor + # In the case the monitor has only one channel, the channel_name can + # be omitted and the criterion will examine the only channel + # available. However, if the monitor has multiple channels, leaving + # the channel_name unspecified will raise an error. + if self._channel_name is None: + if len(monitor.channels) != 1: + raise ValueError("Only single-channel monitors are supported " + "for channel_name == None") + v = monitor.channels.values()[0].val_record + else: + v = monitor.channels[self._channel_name].val_record + # If the channel value decrease is higher than the threshold, we + # update the best value to this value and we update the patience. + if v[-1] < self.best_value * (1. - self.prop_decrease): + # Using the max between actual patience and updated patience + # ensures that the model will run for at least the initial + # patience and that it would behave correctly if the user + # chooses a dumb value (i.e. less than 1) + self.patience = max(self.patience, len(v) * self.patience_increase) + self.best_value = v[-1] + + return len(v) < self.patience + + +class AnnealedLearningRate(object): + """ + This is a callback for the SGD algorithm rather than the Train object. + This anneals the learning rate to decrease as 1/t where t is the number + of gradient descent updates done so far. Use OneOverEpoch as Train object + callback if you would prefer 1/t where t is epochs. + + Parameters + ---------- + anneal_start : int + The epoch on which to begin annealing + """ + def __init__(self, anneal_start): + self._initialized = False + self._count = 0 + self._anneal_start = anneal_start + + def __call__(self, algorithm): + """ + Updates the learning rate according to the annealing schedule. + + Parameters + ---------- + algorithm : WRITEME + """ + if not self._initialized: + self._base = algorithm.learning_rate.get_value() + self._count += 1 + algorithm.learning_rate.set_value(self.current_learning_rate()) + + def current_learning_rate(self): + """ + Returns the current desired learning rate according to the + annealing schedule. + """ + return self._base * min(1, self._anneal_start / self._count) + +class ExponentialDecay(object): + """ + This is a callback for the `SGD` algorithm rather than the `Train` object. + This anneals the learning rate by dividing by decay_factor after each + gradient descent step. It will not shrink the learning rate beyond + `min_lr`. + + Parameters + ---------- + decay_factor : float + The learning rate at step t is given by + `init_learning_rate / (decay_factor ** t)` + min_lr : float + The learning rate will be clipped to be at least this value + """ + + def __init__(self, decay_factor, min_lr): + if isinstance(decay_factor, str): + decay_factor = float(decay_factor) + if isinstance(min_lr, str): + min_lr = float(min_lr) + assert isinstance(decay_factor, float) + assert isinstance(min_lr, float) + self.__dict__.update(locals()) + del self.self + self._count = 0 + self._min_reached = False + + def __call__(self, algorithm): + """ + Updates the learning rate according to the exponential decay schedule. + + Parameters + ---------- + algorithm : SGD + The SGD instance whose `learning_rate` field should be modified. + """ + if self._count == 0: + self._base_lr = algorithm.learning_rate.get_value() + self._count += 1 + + if not self._min_reached: + # If we keep on executing the exponentiation on each mini-batch, + # we will eventually get an OverflowError. So make sure we + # only do the computation until min_lr is reached. + new_lr = self._base_lr / (self.decay_factor ** self._count) + if new_lr <= self.min_lr: + self._min_reached = True + new_lr = self.min_lr + else: + new_lr = self.min_lr + + new_lr = np.cast[config.floatX](new_lr) + algorithm.learning_rate.set_value(new_lr) + +class LinearDecay(object): + """ + This is a callback for the SGD algorithm rather than the Train object. + This anneals the learning rate to decay_factor times of the initial value + during time start till saturate. + + Parameters + ---------- + start : int + The step at which to start decreasing the learning rate + saturate : int + The step at which to stop decreating the learning rate + decay_factor : float + `final learning rate = decay_factor * initial learning rate` + """ + + def __init__(self, start, saturate, decay_factor): + if isinstance(decay_factor, str): + decay_factor = float(decay_factor) + if isinstance(start, str): + start = float(start) + if isinstance(saturate, str): + saturate = float(saturate) + assert isinstance(decay_factor, float) + assert isinstance(start, (py_integer_types, py_float_types)) + assert isinstance(saturate, (py_integer_types, py_float_types)) + assert saturate > start + assert start > 0 + self.__dict__.update(locals()) + del self.self + self._count = 0 + + def __call__(self, algorithm): + """ + Adjusts the learning rate according to the linear decay schedule + + Parameters + ---------- + algorithm : WRITEME + """ + if self._count == 0: + self._base_lr = algorithm.learning_rate.get_value() + self._step = ((self._base_lr - self._base_lr * self.decay_factor) / + (self.saturate - self.start + 1)) + self._count += 1 + if self._count >= self.start: + if self._count < self.saturate: + new_lr = self._base_lr - self._step * (self._count + - self.start + 1) + else: + new_lr = self._base_lr * self.decay_factor + else: + new_lr = self._base_lr + assert new_lr > 0 + new_lr = np.cast[config.floatX](new_lr) + algorithm.learning_rate.set_value(new_lr) + + +def MomentumAdjustor(final_momentum, start, saturate): + """ + Deprecated class used with the deprecated init_momentum argument. + Use learning_rule.MomentumAdjustor instead. + + Parameters + ---------- + final_momentum : WRITEME + start : WRITEME + saturate : WRITEME + """ + warnings.warn("sgd.MomentumAdjustor interface is deprecated and will " + "become officially unsupported as of May 9, 2014. Please use " + "`learning_rule.MomentumAdjustor` instead.") + return LRMomentumAdjustor(final_momentum, start, saturate) + + +class OneOverEpoch(TrainExtension): + """ + Scales the learning rate like one over # epochs + + Parameters + ---------- + start : int + The epoch on which to start shrinking the learning rate + half_life : int, optional + How many epochs after start it will take for the learning rate to lose + half its value for the first time (to lose the next half of its value + will take twice as long) + min_lr : float, optional + The minimum value the learning rate can take on + """ + def __init__(self, start, half_life = None, min_lr = 1e-6): + self.__dict__.update(locals()) + del self.self + self._initialized = False + self._count = 0 + assert start >= 0 + if half_life is None: + self.half_life = start + 1 + else: + assert half_life > 0 + + def on_monitor(self, model, dataset, algorithm): + """ + Adjusts the learning rate according to the decay schedule. + + Parameters + ---------- + model : a Model instance + dataset : Dataset + algorithm : WRITEME + """ + + if not self._initialized: + self._init_lr = algorithm.learning_rate.get_value() + if self._init_lr < self.min_lr: + raise ValueError("The initial learning rate is smaller than " + + "the minimum allowed learning rate.") + self._initialized = True + self._count += 1 + algorithm.learning_rate.set_value(np.cast[config.floatX]( + self.current_lr())) + + def current_lr(self): + """ + Returns the learning rate currently desired by the decay schedule. + """ + if self._count < self.start: + scale = 1 + else: + scale = float(self.half_life) / float(self._count - self.start + + self.half_life) + lr = self._init_lr * scale + clipped = max(self.min_lr, lr) + return clipped + +class LinearDecayOverEpoch(TrainExtension): + """ + Scales the learning rate linearly on each epochs + + Parameters + ---------- + start : int + The epoch on which to start shrinking the learning rate + saturate : int + The epoch to saturate the shrinkage + decay_factor : float + The final value would be initial learning rate times decay_factor + """ + + def __init__(self, start, saturate, decay_factor): + self.__dict__.update(locals()) + del self.self + self._initialized = False + self._count = 0 + assert isinstance(decay_factor, float) + assert isinstance(start, (py_integer_types, py_float_types)) + assert isinstance(saturate, (py_integer_types, py_float_types)) + assert saturate > start + assert start >= 0 + assert saturate >= start + + def on_monitor(self, model, dataset, algorithm): + """ + Updates the learning rate based on the linear decay schedule. + + Parameters + ---------- + model : a Model instance + dataset : Dataset + algorithm : WRITEME + """ + if not self._initialized: + self._init_lr = algorithm.learning_rate.get_value() + self._step = ((self._init_lr - self._init_lr * self.decay_factor) / + (self.saturate - self.start + 1)) + self._initialized = True + self._count += 1 + algorithm.learning_rate.set_value(np.cast[config.floatX]( + self.current_lr())) + + def current_lr(self): + """ + Returns the learning rate currently desired by the decay schedule. + """ + if self._count >= self.start: + if self._count < self.saturate: + new_lr = self._init_lr - self._step * (self._count + - self.start + 1) + else: + new_lr = self._init_lr * self.decay_factor + else: + new_lr = self._init_lr + assert new_lr > 0 + return new_lr + +class _PolyakWorker(object): + """ + Only to be used by the PolyakAveraging TrainingCallback below. + Do not use directly. + A callback for the SGD class. + + Parameters + ---------- + model : a Model + The model whose parameters we want to train with Polyak averaging + """ + + def __init__(self, model): + avg_updates = OrderedDict() + t = sharedX(1.) + self.param_to_mean = OrderedDict() + for param in model.get_params(): + mean = sharedX(param.get_value()) + assert type(mean) == type(param) + self.param_to_mean[param] = mean + avg_updates[mean] = mean - (mean - param) / t + avg_updates[t] = t + 1. + self.avg = function([], updates = avg_updates) + + def __call__(self, algorithm): + """ + To be called after each SGD step. + Updates the Polyak averaged-parameters for this model + + Parameters + ---------- + algorithm : WRITEME + """ + self.avg() + +class PolyakAveraging(TrainExtension): + """ + See "A Tutorial on Stochastic Approximation Algorithms + for Training Restricted Boltzmann Machines and + Deep Belief Nets" by Kevin Swersky et al + + This functionality is still a work in progress. Currently, + your model needs to implement "add_polyak_channels" to + use it. + + The problem is that Polyak averaging shouldn't modify + the model parameters. It should keep a second copy + that it averages in the background. This second copy + doesn't get to come back in and affect the learning process + though. + + (IG tried having the second copy get pushed back into + the model once per epoch, but this turned out to be + harmful, at least in limited tests) + + So we need a cleaner interface for monitoring the + averaged copy of the parameters, and we need to make + sure the saved model at the end uses the averaged + parameters, not the parameters used for computing + the gradients during training. + + TODO: make use of the new on_save callback instead + of duplicating Train's save_freq flag + + Parameters + ---------- + start : int + The epoch after which to start averaging (0 = start averaging + immediately) + save_path : str, optional + WRITEME + save_freq : int, optional + WRITEME + + Notes + ----- + This is usually used with a fixed, rather than annealed learning + rate. It may be used in conjunction with momentum. + """ + + def __init__(self, start, save_path=None, save_freq=1): + self.__dict__.update(locals()) + del self.self + self._count = 0 + assert isinstance(start, py_integer_types) + assert start >= 0 + + def on_monitor(self, model, dataset, algorithm): + """ + Make sure Polyak-averaged model gets monitored. + Save the model if necessary. + + Parameters + ---------- + model : a Model instance + dataset : Dataset + algorithm : WRITEME + """ + if self._count == self.start: + self._worker = _PolyakWorker(model) + algorithm.update_callbacks.append(self._worker) + #HACK + try: + model.add_polyak_channels(self._worker.param_to_mean, + algorithm.monitoring_dataset) + except AttributeError: + pass + elif self.save_path is not None and self._count > self.start and \ + self._count % self.save_freq == 0: + saved_params = OrderedDict() + for param in model.get_params(): + saved_params[param] = param.get_value() + param.set_value(self._worker.param_to_mean[param].get_value()) + serial.save(self.save_path, model) + for param in model.get_params(): + param.set_value(saved_params[param]) + self._count += 1 diff --git a/show_gen_weights.py b/show_gen_weights.py new file mode 100644 index 0000000..918cdb0 --- /dev/null +++ b/show_gen_weights.py @@ -0,0 +1,59 @@ +import sys +from pylearn2.gui.patch_viewer import make_viewer +from pylearn2.utils import serial +model = serial.load(sys.argv[1]) +generator = model.generator + +final = generator.mlp.layers[-1] +success = False + +i = -1 +success = False +to_search = generator.mlp +while not success: + print "while loop ", i + final = to_search.layers[i] + if 'Composite' in str(type(final)): + i = input("which") + elem = final.layers[i] + if hasattr(elem, 'layers'): + print "stepping into inner MLP" + i = -1 + to_search = elem + continue + else: + print "examining this element" + final = elem + + try: + print "Trying get_weights topo" + topo = final.get_weights_topo() + print "It worked" + success = True + except Exception: + pass + + if success: + print "Making the viewer and showing" + make_viewer(topo).show() + quit() + + try: + print "Trying get_weights" + weights = final.get_weights() + print "It worked" + success = True + except NotImplementedError: + i -= 1 # skip over SpaceConverter, etc. +print "Out of the while loop" + + +print "weights shape ", weights.shape +viewer = make_viewer(weights, is_color=weights.shape[1] % 3 == 0 and weights.shape[1] != 48*48) +print "image shape ", viewer.image.shape + +print "made viewer" + +viewer.show() + +print "executed show" diff --git a/show_inpaint_samples.py b/show_inpaint_samples.py new file mode 100644 index 0000000..a99a581 --- /dev/null +++ b/show_inpaint_samples.py @@ -0,0 +1,32 @@ +from pylearn2.utils import serial +import sys +_, model_path = sys.argv +model = serial.load(model_path) +from pylearn2.gui.patch_viewer import make_viewer +space = model.generator.get_output_space() +from pylearn2.config import yaml_parse +import numpy as np + +dataset = yaml_parse.load(model.dataset_yaml_src) +dataset = dataset.get_test_set() + +grid_shape = None + +from pylearn2.utils import sharedX +X = sharedX(dataset.get_batch_topo(100)) +samples, ignore = model.generator.inpainting_sample_and_noise(X) +samples = samples.eval() +total_dimension = space.get_total_dimension() +num_colors = 1 +if total_dimension % 3 == 0: + num_colors = 3 +w = int(np.sqrt(total_dimension / num_colors)) +from pylearn2.space import Conv2DSpace +desired_space = Conv2DSpace(shape=[w, w], num_channels=num_colors, axes=('b',0,1,'c')) +is_color = samples.shape[-1] == 3 +print (samples.min(), samples.mean(), samples.max()) +# Hack for detecting MNIST [0, 1] values. Otherwise we assume centered images +if samples.min() >0: + samples = samples * 2.0 - 1.0 +viewer = make_viewer(samples, grid_shape=grid_shape, is_color=is_color) +viewer.show() diff --git a/show_samples.py b/show_samples.py new file mode 100644 index 0000000..ae0192c --- /dev/null +++ b/show_samples.py @@ -0,0 +1,50 @@ +from pylearn2.utils import serial +import sys +_, model_path = sys.argv +model = serial.load(model_path) +from pylearn2.gui.patch_viewer import make_viewer +space = model.generator.get_output_space() +from pylearn2.space import VectorSpace +from pylearn2.config import yaml_parse +import numpy as np + +match_train = True +if match_train: + dataset = yaml_parse.load(model.dataset_yaml_src) + +grid_shape = None + +if isinstance(space, VectorSpace): + # For some reason format_as from VectorSpace is not working right + samples = model.generator.sample(100).eval() + + if match_train: + grid_shape = (10, 20) + matched = np.zeros((samples.shape[0] * 2, samples.shape[1])) + X = dataset.X + for i in xrange(samples.shape[0]): + matched[2 * i, :] = samples[i, :].copy() + dists = np.square(X - samples[i, :]).sum(axis=1) + j = np.argmin(dists) + matched[2 * i + 1, :] = X[j, :] + samples = matched + + is_color = samples.shape[-1] % 3 == 0 and samples.shape[-1] != 48 * 48 +else: + total_dimension = space.get_total_dimension() + import numpy as np + num_colors = 1 + if total_dimension % 3 == 0: + num_colors = 3 + w = int(np.sqrt(total_dimension / num_colors)) + from pylearn2.space import Conv2DSpace + desired_space = Conv2DSpace(shape=[w, w], num_channels=num_colors, axes=('b',0,1,'c')) + samples = space.format_as(batch=model.generator.sample(100), + space=desired_space).eval() + is_color = samples.shape[-1] == 3 +print (samples.min(), samples.mean(), samples.max()) +# Hack for detecting MNIST [0, 1] values. Otherwise we assume centered images +if samples.min() >0: + samples = samples * 2.0 - 1.0 +viewer = make_viewer(samples, grid_shape=grid_shape, is_color=is_color) +viewer.show() diff --git a/show_samples_cifar_conv_paper.py b/show_samples_cifar_conv_paper.py new file mode 100644 index 0000000..358989c --- /dev/null +++ b/show_samples_cifar_conv_paper.py @@ -0,0 +1,44 @@ +from pylearn2.utils import serial +import sys +_, model_path = sys.argv +model = serial.load(model_path) +space = model.generator.get_output_space() +from pylearn2.config import yaml_parse +from pylearn2.gui.patch_viewer import PatchViewer +import numpy as np + +dataset = yaml_parse.load(model.dataset_yaml_src) + +grid_shape = None + +rows = 4 +sample_cols = 5 + +# For some reason format_as from VectorSpace is not working right +topo_samples = model.generator.sample(rows * sample_cols).eval() +samples = dataset.get_design_matrix(topo_samples) +dataset.axes = ['b', 0, 1, 'c'] +dataset.view_converter.axes = ['b', 0, 1, 'c'] +topo_samples = dataset.get_topological_view(samples) + +pv = PatchViewer(grid_shape=(rows, sample_cols + 1), patch_shape=(32,32), + is_color=True) +scale = np.abs(samples).max() + +X = dataset.X +topo = dataset.get_topological_view() +index = 0 +for i in xrange(samples.shape[0]): + topo_sample = topo_samples[i, :, :, :] + print topo_sample.min(), topo_sample.max() + pv.add_patch(topo_sample / scale, rescale=False) + + if (i +1) % sample_cols == 0: + sample = samples[i, :] + dists = np.square(X - sample).sum(axis=1) + j = np.argmin(dists) + match = topo[j, :] + print match.min(), match.max() + pv.add_patch(match / scale, rescale=False, activation=1) + +pv.show() diff --git a/show_samples_cifar_full_paper.py b/show_samples_cifar_full_paper.py new file mode 100644 index 0000000..aaaedf5 --- /dev/null +++ b/show_samples_cifar_full_paper.py @@ -0,0 +1,41 @@ +from pylearn2.utils import serial +import sys +_, model_path = sys.argv +model = serial.load(model_path) +space = model.generator.get_output_space() +from pylearn2.config import yaml_parse +from pylearn2.gui.patch_viewer import PatchViewer +import numpy as np + +dataset = yaml_parse.load(model.dataset_yaml_src) + +grid_shape = None + +rows = 4 +sample_cols = 5 + +# For some reason format_as from VectorSpace is not working right +samples = model.generator.sample(rows * sample_cols).eval() +topo_samples = dataset.get_topological_view(samples) + +pv = PatchViewer(grid_shape=(rows, sample_cols + 1), patch_shape=(32,32), + is_color=True) +scale = np.abs(samples).max() + +X = dataset.X +topo = dataset.get_topological_view() +index = 0 +for i in xrange(samples.shape[0]): + topo_sample = topo_samples[i, :, :, :] + print topo_sample.min(), topo_sample.max() + pv.add_patch(topo_sample / scale, rescale=False) + + if (i +1) % sample_cols == 0: + sample = samples[i, :] + dists = np.square(X - sample).sum(axis=1) + j = np.argmin(dists) + match = topo[j, :] + print match.min(), match.max() + pv.add_patch(match / scale, rescale=False, activation=1) + +pv.show() diff --git a/show_samples_inpaint.py b/show_samples_inpaint.py new file mode 100644 index 0000000..4fce920 --- /dev/null +++ b/show_samples_inpaint.py @@ -0,0 +1,57 @@ +import theano +from pylearn2.utils import serial +import sys +from pylearn2.gui.patch_viewer import make_viewer +from pylearn2.space import VectorSpace +from pylearn2.config import yaml_parse +import numpy as np +import ipdb + + +# TODO, only works for CIFAR10 for now + +grid_shape = None +repeat_samples = 1 +num_samples = 5 + + +_, model_path = sys.argv +model = serial.load(model_path) +rng = np.random.RandomState(20232) + +def get_data_samples(dataset, n = num_samples): + unique_y = np.unique(dataset.y) + rval = [] + for y in np.unique(dataset.y): + ind = np.where(dataset.y == y)[0] + ind = ind[rng.randint(0, len(ind), n)] + rval.append(dataset.get_topological_view()[ind]) + + return np.concatenate(rval) + +dataset = yaml_parse.load(model.dataset_yaml_src) +dataset = dataset.get_test_set() +data = get_data_samples(dataset) + +output_space = model.generator.get_output_space() +input_space = model.generator.mlp.input_space + +X = input_space.get_theano_batch() +samples, _ = model.generator.inpainting_sample_and_noise(X) +f = theano.function([X], samples) + +samples = [] +for i in xrange(repeat_samples): + samples.append(f(data)) + +samples = np.concatenate(samples) + +is_color = True + + +print (samples.min(), samples.mean(), samples.max()) +# Hack for detecting MNIST [0, 1] values. Otherwise we assume centered images +if samples.min() >0: + samples = samples * 2.0 - 1.0 +viewer = make_viewer(samples, grid_shape=grid_shape, is_color=is_color) +viewer.show() diff --git a/show_samples_mnist_paper.py b/show_samples_mnist_paper.py new file mode 100644 index 0000000..05d4c11 --- /dev/null +++ b/show_samples_mnist_paper.py @@ -0,0 +1,39 @@ +from pylearn2.utils import serial +import sys +_, model_path = sys.argv +model = serial.load(model_path) +from pylearn2.gui.patch_viewer import make_viewer +space = model.generator.get_output_space() +from pylearn2.config import yaml_parse +from pylearn2.gui.patch_viewer import PatchViewer +import numpy as np + +dataset = yaml_parse.load(model.dataset_yaml_src) + +grid_shape = None + +rows = 4 +sample_cols = 5 + +# For some reason format_as from VectorSpace is not working right +samples = model.generator.sample(rows * sample_cols).eval() +topo_samples = dataset.get_topological_view(samples) + +pv = PatchViewer(grid_shape=(rows, sample_cols + 1), patch_shape=(28,28), + is_color=False) + +X = dataset.X +topo = dataset.get_topological_view() +index = 0 +for i in xrange(samples.shape[0]): + topo_sample = topo_samples[i, :, :, :] + pv.add_patch(topo_sample * 2. - 1., rescale=False) + + if (i +1) % sample_cols == 0: + sample = samples[i, :] + dists = np.square(X - sample).sum(axis=1) + j = np.argmin(dists) + match = topo[j, :] + pv.add_patch(match * 2 -1, rescale=False, activation=1) + +pv.show() diff --git a/show_samples_tfd.py b/show_samples_tfd.py new file mode 100644 index 0000000..8125718 --- /dev/null +++ b/show_samples_tfd.py @@ -0,0 +1,19 @@ +from pylearn2.utils import serial +import sys +_, model_path = sys.argv +model = serial.load(model_path) +from pylearn2.gui.patch_viewer import make_viewer +space = model.generator.get_output_space() +total_dimension = space.get_total_dimension() +import numpy as np +num_colors = 1 +#if total_dimension % 3 == 0: +# num_colors = 3 +w = int(np.sqrt(total_dimension / num_colors)) +from pylearn2.space import Conv2DSpace +desired_space = Conv2DSpace(shape=[w, w], num_channels=num_colors, axes=('b',0,1,'c')) +samples = space.format_as(batch=model.generator.sample(100), + space=desired_space).eval() +print (samples.min(), samples.mean(), samples.max()) +viewer = make_viewer(samples * 2.0 - 1.0) +viewer.show() diff --git a/show_samples_tfd_paper.py b/show_samples_tfd_paper.py new file mode 100644 index 0000000..4bf7da1 --- /dev/null +++ b/show_samples_tfd_paper.py @@ -0,0 +1,39 @@ +from pylearn2.utils import serial +import sys +_, model_path = sys.argv +model = serial.load(model_path) +from pylearn2.gui.patch_viewer import make_viewer +space = model.generator.get_output_space() +from pylearn2.config import yaml_parse +from pylearn2.gui.patch_viewer import PatchViewer +import numpy as np + +dataset = yaml_parse.load(model.dataset_yaml_src) + +grid_shape = None + +rows = 4 +sample_cols = 5 + +# For some reason format_as from VectorSpace is not working right +samples = model.generator.sample(rows * sample_cols).eval() +topo_samples = dataset.get_topological_view(samples) + +pv = PatchViewer(grid_shape=(rows, sample_cols + 1), patch_shape=(48,48), + is_color=False) + +X = dataset.X +topo = dataset.get_topological_view() +index = 0 +for i in xrange(samples.shape[0]): + topo_sample = topo_samples[i, :, :, :] + pv.add_patch(topo_sample * 2. - 1., rescale=False) + + if (i +1) % sample_cols == 0: + sample = samples[i, :] + dists = np.square(X - sample).sum(axis=1) + j = np.argmin(dists) + match = topo[j, :] + pv.add_patch(match * 2 -1, rescale=False, activation=1) + +pv.show() diff --git a/test_deconv.py b/test_deconv.py new file mode 100644 index 0000000..ef7be94 --- /dev/null +++ b/test_deconv.py @@ -0,0 +1,52 @@ +""" +This script visually test the deconv layer. +Construct an MLP with conv ,and deconv layer, +set their W to same values and show the original +input and the output of the mlp side by side. +They are supposed to look same. +""" + + +import theano +from adversarial.deconv import Deconv +from pylearn2.datasets.mnist import MNIST +from pylearn2.space import Conv2DSpace +from pylearn2.models.mlp import MLP +from pylearn2.models.maxout import MaxoutConvC01B +from pylearn2.gui import patch_viewer +import ipdb + + +input_space = Conv2DSpace(shape = (28, 28), num_channels=1, axes = ('c', 0, 1, 'b')) +conv = MaxoutConvC01B(layer_name = 'conv', + num_channels = 16, + num_pieces = 1, + kernel_shape = (4, 4), + pool_shape = (1, 1), + pool_stride=(1, 1), + irange = 0.05) +deconv = Deconv(layer_name = 'deconv', + num_channels = 1, + kernel_shape = (4, 4), + irange = 0.05) + +mlp = MLP(input_space =input_space, + layers = [conv, deconv]) + +mlp.layers[1].transformer._filters.set_value(mlp.layers[0].transformer._filters.get_value()) + +x = input_space.get_theano_batch() +out = mlp.fprop(x) +f = theano.function([x], out) + +data = MNIST('test') +data_specs = (input_space, 'features') +iter = data.iterator(mode = 'sequential', batch_size = 2, data_specs = data_specs) +pv = patch_viewer.PatchViewer((10, 10), (28, 28)) +for item in iter: + res = f(item) + pv.add_patch(item[0,:,:,0]) + pv.add_patch(res[0,:,:,0]) + pv.show() + break + diff --git a/tfd_pretrain/pretrain.yaml b/tfd_pretrain/pretrain.yaml new file mode 100644 index 0000000..4f4ef2e --- /dev/null +++ b/tfd_pretrain/pretrain.yaml @@ -0,0 +1,115 @@ +!obj:pylearn2.train.Train { + dataset: &train !obj:pylearn2.datasets.tfd.TFD { + which_set: 'unlabeled', + scale: True, + }, + model: !obj:adversarial.AdversaryPair { + generator: !obj:adversarial.Generator { + monitor_ll: 1, + mlp: !obj:pylearn2.models.mlp.MLP { + layers: [ + !obj:pylearn2.models.mlp.RectifiedLinear { + layer_name: 'h0', + dim: 8000, + irange: .05, + max_col_norm: 1.9365, + }, + !obj:pylearn2.models.mlp.Sigmoid { + layer_name: 'h1', + dim: 8000, + irange: .05, + max_col_norm: 1.9365, + init_bias: -2.0, + }, + !obj:pylearn2.models.mlp.Sigmoid { + max_col_norm: 1.9365, + init_bias: !obj:pylearn2.models.dbm.init_sigmoid_bias_from_marginals { dataset: *train}, + layer_name: 'y', + sparse_init: 100, + dim: 2304 + } + ], + nvis: 100, + }}, + discriminator: + !obj:pylearn2.models.mlp.MLP { + layers: [ + !obj:pylearn2.models.maxout.Maxout { + #W_lr_scale: .1, + #b_lr_scale: .1, + layer_name: 'h0', + num_units: 1200, + num_pieces: 5, + irange: .005, + max_col_norm: 1.9365, + }, + !obj:pylearn2.models.maxout.Maxout { + #W_lr_scale: .1, + #b_lr_scale: .1, + layer_name: 'h1', + num_units: 1200, + num_pieces: 5, + irange: .005, + max_col_norm: 1.9365, + }, + !obj:pylearn2.models.mlp.Sigmoid { + #W_lr_scale: .1, + #b_lr_scale: .1, + max_col_norm: 1.9365, + layer_name: 'y', + dim: 1, + irange: .005 + } + ], + nvis: 2304, + }, + }, + algorithm: !obj:pylearn2.training_algorithms.sgd.SGD { + batch_size: 100, + learning_rate: .05, + learning_rule: !obj:pylearn2.training_algorithms.learning_rule.Momentum { + init_momentum: .5, + }, + monitoring_dataset: + { + # 'train' : *train, + 'valid' : !obj:pylearn2.datasets.tfd.TFD { + which_set: 'valid', + scale: True, + }, + # 'test' : !obj:pylearn2.datasets.tfd.TFD { + # which_set: 'test', + # scale: True, + # } + }, + cost: !obj:adversarial.AdversaryCost2 { + scale_grads: 0, + #target_scale: 1., + discriminator_default_input_include_prob: .5, + discriminator_input_include_probs: { + 'h0': .8 + }, + discriminator_default_input_scale: 2., + discriminator_input_scales: { + 'h0': 1.25 + } + }, + #!obj:pylearn2.costs.mlp.dropout.Dropout { + # input_include_probs: { 'h0' : .8 }, + # input_scales: { 'h0': 1. } + #}, + update_callbacks: !obj:pylearn2.training_algorithms.sgd.ExponentialDecay { + decay_factor: 1.000004, + min_lr: .000001 + } + }, + extensions: [ + !obj:pylearn2.training_algorithms.learning_rule.MomentumAdjustor { + start: 1, + saturate: 250, + final_momentum: .7 + } + ], + save_path: "${PYLEARN2_TRAIN_FILE_FULL_STEM}.pkl", + save_freq: 1 +} diff --git a/tfd_pretrain/train.yaml b/tfd_pretrain/train.yaml new file mode 100644 index 0000000..ad36006 --- /dev/null +++ b/tfd_pretrain/train.yaml @@ -0,0 +1,112 @@ +!obj:pylearn2.train.Train { + dataset: &train !obj:pylearn2.datasets.tfd.TFD { + which_set: 'unlabeled', + scale: True, + }, + model: !obj:adversarial.AdversaryPair { + generator: !obj:adversarial.Generator { + monitor_ll: 1, + mlp: !obj:adversarial.add_layers { + mlp: !obj:pylearn2.models.mlp.MLP { + layers: [ + !obj:pylearn2.models.mlp.RectifiedLinear { + layer_name: 'h0', + dim: 8000, + irange: .05, + max_col_norm: 1.9365, + }, + !obj:pylearn2.models.mlp.Sigmoid { + layer_name: 'h1', + dim: 100, + irange: .05, + max_col_norm: 1.9365, + init_bias: -2.0, + }, + ], + nvis: 100, + }, + pretrained: "./pretrain.pkl", + } + }, + discriminator: + !obj:pylearn2.models.mlp.MLP { + layers: [ + !obj:pylearn2.models.maxout.Maxout { + #W_lr_scale: .1, + #b_lr_scale: .1, + layer_name: 'h0', + num_units: 1200, + num_pieces: 5, + irange: .005, + max_col_norm: 1.9365, + }, + !obj:pylearn2.models.maxout.Maxout { + #W_lr_scale: .1, + #b_lr_scale: .1, + layer_name: 'h1', + num_units: 1200, + num_pieces: 5, + irange: .005, + max_col_norm: 1.9365, + }, + !obj:pylearn2.models.mlp.Sigmoid { + #W_lr_scale: .1, + #b_lr_scale: .1, + max_col_norm: 1.9365, + layer_name: 'y', + dim: 1, + irange: .005 + } + ], + nvis: 2304, + }, + }, + algorithm: !obj:pylearn2.training_algorithms.sgd.SGD { + batch_size: 100, + learning_rate: .05, + learning_rule: !obj:pylearn2.training_algorithms.learning_rule.Momentum { + init_momentum: .5, + }, + monitoring_dataset: + { + # 'train' : *train, + 'valid' : !obj:pylearn2.datasets.tfd.TFD { + which_set: 'valid', + scale: True, + }, + # 'test' : !obj:pylearn2.datasets.tfd.TFD { + # which_set: 'test', + # scale: True, + # } + }, + cost: !obj:adversarial.AdversaryCost2 { + scale_grads: 0, + #target_scale: 1., + discriminator_default_input_include_prob: .5, + discriminator_input_include_probs: { + 'h0': .8 + }, + discriminator_default_input_scale: 2., + discriminator_input_scales: { + 'h0': 1.25 + } + }, + #!obj:pylearn2.costs.mlp.dropout.Dropout { + # input_include_probs: { 'h0' : .8 }, + # input_scales: { 'h0': 1. } + #}, + update_callbacks: !obj:pylearn2.training_algorithms.sgd.ExponentialDecay { + decay_factor: 1.000004, + min_lr: .000001 + } + }, + extensions: [ + !obj:pylearn2.training_algorithms.learning_rule.MomentumAdjustor { + start: 1, + saturate: 250, + final_momentum: .7 + } + ], + save_path: "${PYLEARN2_TRAIN_FILE_FULL_STEM}.pkl", + save_freq: 1 +}