"""
A Variational Autoencoder (VAE) for spectrogram data.
VAE References
--------------
.. [1] Kingma, Diederik P., and Max Welling. "Auto-encoding variational bayes."
arXiv preprint arXiv:1312.6114 (2013).
`<https://arxiv.org/abs/1312.6114>`_
.. [2] Rezende, Danilo Jimenez, Shakir Mohamed, and Daan Wierstra. "Stochastic
backpropagation and approximate inference in deep generative models." arXiv
preprint arXiv:1401.4082 (2014).
`<https://arxiv.org/abs/1401.4082>`_
"""
__date__ = "November 2018 - November 2019"
import numpy as np
import os
import torch
from torch.distributions import LowRankMultivariateNormal
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from ava.models.vae_dataset import SyllableDataset
from ava.plotting.grid_plot import grid_plot
X_SHAPE = (128,128)
"""Processed spectrogram shape: ``[freq_bins, time_bins]``"""
X_DIM = np.prod(X_SHAPE)
"""Processed spectrogram dimension: ``freq_bins * time_bins``"""
[docs]class VAE(nn.Module):
"""Variational Autoencoder class for single-channel images.
Attributes
----------
save_dir : str, optional
Directory where the model is saved. Defaults to ``''``.
lr : float, optional
Model learning rate. Defaults to ``1e-3``.
z_dim : int, optional
Latent dimension. Defaults to ``32``.
model_precision : float, optional
Precision of the observation model. Defaults to ``10.0``.
device_name : {'cpu', 'cuda', 'auto'}, optional
Name of device to train the model on. When ``'auto'`` is passed,
``'cuda'`` is chosen if ``torch.cuda.is_available()``, otherwise
``'cpu'`` is chosen. Defaults to ``'auto'``.
Notes
-----
The model is trained to maximize the standard ELBO objective:
.. math:: \mathcal{L} = \mathbb{E}_{q(z|x)} log p(x,z) + \mathbb{H}[q(z|x)]
where :math:`p(x,z) = p(z)p(x|z)` and :math:`\mathbb{H}` is differential
entropy. The prior :math:`p(z)` is a unit spherical normal distribution. The
conditional distribution :math:`p(x|z)` is set as a spherical normal
distribution to prevent overfitting. The variational distribution,
:math:`q(z|x)` is an approximately rank-1 multivariate normal distribution.
Here, :math:`q(z|x)` and :math:`p(x|z)` are parameterized by neural
networks. Gradients are passed through stochastic layers via the
reparameterization trick, implemented by the PyTorch `rsample` method.
The dimensions of the network are hard-coded for use with 128 x 128
spectrograms. Although a desired latent dimension can be passed to
`__init__`, the dimensions of the network limit the practical range of
values roughly 8 to 64 dimensions. Fiddling with the image dimensions will
require updating the parameters of the layers defined in `_build_network`.
"""
def __init__(self, save_dir='', lr=1e-3, z_dim=32, model_precision=10.0,
device_name="auto"):
"""Construct a VAE.
Parameters
----------
save_dir : str, optional
Directory where the model is saved. Defaults to the current working
directory.
lr : float, optional
Learning rate of the ADAM optimizer. Defaults to 1e-3.
z_dim : int, optional
Dimension of the latent space. Defaults to 32.
model_precision : float, optional
Precision of the noise model, p(x|z) = N(mu(z), \Lambda) where
\Lambda = model_precision * I. Defaults to 10.0.
device_name: str, optional
Name of device to train the model on. Valid options are ["cpu",
"cuda", "auto"]. "auto" will choose "cuda" if it is available.
Defaults to "auto".
Note
----
- The model is built before it's parameters can be loaded from a file.
This means `self.z_dim` must match `z_dim` of the model being
loaded.
"""
super(VAE, self).__init__()
self.save_dir = save_dir
self.lr = lr
self.z_dim = z_dim
self.model_precision = model_precision
assert device_name != "cuda" or torch.cuda.is_available()
if device_name == "auto":
device_name = "cuda" if torch.cuda.is_available() else "cpu"
self.device = torch.device(device_name)
if self.save_dir != '' and not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
self._build_network()
self.optimizer = Adam(self.parameters(), lr=self.lr)
self.epoch = 0
self.loss = {'train':{}, 'test':{}}
self.to(self.device)
def _build_network(self):
"""Define all the network layers."""
# Encoder
self.conv1 = nn.Conv2d(1, 8, 3,1,padding=1)
self.conv2 = nn.Conv2d(8, 8, 3,2,padding=1)
self.conv3 = nn.Conv2d(8, 16,3,1,padding=1)
self.conv4 = nn.Conv2d(16,16,3,2,padding=1)
self.conv5 = nn.Conv2d(16,24,3,1,padding=1)
self.conv6 = nn.Conv2d(24,24,3,2,padding=1)
self.conv7 = nn.Conv2d(24,32,3,1,padding=1)
self.bn1 = nn.BatchNorm2d(1)
self.bn2 = nn.BatchNorm2d(8)
self.bn3 = nn.BatchNorm2d(8)
self.bn4 = nn.BatchNorm2d(16)
self.bn5 = nn.BatchNorm2d(16)
self.bn6 = nn.BatchNorm2d(24)
self.bn7 = nn.BatchNorm2d(24)
self.fc1 = nn.Linear(8192,1024)
self.fc2 = nn.Linear(1024,256)
self.fc31 = nn.Linear(256,64)
self.fc32 = nn.Linear(256,64)
self.fc33 = nn.Linear(256,64)
self.fc41 = nn.Linear(64,self.z_dim)
self.fc42 = nn.Linear(64,self.z_dim)
self.fc43 = nn.Linear(64,self.z_dim)
# Decoder
self.fc5 = nn.Linear(self.z_dim,64)
self.fc6 = nn.Linear(64,256)
self.fc7 = nn.Linear(256,1024)
self.fc8 = nn.Linear(1024,8192)
self.convt1 = nn.ConvTranspose2d(32,24,3,1,padding=1)
self.convt2 = nn.ConvTranspose2d(24,24,3,2,padding=1,output_padding=1)
self.convt3 = nn.ConvTranspose2d(24,16,3,1,padding=1)
self.convt4 = nn.ConvTranspose2d(16,16,3,2,padding=1,output_padding=1)
self.convt5 = nn.ConvTranspose2d(16,8,3,1,padding=1)
self.convt6 = nn.ConvTranspose2d(8,8,3,2,padding=1,output_padding=1)
self.convt7 = nn.ConvTranspose2d(8,1,3,1,padding=1)
self.bn8 = nn.BatchNorm2d(32)
self.bn9 = nn.BatchNorm2d(24)
self.bn10 = nn.BatchNorm2d(24)
self.bn11 = nn.BatchNorm2d(16)
self.bn12 = nn.BatchNorm2d(16)
self.bn13 = nn.BatchNorm2d(8)
self.bn14 = nn.BatchNorm2d(8)
def _get_layers(self):
"""Return a dictionary mapping names to network layers."""
return {'fc1':self.fc1, 'fc2':self.fc2, 'fc31':self.fc31,
'fc32':self.fc32, 'fc33':self.fc33, 'fc41':self.fc41,
'fc42':self.fc42, 'fc43':self.fc43, 'fc5':self.fc5,
'fc6':self.fc6, 'fc7':self.fc7, 'fc8':self.fc8, 'bn1':self.bn1,
'bn2':self.bn2, 'bn3':self.bn3, 'bn4':self.bn4, 'bn5':self.bn5,
'bn6':self.bn6, 'bn7':self.bn7, 'bn8':self.bn8, 'bn9':self.bn9,
'bn10':self.bn10, 'bn11':self.bn11, 'bn12':self.bn12,
'bn13':self.bn13, 'bn14':self.bn14, 'conv1':self.conv1,
'conv2':self.conv2, 'conv3':self.conv3, 'conv4':self.conv4,
'conv5':self.conv5, 'conv6':self.conv6, 'conv7':self.conv7,
'convt1':self.convt1, 'convt2':self.convt2,
'convt3':self.convt3, 'convt4':self.convt4,
'convt5':self.convt5, 'convt6':self.convt6,
'convt7':self.convt7}
[docs] def encode(self, x):
"""
Compute :math:`q(z|x)`.
.. math:: q(z|x) = \mathcal{N}(\mu, \Sigma)
.. math:: \Sigma = u u^{T} + \mathtt{diag}(d)
where :math:`\mu`, :math:`u`, and :math:`d` are deterministic functions
of `x` and :math:`\Sigma` denotes a covariance matrix.
Parameters
----------
x : torch.Tensor
The input images, with shape: ``[batch_size, height=128,
width=128]``
Returns
-------
mu : torch.Tensor
Posterior mean, with shape ``[batch_size, self.z_dim]``
u : torch.Tensor
Posterior covariance factor, as defined above. Shape:
``[batch_size, self.z_dim]``
d : torch.Tensor
Posterior diagonal factor, as defined above. Shape:
``[batch_size, self.z_dim]``
"""
x = x.unsqueeze(1)
x = F.relu(self.conv1(self.bn1(x)))
x = F.relu(self.conv2(self.bn2(x)))
x = F.relu(self.conv3(self.bn3(x)))
x = F.relu(self.conv4(self.bn4(x)))
x = F.relu(self.conv5(self.bn5(x)))
x = F.relu(self.conv6(self.bn6(x)))
x = F.relu(self.conv7(self.bn7(x)))
x = x.view(-1, 8192)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
mu = F.relu(self.fc31(x))
mu = self.fc41(mu)
u = F.relu(self.fc32(x))
u = self.fc42(u).unsqueeze(-1) # Last dimension is rank \Sigma = 1.
d = F.relu(self.fc33(x))
d = torch.exp(self.fc43(d)) # d must be positive.
return mu, u, d
[docs] def decode(self, z):
"""
Compute :math:`p(x|z)`.
.. math:: p(x|z) = \mathcal{N}(\mu, \Lambda)
.. math:: \Lambda = \mathtt{model\_precision} \cdot I
where :math:`\mu` is a deterministic function of `z`, :math:`\Lambda` is
a precision matrix, and :math:`I` is the identity matrix.
Parameters
----------
z : torch.Tensor
Batch of latent samples with shape ``[batch_size, self.z_dim]``
Returns
-------
x : torch.Tensor
Batch of means mu, described above. Shape: ``[batch_size,
X_DIM=128*128]``
"""
z = F.relu(self.fc5(z))
z = F.relu(self.fc6(z))
z = F.relu(self.fc7(z))
z = F.relu(self.fc8(z))
z = z.view(-1,32,16,16)
z = F.relu(self.convt1(self.bn8(z)))
z = F.relu(self.convt2(self.bn9(z)))
z = F.relu(self.convt3(self.bn10(z)))
z = F.relu(self.convt4(self.bn11(z)))
z = F.relu(self.convt5(self.bn12(z)))
z = F.relu(self.convt6(self.bn13(z)))
z = self.convt7(self.bn14(z))
return z.view(-1, X_DIM)
[docs] def forward(self, x, return_latent_rec=False):
"""
Send `x` round trip and compute a loss.
In more detail: Given `x`, compute :math:`q(z|x)` and sample:
:math:`\hat{z} \sim q(z|x)` . Then compute :math:`\log p(x|\hat{z})`,
the log-likelihood of `x`, the input, given :math:`\hat{z}`, the latent
sample. We will also need the likelihood of :math:`\hat{z}` under the
model's prior: :math:`p(\hat{z})`, and the entropy of the latent
conditional distribution, :math:`\mathbb{H}[q(z|x)]` . ELBO can then be
estimated as:
.. math:: \\frac{1}{N} \sum_{i=1}^N \mathbb{E}_{\hat{z} \sim q(z|x_i)}
\log p(x_i,\hat{z}) + \mathbb{H}[q(z|x_i)]
where :math:`N` denotes the number of samples from the data distribution
and the expectation is estimated using a single latent sample,
:math:`\hat{z}`. In practice, the outer expectation is estimated using
minibatches.
Parameters
----------
x : torch.Tensor
A batch of samples from the data distribution (spectrograms).
Shape: ``[batch_size, height=128, width=128]``
return_latent_rec : bool, optional
Whether to return latent means and reconstructions. Defaults to
``False``.
Returns
-------
loss : torch.Tensor
Negative ELBO times the batch size. Shape: ``[]``
latent : numpy.ndarray, if `return_latent_rec`
Latent means. Shape: ``[batch_size, self.z_dim]``
reconstructions : numpy.ndarray, if `return_latent_rec`
Reconstructed means. Shape: ``[batch_size, height=128, width=128]``
"""
mu, u, d = self.encode(x)
latent_dist = LowRankMultivariateNormal(mu, u, d)
z = latent_dist.rsample()
x_rec = self.decode(z)
# E_{q(z|x)} p(z)
elbo = -0.5 * (torch.sum(torch.pow(z,2)) + self.z_dim * np.log(2*np.pi))
# E_{q(z|x)} p(x|z)
pxz_term = -0.5 * X_DIM * (np.log(2*np.pi/self.model_precision))
l2s = torch.sum(torch.pow(x.view(x.shape[0],-1) - x_rec, 2), dim=1)
pxz_term = pxz_term - 0.5 * self.model_precision * torch.sum(l2s)
elbo = elbo + pxz_term
# H[q(z|x)]
elbo = elbo + torch.sum(latent_dist.entropy())
if return_latent_rec:
return -elbo, z.detach().cpu().numpy(), \
x_rec.view(-1, X_SHAPE[0], X_SHAPE[1]).detach().cpu().numpy()
return -elbo
[docs] def train_epoch(self, train_loader):
"""
Train the model for a single epoch.
Parameters
----------
train_loader : torch.utils.data.Dataloader
ava.models.vae_dataset.SyllableDataset Dataloader for training set
Returns
-------
elbo : float
A biased estimate of the ELBO, estimated using samples from
`train_loader`.
"""
self.train()
train_loss = 0.0
for batch_idx, data in enumerate(train_loader):
self.optimizer.zero_grad()
data = data.to(self.device)
loss = self.forward(data)
train_loss += loss.item()
loss.backward()
self.optimizer.step()
train_loss /= len(train_loader.dataset)
print('Epoch: {} Average loss: {:.4f}'.format(self.epoch, \
train_loss))
self.epoch += 1
return train_loss
[docs] def test_epoch(self, test_loader):
"""
Test the model on a held-out test set, return an ELBO estimate.
Parameters
----------
test_loader : torch.utils.data.Dataloader
ava.models.vae_dataset.SyllableDataset Dataloader for test set
Returns
-------
elbo : float
An unbiased estimate of the ELBO, estimated using samples from
`test_loader`.
"""
self.eval()
test_loss = 0.0
with torch.no_grad():
for i, data in enumerate(test_loader):
data = data.to(self.device)
loss = self.forward(data)
test_loss += loss.item()
test_loss /= len(test_loader.dataset)
print('Test loss: {:.4f}'.format(test_loss))
return test_loss
[docs] def train_loop(self, loaders, epochs=100, test_freq=2, save_freq=10,
vis_freq=1):
"""
Train the model for multiple epochs, testing and saving along the way.
Parameters
----------
loaders : dictionary
Dictionary mapping the keys ``'test'`` and ``'train'`` to respective
torch.utils.data.Dataloader objects.
epochs : int, optional
Number of (possibly additional) epochs to train the model for.
Defaults to ``100``.
test_freq : int, optional
Testing is performed every `test_freq` epochs. Defaults to ``2``.
save_freq : int, optional
The model is saved every `save_freq` epochs. Defaults to ``10``.
vis_freq : int, optional
Syllable reconstructions are plotted every `vis_freq` epochs.
Defaults to ``1``.
"""
print("="*40)
print("Training: epochs", self.epoch, "to", self.epoch+epochs-1)
print("Training set:", len(loaders['train'].dataset))
print("Test set:", len(loaders['test'].dataset))
print("="*40)
# For some number of epochs...
for epoch in range(self.epoch, self.epoch+epochs):
# Run through the training data and record a loss.
loss = self.train_epoch(loaders['train'])
self.loss['train'][epoch] = loss
# Run through the test data and record a loss.
if (test_freq is not None) and (epoch % test_freq == 0):
loss = self.test_epoch(loaders['test'])
self.loss['test'][epoch] = loss
# Save the model.
if (save_freq is not None) and (epoch % save_freq == 0) and \
(epoch > 0):
filename = "checkpoint_"+str(epoch).zfill(3)+'.tar'
self.save_state(filename)
# Plot reconstructions.
if (vis_freq is not None) and (epoch % vis_freq == 0):
self.visualize(loaders['test'])
[docs] def save_state(self, filename):
"""Save all the model parameters to the given file."""
layers = self._get_layers()
state = {}
for layer_name in layers:
state[layer_name] = layers[layer_name].state_dict()
state['optimizer_state'] = self.optimizer.state_dict()
state['loss'] = self.loss
state['z_dim'] = self.z_dim
state['epoch'] = self.epoch
state['lr'] = self.lr
state['save_dir'] = self.save_dir
filename = os.path.join(self.save_dir, filename)
torch.save(state, filename)
[docs] def load_state(self, filename):
"""
Load all the model parameters from the given ``.tar`` file.
The ``.tar`` file should be written by `self.save_state`.
Parameters
----------
filename : str
File containing a model state.
Note
----
- `self.lr`, `self.save_dir`, and `self.z_dim` are not loaded.
"""
checkpoint = torch.load(filename, map_location=self.device)
assert checkpoint['z_dim'] == self.z_dim
layers = self._get_layers()
for layer_name in layers:
layer = layers[layer_name]
layer.load_state_dict(checkpoint[layer_name])
self.optimizer.load_state_dict(checkpoint['optimizer_state'])
self.loss = checkpoint['loss']
self.epoch = checkpoint['epoch']
[docs] def visualize(self, loader, num_specs=5, gap=(2,6), \
save_filename='reconstruction.pdf'):
"""
Plot spectrograms and their reconstructions.
Spectrograms are chosen at random from the Dataloader Dataset.
Parameters
----------
loader : torch.utils.data.Dataloader
Spectrogram Dataloader
num_specs : int, optional
Number of spectrogram pairs to plot. Defaults to ``5``.
gap : int or tuple of two ints, optional
The vertical and horizontal gap between images, in pixels. Defaults
to ``(2,6)``.
save_filename : str, optional
Where to save the plot, relative to `self.save_dir`. Defaults to
``'temp.pdf'``.
Returns
-------
specs : numpy.ndarray
Spectgorams from `loader`.
rec_specs : numpy.ndarray
Corresponding spectrogram reconstructions.
"""
# Collect random indices.
assert num_specs <= len(loader.dataset) and num_specs >= 1
indices = np.random.choice(np.arange(len(loader.dataset)),
size=num_specs,replace=False)
# Retrieve spectrograms from the loader.
specs = torch.stack(loader.dataset[indices]).to(self.device)
# Get resonstructions.
with torch.no_grad():
_, _, rec_specs = self.forward(specs, return_latent_rec=True)
specs = specs.detach().cpu().numpy()
all_specs = np.stack([specs, rec_specs])
# Plot.
save_filename = os.path.join(self.save_dir, save_filename)
grid_plot(all_specs, gap=gap, filename=save_filename)
return specs, rec_specs
[docs] def get_latent(self, loader):
"""
Get latent means for all syllable in the given loader.
Parameters
----------
loader : torch.utils.data.Dataloader
ava.models.vae_dataset.SyllableDataset Dataloader.
Returns
-------
latent : numpy.ndarray
Latent means. Shape: ``[len(loader.dataset), self.z_dim]``
Note
----
- Make sure your loader is not set to shuffle if you're going to match
these with labels or other fields later.
"""
latent = np.zeros((len(loader.dataset), self.z_dim))
i = 0
for data in loader:
data = data.to(self.device)
with torch.no_grad():
mu, _, _ = self.encode(data)
mu = mu.detach().cpu().numpy()
latent[i:i+len(mu)] = mu
i += len(mu)
return latent
if __name__ == '__main__':
pass
###