|
- """
- Extensions called during training to generate samples and diagnostic plots and printouts.
- """
-
- import matplotlib
- matplotlib.use('agg')
- import matplotlib.pyplot as plt
- import numpy as np
- import os
- import theano.tensor as T
- import theano
-
- from blocks.extensions import SimpleExtension
-
- import viz
- import sampler
-
-
- class PlotSamples(SimpleExtension):
- def __init__(self, model, algorithm, X, path, n_samples=49, **kwargs):
- """
- Generate samples from the model. The do() function is called as an extension during training.
- Generates 3 types of samples:
- - Sample from generative model
- - Sample from image denoising posterior distribution (default signal to noise of 1)
- - Sample from image inpainting posterior distribution (inpaint left half of image)
- """
-
- super(PlotSamples, self).__init__(**kwargs)
- self.model = model
- self.path = path
- n_samples = np.min([n_samples, X.shape[0]])
- self.X = X[:n_samples].reshape(
- (n_samples, model.n_colors, model.spatial_width, model.spatial_width))
- self.n_samples = n_samples
- X_noisy = T.tensor4('X noisy samp', dtype=theano.config.floatX)
- t = T.matrix('t samp', dtype=theano.config.floatX)
- self.get_mu_sigma = theano.function([X_noisy, t], model.get_mu_sigma(X_noisy, t),
- allow_input_downcast=True)
-
- def do(self, callback_name, *args):
-
- import sys
- sys.setrecursionlimit(10000000)
-
- print "generating samples"
- base_fname_part1 = self.path + '/samples-'
- base_fname_part2 = '_batch%06d'%self.main_loop.status['iterations_done']
- sampler.generate_samples(self.model, self.get_mu_sigma,
- n_samples=self.n_samples, inpaint=False, denoise_sigma=None, X_true=None,
- base_fname_part1=base_fname_part1, base_fname_part2=base_fname_part2)
- sampler.generate_samples(self.model, self.get_mu_sigma,
- n_samples=self.n_samples, inpaint=True, denoise_sigma=None, X_true=self.X,
- base_fname_part1=base_fname_part1, base_fname_part2=base_fname_part2)
- sampler.generate_samples(self.model, self.get_mu_sigma,
- n_samples=self.n_samples, inpaint=False, denoise_sigma=1, X_true=self.X,
- base_fname_part1=base_fname_part1, base_fname_part2=base_fname_part2)
-
-
- class PlotParameters(SimpleExtension):
- def __init__(self, model, blocks_model, path, **kwargs):
- super(PlotParameters, self).__init__(**kwargs)
- self.path = path
- self.model = model
- self.blocks_model = blocks_model
-
- def do(self, callback_name, *args):
-
- import sys
- sys.setrecursionlimit(10000000)
-
- print "plotting parameters"
- for param in self.blocks_model.parameters:
- param_name = param.name
- filename_safe_name = '-'.join(param_name.split('/')[2:]).replace(' ', '_')
- base_fname_part1 = self.path + '/params-' + filename_safe_name
- base_fname_part2 = '_batch%06d'%self.main_loop.status['iterations_done']
- viz.plot_parameter(param.get_value(), base_fname_part1, base_fname_part2,
- title=param_name, n_colors=self.model.n_colors)
-
-
- class PlotGradients(SimpleExtension):
- def __init__(self, model, blocks_model, algorithm, X, path, **kwargs):
- super(PlotGradients, self).__init__(**kwargs)
- self.path = path
- self.X = X
- self.model = model
- self.blocks_model = blocks_model
- gradients = []
- for param_name in sorted(self.blocks_model.parameters.keys()):
- gradients.append(algorithm.gradients[self.blocks_model.parameters[param_name]])
- self.grad_f = theano.function(algorithm.inputs, gradients, allow_input_downcast=True)
-
- def do(self, callback_name, *args):
- print "plotting gradients"
- grad_vals = self.grad_f(self.X)
- keynames = sorted(self.blocks_model.parameters.keys())
- for ii in xrange(len(keynames)):
- param_name = keynames[ii]
- val = grad_vals[ii]
- filename_safe_name = '-'.join(param_name.split('/')[2:]).replace(' ', '_')
- base_fname_part1 = self.path + '/grads-' + filename_safe_name
- base_fname_part2 = '_batch%06d'%self.main_loop.status['iterations_done']
- viz.plot_parameter(val, base_fname_part1, base_fname_part2,
- title="grad " + param_name, n_colors=self.model.n_colors)
-
-
- class PlotInternalState(SimpleExtension):
- def __init__(self, model, blocks_model, state, features, X, path, **kwargs):
- super(PlotInternalState, self).__init__(**kwargs)
- self.path = path
- self.X = X
- self.model = model
- self.blocks_model = blocks_model
- self.internal_state_f = theano.function([features], state, allow_input_downcast=True)
- self.internal_state_names = []
- for var in state:
- self.internal_state_names.append(var.name)
-
- def do(self, callback_name, *args):
- print "plotting internal state of network"
- state = self.internal_state_f(self.X)
- for ii in xrange(len(state)):
- param_name = self.internal_state_names[ii]
- val = state[ii]
- filename_safe_name = param_name.replace(' ', '_').replace('/', '-')
- base_fname_part1 = self.path + '/state-' + filename_safe_name
- base_fname_part2 = '_batch%06d'%self.main_loop.status['iterations_done']
- viz.plot_parameter(val, base_fname_part1, base_fname_part2,
- title="state " + param_name, n_colors=self.model.n_colors)
-
-
- class PlotMonitors(SimpleExtension):
- def __init__(self, path, burn_in_iters=0, **kwargs):
- super(PlotMonitors, self).__init__(**kwargs)
- self.path = path
- self.burn_in_iters = burn_in_iters
-
- def do(self, callback_name, *args):
- print "plotting monitors"
- try:
- df = self.main_loop.log.to_dataframe()
- except AttributeError:
- # This starting breaking after a Blocks update.
- print "Failed to generate monitoring plots due to Blocks interface change."
- return
- iter_number = df.tail(1).index
- # Throw out the first burn_in values
- # as the objective is often much larger
- # in that period.
- if iter_number > self.burn_in_iters:
- df = df.loc[self.burn_in_iters:]
- cols = [col for col in df.columns if col.startswith(('cost', 'train', 'test'))]
- df = df[cols].interpolate(method='linear')
-
- # If we don't have any non-nan dataframes, don't plot
- if len(df) == 0:
- return
- try:
- axs = df.interpolate(method='linear').plot(
- subplots=True, legend=False, figsize=(5, len(cols)*2))
- except TypeError:
- # This starting breaking after a different Blocks update.
- print "Failed to generate monitoring plots due to Blocks interface change."
- return
-
- for ax, cname in zip(axs, cols):
- ax.set_title(cname)
- fn = os.path.join(self.path,
- 'monitors_subplots_batch%06d.png' % self.main_loop.status['iterations_done'])
- plt.savefig(fn, bbox_inches='tight')
-
- plt.clf()
- df.plot(subplots=False, figsize=(15,10))
- plt.gcf().tight_layout()
- fn = os.path.join(self.path,
- 'monitors_batch%06d.png' % self.main_loop.status['iterations_done'])
- plt.savefig(fn, bbox_inches='tight')
- plt.close('all')
-
-
- class LogLikelihood(SimpleExtension):
- def __init__(self, model, test_stream, rescale, num_eval_batches=10000, **kwargs):
- """
- Compute and print log likelihood lower bound on test dataset.
- The do() function is called as an extension during training.
- """
- super(LogLikelihood, self).__init__(**kwargs)
- self.model = model
- self.test_stream = test_stream
- self.rescale = rescale
- self.num_eval_batches = num_eval_batches
-
- features = T.matrix('features', dtype=theano.config.floatX)
- cost = self.model.cost(features)
-
- self.L_gap_func = theano.function([features,], cost,
- allow_input_downcast=True)
-
- def print_stats(self, L_gap):
- larr = np.array(L_gap)
- mn = np.mean(larr)
- sd = np.std(larr, ddof=1)
- stderr = sd / np.sqrt(len(L_gap))
-
- # The log likelihood lower bound, K, is reported for the data after Z-scoring it.
- # Z-score rescale is the multiplicative factor by which the data was rescaled, to
- # give it standard deviation 1.
- print "eval batch=%05d (K-L_null)=%g bits/pix standard error=%g bits/pix Z-score rescale %g"%(
- len(L_gap), mn, stderr, self.rescale)
-
- def do(self, callback_name, *args):
- L_gap = []
- n_colors = self.model.n_colors
-
- Xiter = None
- for kk in xrange(self.num_eval_batches):
- try:
- X = next(Xiter)[0]
- except:
- Xiter = self.test_stream.get_epoch_iterator()
- X = next(Xiter)[0]
-
- lg = -self.L_gap_func(X)
- L_gap.append(lg)
-
- if np.mod(kk, 1000) == 999:
- self.print_stats(L_gap)
- self.print_stats(L_gap)
-
-
- def decay_learning_rate(iteration, old_value):
- # TODO the numbers in this function should not be hard coded
-
- # this is called every epoch
- # reduce the learning rate by 10 every 1000 epochs
- min_value = 1e-4
-
- decay_rate = np.exp(np.log(0.1)/1000.)
- new_value = decay_rate*old_value
- if new_value < min_value:
- new_value = min_value
- print "learning rate %g"%new_value
- return np.float32(new_value)
|