Source code for ava.models.vae

"""
A Variational Autoencoder (VAE) for spectrogram data.

VAE References
--------------
.. [1] Kingma, Diederik P., and Max Welling. "Auto-encoding variational bayes."
	arXiv preprint arXiv:1312.6114 (2013).

	`<https://arxiv.org/abs/1312.6114>`_


.. [2] Rezende, Danilo Jimenez, Shakir Mohamed, and Daan Wierstra. "Stochastic
	backpropagation and approximate inference in deep generative models." arXiv
	preprint arXiv:1401.4082 (2014).

	`<https://arxiv.org/abs/1401.4082>`_
"""
__date__ = "November 2018 - November 2019"


import numpy as np
import os
import torch
from torch.distributions import LowRankMultivariateNormal
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

from ava.models.vae_dataset import SyllableDataset
from ava.plotting.grid_plot import grid_plot


X_SHAPE = (128,128)
"""Processed spectrogram shape: ``[freq_bins, time_bins]``"""
X_DIM = np.prod(X_SHAPE)
"""Processed spectrogram dimension: ``freq_bins * time_bins``"""



[docs]class VAE(nn.Module): """Variational Autoencoder class for single-channel images. Attributes ---------- save_dir : str, optional Directory where the model is saved. Defaults to ``''``. lr : float, optional Model learning rate. Defaults to ``1e-3``. z_dim : int, optional Latent dimension. Defaults to ``32``. model_precision : float, optional Precision of the observation model. Defaults to ``10.0``. device_name : {'cpu', 'cuda', 'auto'}, optional Name of device to train the model on. When ``'auto'`` is passed, ``'cuda'`` is chosen if ``torch.cuda.is_available()``, otherwise ``'cpu'`` is chosen. Defaults to ``'auto'``. Notes ----- The model is trained to maximize the standard ELBO objective: .. math:: \mathcal{L} = \mathbb{E}_{q(z|x)} log p(x,z) + \mathbb{H}[q(z|x)] where :math:`p(x,z) = p(z)p(x|z)` and :math:`\mathbb{H}` is differential entropy. The prior :math:`p(z)` is a unit spherical normal distribution. The conditional distribution :math:`p(x|z)` is set as a spherical normal distribution to prevent overfitting. The variational distribution, :math:`q(z|x)` is an approximately rank-1 multivariate normal distribution. Here, :math:`q(z|x)` and :math:`p(x|z)` are parameterized by neural networks. Gradients are passed through stochastic layers via the reparameterization trick, implemented by the PyTorch `rsample` method. The dimensions of the network are hard-coded for use with 128 x 128 spectrograms. Although a desired latent dimension can be passed to `__init__`, the dimensions of the network limit the practical range of values roughly 8 to 64 dimensions. Fiddling with the image dimensions will require updating the parameters of the layers defined in `_build_network`. """ def __init__(self, save_dir='', lr=1e-3, z_dim=32, model_precision=10.0, device_name="auto"): """Construct a VAE. Parameters ---------- save_dir : str, optional Directory where the model is saved. Defaults to the current working directory. lr : float, optional Learning rate of the ADAM optimizer. Defaults to 1e-3. z_dim : int, optional Dimension of the latent space. Defaults to 32. model_precision : float, optional Precision of the noise model, p(x|z) = N(mu(z), \Lambda) where \Lambda = model_precision * I. Defaults to 10.0. device_name: str, optional Name of device to train the model on. Valid options are ["cpu", "cuda", "auto"]. "auto" will choose "cuda" if it is available. Defaults to "auto". Note ---- - The model is built before it's parameters can be loaded from a file. This means `self.z_dim` must match `z_dim` of the model being loaded. """ super(VAE, self).__init__() self.save_dir = save_dir self.lr = lr self.z_dim = z_dim self.model_precision = model_precision assert device_name != "cuda" or torch.cuda.is_available() if device_name == "auto": device_name = "cuda" if torch.cuda.is_available() else "cpu" self.device = torch.device(device_name) if self.save_dir != '' and not os.path.exists(self.save_dir): os.makedirs(self.save_dir) self._build_network() self.optimizer = Adam(self.parameters(), lr=self.lr) self.epoch = 0 self.loss = {'train':{}, 'test':{}} self.to(self.device) def _build_network(self): """Define all the network layers.""" # Encoder self.conv1 = nn.Conv2d(1, 8, 3,1,padding=1) self.conv2 = nn.Conv2d(8, 8, 3,2,padding=1) self.conv3 = nn.Conv2d(8, 16,3,1,padding=1) self.conv4 = nn.Conv2d(16,16,3,2,padding=1) self.conv5 = nn.Conv2d(16,24,3,1,padding=1) self.conv6 = nn.Conv2d(24,24,3,2,padding=1) self.conv7 = nn.Conv2d(24,32,3,1,padding=1) self.bn1 = nn.BatchNorm2d(1) self.bn2 = nn.BatchNorm2d(8) self.bn3 = nn.BatchNorm2d(8) self.bn4 = nn.BatchNorm2d(16) self.bn5 = nn.BatchNorm2d(16) self.bn6 = nn.BatchNorm2d(24) self.bn7 = nn.BatchNorm2d(24) self.fc1 = nn.Linear(8192,1024) self.fc2 = nn.Linear(1024,256) self.fc31 = nn.Linear(256,64) self.fc32 = nn.Linear(256,64) self.fc33 = nn.Linear(256,64) self.fc41 = nn.Linear(64,self.z_dim) self.fc42 = nn.Linear(64,self.z_dim) self.fc43 = nn.Linear(64,self.z_dim) # Decoder self.fc5 = nn.Linear(self.z_dim,64) self.fc6 = nn.Linear(64,256) self.fc7 = nn.Linear(256,1024) self.fc8 = nn.Linear(1024,8192) self.convt1 = nn.ConvTranspose2d(32,24,3,1,padding=1) self.convt2 = nn.ConvTranspose2d(24,24,3,2,padding=1,output_padding=1) self.convt3 = nn.ConvTranspose2d(24,16,3,1,padding=1) self.convt4 = nn.ConvTranspose2d(16,16,3,2,padding=1,output_padding=1) self.convt5 = nn.ConvTranspose2d(16,8,3,1,padding=1) self.convt6 = nn.ConvTranspose2d(8,8,3,2,padding=1,output_padding=1) self.convt7 = nn.ConvTranspose2d(8,1,3,1,padding=1) self.bn8 = nn.BatchNorm2d(32) self.bn9 = nn.BatchNorm2d(24) self.bn10 = nn.BatchNorm2d(24) self.bn11 = nn.BatchNorm2d(16) self.bn12 = nn.BatchNorm2d(16) self.bn13 = nn.BatchNorm2d(8) self.bn14 = nn.BatchNorm2d(8) def _get_layers(self): """Return a dictionary mapping names to network layers.""" return {'fc1':self.fc1, 'fc2':self.fc2, 'fc31':self.fc31, 'fc32':self.fc32, 'fc33':self.fc33, 'fc41':self.fc41, 'fc42':self.fc42, 'fc43':self.fc43, 'fc5':self.fc5, 'fc6':self.fc6, 'fc7':self.fc7, 'fc8':self.fc8, 'bn1':self.bn1, 'bn2':self.bn2, 'bn3':self.bn3, 'bn4':self.bn4, 'bn5':self.bn5, 'bn6':self.bn6, 'bn7':self.bn7, 'bn8':self.bn8, 'bn9':self.bn9, 'bn10':self.bn10, 'bn11':self.bn11, 'bn12':self.bn12, 'bn13':self.bn13, 'bn14':self.bn14, 'conv1':self.conv1, 'conv2':self.conv2, 'conv3':self.conv3, 'conv4':self.conv4, 'conv5':self.conv5, 'conv6':self.conv6, 'conv7':self.conv7, 'convt1':self.convt1, 'convt2':self.convt2, 'convt3':self.convt3, 'convt4':self.convt4, 'convt5':self.convt5, 'convt6':self.convt6, 'convt7':self.convt7}
[docs] def encode(self, x): """ Compute :math:`q(z|x)`. .. math:: q(z|x) = \mathcal{N}(\mu, \Sigma) .. math:: \Sigma = u u^{T} + \mathtt{diag}(d) where :math:`\mu`, :math:`u`, and :math:`d` are deterministic functions of `x` and :math:`\Sigma` denotes a covariance matrix. Parameters ---------- x : torch.Tensor The input images, with shape: ``[batch_size, height=128, width=128]`` Returns ------- mu : torch.Tensor Posterior mean, with shape ``[batch_size, self.z_dim]`` u : torch.Tensor Posterior covariance factor, as defined above. Shape: ``[batch_size, self.z_dim]`` d : torch.Tensor Posterior diagonal factor, as defined above. Shape: ``[batch_size, self.z_dim]`` """ x = x.unsqueeze(1) x = F.relu(self.conv1(self.bn1(x))) x = F.relu(self.conv2(self.bn2(x))) x = F.relu(self.conv3(self.bn3(x))) x = F.relu(self.conv4(self.bn4(x))) x = F.relu(self.conv5(self.bn5(x))) x = F.relu(self.conv6(self.bn6(x))) x = F.relu(self.conv7(self.bn7(x))) x = x.view(-1, 8192) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) mu = F.relu(self.fc31(x)) mu = self.fc41(mu) u = F.relu(self.fc32(x)) u = self.fc42(u).unsqueeze(-1) # Last dimension is rank \Sigma = 1. d = F.relu(self.fc33(x)) d = torch.exp(self.fc43(d)) # d must be positive. return mu, u, d
[docs] def decode(self, z): """ Compute :math:`p(x|z)`. .. math:: p(x|z) = \mathcal{N}(\mu, \Lambda) .. math:: \Lambda = \mathtt{model\_precision} \cdot I where :math:`\mu` is a deterministic function of `z`, :math:`\Lambda` is a precision matrix, and :math:`I` is the identity matrix. Parameters ---------- z : torch.Tensor Batch of latent samples with shape ``[batch_size, self.z_dim]`` Returns ------- x : torch.Tensor Batch of means mu, described above. Shape: ``[batch_size, X_DIM=128*128]`` """ z = F.relu(self.fc5(z)) z = F.relu(self.fc6(z)) z = F.relu(self.fc7(z)) z = F.relu(self.fc8(z)) z = z.view(-1,32,16,16) z = F.relu(self.convt1(self.bn8(z))) z = F.relu(self.convt2(self.bn9(z))) z = F.relu(self.convt3(self.bn10(z))) z = F.relu(self.convt4(self.bn11(z))) z = F.relu(self.convt5(self.bn12(z))) z = F.relu(self.convt6(self.bn13(z))) z = self.convt7(self.bn14(z)) return z.view(-1, X_DIM)
[docs] def forward(self, x, return_latent_rec=False): """ Send `x` round trip and compute a loss. In more detail: Given `x`, compute :math:`q(z|x)` and sample: :math:`\hat{z} \sim q(z|x)` . Then compute :math:`\log p(x|\hat{z})`, the log-likelihood of `x`, the input, given :math:`\hat{z}`, the latent sample. We will also need the likelihood of :math:`\hat{z}` under the model's prior: :math:`p(\hat{z})`, and the entropy of the latent conditional distribution, :math:`\mathbb{H}[q(z|x)]` . ELBO can then be estimated as: .. math:: \\frac{1}{N} \sum_{i=1}^N \mathbb{E}_{\hat{z} \sim q(z|x_i)} \log p(x_i,\hat{z}) + \mathbb{H}[q(z|x_i)] where :math:`N` denotes the number of samples from the data distribution and the expectation is estimated using a single latent sample, :math:`\hat{z}`. In practice, the outer expectation is estimated using minibatches. Parameters ---------- x : torch.Tensor A batch of samples from the data distribution (spectrograms). Shape: ``[batch_size, height=128, width=128]`` return_latent_rec : bool, optional Whether to return latent means and reconstructions. Defaults to ``False``. Returns ------- loss : torch.Tensor Negative ELBO times the batch size. Shape: ``[]`` latent : numpy.ndarray, if `return_latent_rec` Latent means. Shape: ``[batch_size, self.z_dim]`` reconstructions : numpy.ndarray, if `return_latent_rec` Reconstructed means. Shape: ``[batch_size, height=128, width=128]`` """ mu, u, d = self.encode(x) latent_dist = LowRankMultivariateNormal(mu, u, d) z = latent_dist.rsample() x_rec = self.decode(z) # E_{q(z|x)} p(z) elbo = -0.5 * (torch.sum(torch.pow(z,2)) + self.z_dim * np.log(2*np.pi)) # E_{q(z|x)} p(x|z) pxz_term = -0.5 * X_DIM * (np.log(2*np.pi/self.model_precision)) l2s = torch.sum(torch.pow(x.view(x.shape[0],-1) - x_rec, 2), dim=1) pxz_term = pxz_term - 0.5 * self.model_precision * torch.sum(l2s) elbo = elbo + pxz_term # H[q(z|x)] elbo = elbo + torch.sum(latent_dist.entropy()) if return_latent_rec: return -elbo, z.detach().cpu().numpy(), \ x_rec.view(-1, X_SHAPE[0], X_SHAPE[1]).detach().cpu().numpy() return -elbo
[docs] def train_epoch(self, train_loader): """ Train the model for a single epoch. Parameters ---------- train_loader : torch.utils.data.Dataloader ava.models.vae_dataset.SyllableDataset Dataloader for training set Returns ------- elbo : float A biased estimate of the ELBO, estimated using samples from `train_loader`. """ self.train() train_loss = 0.0 for batch_idx, data in enumerate(train_loader): self.optimizer.zero_grad() data = data.to(self.device) loss = self.forward(data) train_loss += loss.item() loss.backward() self.optimizer.step() train_loss /= len(train_loader.dataset) print('Epoch: {} Average loss: {:.4f}'.format(self.epoch, \ train_loss)) self.epoch += 1 return train_loss
[docs] def test_epoch(self, test_loader): """ Test the model on a held-out test set, return an ELBO estimate. Parameters ---------- test_loader : torch.utils.data.Dataloader ava.models.vae_dataset.SyllableDataset Dataloader for test set Returns ------- elbo : float An unbiased estimate of the ELBO, estimated using samples from `test_loader`. """ self.eval() test_loss = 0.0 with torch.no_grad(): for i, data in enumerate(test_loader): data = data.to(self.device) loss = self.forward(data) test_loss += loss.item() test_loss /= len(test_loader.dataset) print('Test loss: {:.4f}'.format(test_loss)) return test_loss
[docs] def train_loop(self, loaders, epochs=100, test_freq=2, save_freq=10, vis_freq=1): """ Train the model for multiple epochs, testing and saving along the way. Parameters ---------- loaders : dictionary Dictionary mapping the keys ``'test'`` and ``'train'`` to respective torch.utils.data.Dataloader objects. epochs : int, optional Number of (possibly additional) epochs to train the model for. Defaults to ``100``. test_freq : int, optional Testing is performed every `test_freq` epochs. Defaults to ``2``. save_freq : int, optional The model is saved every `save_freq` epochs. Defaults to ``10``. vis_freq : int, optional Syllable reconstructions are plotted every `vis_freq` epochs. Defaults to ``1``. """ print("="*40) print("Training: epochs", self.epoch, "to", self.epoch+epochs-1) print("Training set:", len(loaders['train'].dataset)) print("Test set:", len(loaders['test'].dataset)) print("="*40) # For some number of epochs... for epoch in range(self.epoch, self.epoch+epochs): # Run through the training data and record a loss. loss = self.train_epoch(loaders['train']) self.loss['train'][epoch] = loss # Run through the test data and record a loss. if (test_freq is not None) and (epoch % test_freq == 0): loss = self.test_epoch(loaders['test']) self.loss['test'][epoch] = loss # Save the model. if (save_freq is not None) and (epoch % save_freq == 0) and \ (epoch > 0): filename = "checkpoint_"+str(epoch).zfill(3)+'.tar' self.save_state(filename) # Plot reconstructions. if (vis_freq is not None) and (epoch % vis_freq == 0): self.visualize(loaders['test'])
[docs] def save_state(self, filename): """Save all the model parameters to the given file.""" layers = self._get_layers() state = {} for layer_name in layers: state[layer_name] = layers[layer_name].state_dict() state['optimizer_state'] = self.optimizer.state_dict() state['loss'] = self.loss state['z_dim'] = self.z_dim state['epoch'] = self.epoch state['lr'] = self.lr state['save_dir'] = self.save_dir filename = os.path.join(self.save_dir, filename) torch.save(state, filename)
[docs] def load_state(self, filename): """ Load all the model parameters from the given ``.tar`` file. The ``.tar`` file should be written by `self.save_state`. Parameters ---------- filename : str File containing a model state. Note ---- - `self.lr`, `self.save_dir`, and `self.z_dim` are not loaded. """ checkpoint = torch.load(filename, map_location=self.device) assert checkpoint['z_dim'] == self.z_dim layers = self._get_layers() for layer_name in layers: layer = layers[layer_name] layer.load_state_dict(checkpoint[layer_name]) self.optimizer.load_state_dict(checkpoint['optimizer_state']) self.loss = checkpoint['loss'] self.epoch = checkpoint['epoch']
[docs] def visualize(self, loader, num_specs=5, gap=(2,6), \ save_filename='reconstruction.pdf'): """ Plot spectrograms and their reconstructions. Spectrograms are chosen at random from the Dataloader Dataset. Parameters ---------- loader : torch.utils.data.Dataloader Spectrogram Dataloader num_specs : int, optional Number of spectrogram pairs to plot. Defaults to ``5``. gap : int or tuple of two ints, optional The vertical and horizontal gap between images, in pixels. Defaults to ``(2,6)``. save_filename : str, optional Where to save the plot, relative to `self.save_dir`. Defaults to ``'temp.pdf'``. Returns ------- specs : numpy.ndarray Spectgorams from `loader`. rec_specs : numpy.ndarray Corresponding spectrogram reconstructions. """ # Collect random indices. assert num_specs <= len(loader.dataset) and num_specs >= 1 indices = np.random.choice(np.arange(len(loader.dataset)), size=num_specs,replace=False) # Retrieve spectrograms from the loader. specs = torch.stack(loader.dataset[indices]).to(self.device) # Get resonstructions. with torch.no_grad(): _, _, rec_specs = self.forward(specs, return_latent_rec=True) specs = specs.detach().cpu().numpy() all_specs = np.stack([specs, rec_specs]) # Plot. save_filename = os.path.join(self.save_dir, save_filename) grid_plot(all_specs, gap=gap, filename=save_filename) return specs, rec_specs
[docs] def get_latent(self, loader): """ Get latent means for all syllable in the given loader. Parameters ---------- loader : torch.utils.data.Dataloader ava.models.vae_dataset.SyllableDataset Dataloader. Returns ------- latent : numpy.ndarray Latent means. Shape: ``[len(loader.dataset), self.z_dim]`` Note ---- - Make sure your loader is not set to shuffle if you're going to match these with labels or other fields later. """ latent = np.zeros((len(loader.dataset), self.z_dim)) i = 0 for data in loader: data = data.to(self.device) with torch.no_grad(): mu, _, _ = self.encode(data) mu = mu.detach().cpu().numpy() latent[i:i+len(mu)] = mu i += len(mu) return latent
if __name__ == '__main__': pass ###