{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "daecd60e-5c5c-401c-9471-3ea002dee555",
   "metadata": {},
   "source": [
    "# Lab exercise 6: The Variational Auto-Encoder\n",
    "The VAE that we will develop is based on the following generative story:\n",
    "* $z\\sim p_\\theta(z)$\n",
    "* $x \\sim p_\\theta(x|z;θ)$\n",
    "\n",
    "where the latent representations z take value in $R^n$\n",
    ". The prior ditribution $p(z)$ is a multivariate Gaussian where each coordinate is independent. We fix the mean and variance of each coordinate to 0 and 1, respectively. The conditional distribution $p(x|z;θ)$ is parameterized by a neural network: it is the decoder! The generated pixels x are independent Gaussians with a fixed variance.\n",
    "\n",
    "Note: this kind of VAE will be quite bad at generating MNIST picture. Therefore, when you do you experiments, you should both generate picture and display the mean parameters of the output distributions. This is a well known problem of VAE, you can try to play with the network architecture and the parameters to improve generation.\n",
    "\n",
    "Although the decoder is similar to the auto-encoder decoder, the encoder is different: it must return two tensors, the tensor of means and the tensor of variances. As the variance of a Gaussian distribution is constrained to be strictly positive, it is usual to instead return the log-variance (or log squared variance), which is unconstrained. If you exponentiate the log-variance, you get the variance which will be strictly positive as the exponential function only returns positive values.\n",
    "\n",
    "Similarly to the auto-encoder, there are several hyperparameters you can try to tune. However, for the VAE I strongly advise you to:\n",
    "* set the latent space dim to 2\n",
    "* use [gradient clipping](https://pytorch-org.translate.goog/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html?_x_tr_sl=en&_x_tr_tl=fr&_x_tr_hl=fr&_x_tr_pto=sc) (bound the gradient) \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bd298077-53e3-438b-a751-7eb0877f1b90",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "from torch import nn\n",
    "from torch import optim\n",
    "from torch.utils.data import DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cb75ca7-0fb1-46bf-aea9-ad2f1112e01b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class VAEEncoder(nn.Module):\n",
    "    def __init__(self, dim_input, dim_latent):\n",
    "        super().__init__()\n",
    "        # TODO\n",
    "\n",
    "    def forward(self, inputs):\n",
    "        # TODO\n",
    "\n",
    "        # mu = mean\n",
    "        # log_sigma_squared = log variance\n",
    "        # The idea is that you use two different output projection:\n",
    "        # one for the mean, one for the log_sigma_squared\n",
    "        # but all other layers are shared\n",
    "        return mu, log_sigma_squared\n",
    "        \n",
    "class VAEDecoder(nn.Module):\n",
    "    def __init__(self, dim_latent, dim_output):\n",
    "        super().__init__()\n",
    "        # TODO\n",
    "\n",
    "    def forward(self, z):\n",
    "        # TODO\n",
    "\n",
    "        return img"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aff6103f-3808-475e-aedc-e264249b5385",
   "metadata": {},
   "source": [
    "To compute the training loss, you must compute two terms:\n",
    "* A Monte-Carlo estimation of the reconstruction loss (you can start considering sampling only one z)\n",
    "* The KL divergence between the distributions computed by the encoder and the prior\n",
    "\n",
    "For the reconstruction loss, you can use the mean square error loss.\n",
    "\n",
    "To sample values, you can use the reparameterization trick as follows: \n",
    "```\n",
    "e = torch.normal(0, 1., mu.shape)\n",
    "z = mu + e * torch.sqrt(torch.exp(log_sigma_squared))\n",
    "```\n",
    "\n",
    "For the formula of the $KL(q_\\phi(z|x)|| p_\\theta(z))$ you should have it from the exercise we did during the lecture. You can also check appendix of the [original paper](https://arxiv.org/pdf/1312.6114.pdf) \n",
    "\n",
    "You have in the following to implements:\n",
    "* The training loop\n",
    "* The two losses\n",
    "* Choose which architecture for encoder and decoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74997eff-4a76-4ecc-9ca8-21b2e8595da0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def KL(mu,log_sigma):\n",
    "    '''\n",
    "    We suppose that the prior is a centered gaussian\n",
    "    '''\n",
    "    #TODO\n",
    "\n",
    "def reconstruction_loss(input_image, predicted_image):\n",
    "    #TODO\n",
    "def training_loop(dataset, encoder, decoder, n_sample_monte_carlo=1, max_epoch=50, learning_rate=1e-3, batch_size=128, max_grad_norm=5e-1):\n",
    "    data_loader = DataLoader(list(dataset), batch_size=batch_size, shuffle=True, drop_last=True)\n",
    "    optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=learning_rate)\n",
    "    for epoch in range(max_epoch):\n",
    "        losses = []\n",
    "        for x in dataloader:\n",
    "            optimizer.zero_grad()\n",
    "            mu, log_sigma = encoder(x)\n",
    "            # do sampling\n",
    "            z =  # sampled z\n",
    "            y = decoder(z) # distribution given\n",
    "\n",
    "            loss = reconstruction_loss(x, y) + KL(mu, log_sigma)\n",
    "            losses.append(loss.item())\n",
    "            \n",
    "            loss.backward()\n",
    "            \n",
    "            torch.nn.utils.clip_grad_norm_(list(encoder.parameters()) + list(decoder.parameters()), max_grad_norm)\n",
    "            optimizer.step()\n",
    "        print(f\"The sum of losses for epoch {epoch} is {np.sum(losses)}\")\n",
    "                    \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a018307b-571c-4a6c-9068-a29a900c672a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import dataset_loader # file given in lab 3\n",
    "# Download mnist dataset \n",
    "if(\"mnist.pkl.gz\" not in os.listdir(\".\")):\n",
    "    # this link doesn't work any more,\n",
    "    # seach on google for the file \"mnist.pkl.gz\"\n",
    "    # and download it\n",
    "    !wget https://github.com/mnielsen/neural-networks-and-deep-learning/raw/master/data/mnist.pkl.gz\n",
    "\n",
    "# if you have it somewhere else, you can comment the lines above\n",
    "# and overwrite the path below\n",
    "mnist_path = \"./mnist.pkl.gz\"\n",
    "\n",
    "\n",
    "train_data, dev_data, test_data = dataset_loader.load_mnist(mnist_path)\n",
    "\n",
    "dataset = train_data[0] # we are not interrested having labels\n",
    "\n",
    "encoder = # TODO\n",
    "decoder = # TODO\n",
    "\n",
    "training_loop(dataset, encoder, decoder, .....)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eff03ef5-cb1f-4152-8b0d-ef87f083b9ef",
   "metadata": {},
   "source": [
    "## Generating new images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7131239-b7d8-40b8-94b9-525ec7f5b966",
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 10\n",
    "\n",
    "z = torch.normal(0, 1., mu.shape) # sampling\n",
    "\n",
    "generated = decoder(z)\n",
    "imgs = np.zeros((n*28, n*28))\n",
    "\n",
    "plt.figure()\n",
    "for i in range(n*n):\n",
    "    imgs[i//n  * 28: i//n * 28 + 28,i%n * 28 : i%n * 28 + 28] = 1 - generated[i].detach().numpy().reshape(28,28)\n",
    "plt.axis('off')\n",
    "plt.imshow(imgs, cmap='Greys')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0248392e-7f08-4667-8795-b5878a74dd73",
   "metadata": {},
   "source": [
    "## Latent space visualization\n",
    "\n",
    "It is quite useful to visualize the latent space the variational auto-encoder. You can visualize it either for the training data or the dev data. Note that if you want to visualize a latent space when its dimension is greater than two (useful for the first part!), you could project it in 2 dimensions using PCA (its already implemented in scikit-learn!) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d230a871-afb7-480f-ae03-07bb9c21a60a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.cm as cm\n",
    "\n",
    "labels = np.array(dev_data[1])\n",
    "images = np.array(dev_data[0])\n",
    "dl = DataLoader(list(zip(images, labels)), batch_size=128, shuffle=True, drop_last=True)\n",
    "az = []\n",
    "ay = []\n",
    "with torch.no_grad():\n",
    "    for x, y in dl:\n",
    "        mu, log_sigma = encoder(x)\n",
    "\n",
    "        z = torch.randn(mu.shape) * torch.exp(0.5 * log_sigma) + mu\n",
    "        az += z.tolist()\n",
    "        ay += y.tolist()\n",
    "az = np.array(az)\n",
    "ay = np.array(ay)\n",
    "\n",
    "colors = cm.rainbow(np.linspace(0, 1, 10))\n",
    "for i in range(10):\n",
    "    plt.scatter(az[ay == i][:, 0], az[ay == i][:, 1])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
