{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, random\n",
    "import numpy as np\n",
    "import torch\n",
    "import dataset_loader\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Lab exercises 6\n",
    "\n",
    "The goal of this lab exercise is to help you learn how to use Pytorch, which you will need to use for the project.\n",
    "\n",
    "It is important the you read the documentation to understand how to use Pytorch functions, what kind of transformation they apply etc. You have to take time to read it carefully to understand what you are doing.\n",
    "\n",
    "- https://pytorch.org/docs/stable/nn.html\n",
    "- https://pytorch.org/docs/stable/torch.html\n",
    "\n",
    "Each time you use a function, check the manual, even if you think that you know how it works. If you don't read it, you are just making stup\\*d decisions.\n",
    "\n",
    "# 1. Pytorch basics\n",
    "\n",
    "Instead of manipulating numpy arrays, we will manipulate pytorch tensors.\n",
    "A lot of things are defined in the same way, except the you can use autograd!\n",
    "\n",
    "Note that when using pytorch and the autograd mechanism, you want to avoid in-place operations, for reason I didn't have time cover in the course, sorry. :(\n",
    "The only time you will need in place operation in this course is for parameter initialization.\n",
    "It is easy to identify in place operations: their function name ends with an underscore!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create a tensor of shape (2, 4) containing random values.\n",
    "# by default, it will be a float tensor and will not ask for gradient\n",
    "\n",
    "t = torch.rand(2, 4)\n",
    "\n",
    "# you can also create a tensor full of 0 or 1\n",
    "t_zeros = torch.zeros(2, 4)\n",
    "t_ones = torch.ones(2, 4)\n",
    "\n",
    "print(t)\n",
    "print(t_zeros)\n",
    "print(t_ones)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# there also exists *_like functions that creates a tensor with exactly\n",
    "# the same properties as its argument (shape, gradient requirement, type, etc)\n",
    "\n",
    "t2 = torch.rand_like(t)\n",
    "t2_zeros = torch.zeros_like(t)\n",
    "t2_ones = torch.ones_like(t)\n",
    "\n",
    "print(t2)\n",
    "print(t2_zeros)\n",
    "print(t2_ones)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# you can also create a tensor of long values (i.e. integers),\n",
    "# which will be usefull to represent labels :)\n",
    "\n",
    "t_zeros_long = torch.zeros(10, dtype=torch.long)\n",
    "print(t_zeros_long)\n",
    "\n",
    "print(t_zeros_long.dtype, t_zeros.dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# you can also initialize the tensor with values\n",
    "t_long = torch.LongTensor([0,1,10,20])\n",
    "print(t_long)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# now, let's turn to the serious stuff: gradient computation! :)\n",
    "\n",
    "t = torch.rand(2, 10)\n",
    "\n",
    "# by default no gradient will be required for t :(\n",
    "print(t.requires_grad)\n",
    "\n",
    "# so we ask for it explicitly (note the underscore: in place operation!)\n",
    "t.requires_grad_(True)\n",
    "print(t.requires_grad)\n",
    "\n",
    "# We can also set this to true at creation\n",
    "t = torch.rand(2, 10, requires_grad=True)\n",
    "print(t.requires_grad)\n",
    "\n",
    "# now, let's do a stupid operation and compute the gradient\n",
    "# this sum over all element of t,\n",
    "# it return a tensor with a single value\n",
    "z = t.sum()\n",
    "print(z.shape, z.requires_grad)\n",
    "\n",
    "# backpropagation!\n",
    "z.backward()\n",
    "\n",
    "# print the gradient of t,\n",
    "# it should be a vector full of 1, do you understand why?\n",
    "print(t.grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# if I call backward a second time, it will accumulate the gradient\n",
    "# so it will be a tensort full of 2\n",
    "z.backward()\n",
    "print(t.grad)\n",
    "\n",
    "# we can reset the gradient of the tensor to zero\n",
    "# => note the in-place operation\n",
    "t.grad.zero_()\n",
    "print(t.grad)\n",
    "\n",
    "# and if we backprop again, it will be a tensor full of one again\n",
    "z.backward()\n",
    "print(t.grad)\n",
    "\n",
    "# and this highlight one of the major source of bug in pytorch:\n",
    "# Do not forget to reset your gradients!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# similarly to the previous lab exercise,\n",
    "# the parameter of your network must be encapsulated in a Parameter object:\n",
    "# https://pytorch.org/docs/stable/generated/torch.nn.parameter.Parameter.html\n",
    "# You should understand why :)\n",
    "#\n",
    "# however, Pytorch comes with a lot of modules already made!\n",
    "# they are in the torch.nn that we often just rename as nn\n",
    "\n",
    "# https://pytorch.org/docs/stable/generated/torch.nn.Linear.html#torch.nn.Linear\n",
    "linear = nn.Linear(10, 20)\n",
    "\n",
    "# parameter of the linear transformation:\n",
    "# projection matrix W and bias\n",
    "# look, its a parameter object!\n",
    "print(type(linear.weight))\n",
    "print(type(linear.bias))\n",
    "\n",
    "# the Linear class is not a subtype of parameter\n",
    "# but of Module, which represent a network part.\n",
    "print(type(linear), isinstance(linear, nn.Module))\n",
    "\n",
    "# so remember that we often want to compute values on batches.\n",
    "# in pytorch, datapoints will be rows instead of columns\n",
    "# (contrary to the two previous lab exercises with numpy!)\n",
    "# So, for example, we can create a batch with two random elements,\n",
    "# each one of size 10, i.e. the input shape is (2, 10)\n",
    "t_inputs = torch.rand(2, 10)\n",
    "print(t_inputs.shape)\n",
    "\n",
    "# compute the hidden representation after the linear transformation!\n",
    "# note that we use the object as a function for this\n",
    "t_outputs = linear(t_inputs)\n",
    "print(t_outputs.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data loading and conversion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download mnist dataset \n",
    "if(\"mnist.pkl.gz\" not in os.listdir(\".\")):\n",
    "    # this link doesn't work any more,\n",
    "    # seach on google for the file \"mnist.pkl.gz\"\n",
    "    # and download it\n",
    "    !wget https://github.com/mnielsen/neural-networks-and-deep-learning/raw/master/data/mnist.pkl.gz\n",
    "\n",
    "# if you have it somewhere else, you can comment the lines above\n",
    "# and overwrite the path below\n",
    "mnist_path = \"./mnist.pkl.gz\"\n",
    "# load the 3 splits\n",
    "train_data, dev_data, test_data = dataset_loader.load_mnist(mnist_path)\n",
    "\n",
    "def build_torch_inputs(data):\n",
    "    x, y = data\n",
    "    ret = list()\n",
    "    \n",
    "    for i in range(x.shape[0]):\n",
    "        input_tensor = torch.from_numpy(x[i]).reshape(1, -1)\n",
    "        output_value = int(y[i])\n",
    "        \n",
    "        ret.append({\n",
    "            \"input_tensor\": input_tensor,\n",
    "            \"output_value\": output_value\n",
    "        })\n",
    "        \n",
    "    return ret\n",
    "        \n",
    "train_data = build_torch_inputs(train_data)\n",
    "dev_data = build_torch_inputs(dev_data)\n",
    "test_data = build_torch_inputs(test_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train_data is a list,\n",
    "# each element is a dictionnary with two keys:\n",
    "# - input_tensor: the input image as a row vector\n",
    "# - output_value: the gold label\n",
    "\n",
    "print(train_data[10][\"input_tensor\"].shape)\n",
    "print(train_data[10][\"output_value\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Instead of computing the loss on a single input or on the full dataset,\n",
    "# it is more common to compute it on a subset of the data, called a batch or minibatch.\n",
    "# For example, if the we use a batch of size 10,\n",
    "# the input of the network will be a tensor of shape (10, 784)\n",
    "# where each row is a single input.\n",
    "#\n",
    "# In the data, we already transformed each input into a row vector,\n",
    "# so we only need to concatenate them.\n",
    "#\n",
    "# Here we who an example, of course this is done dynamically during training\n",
    "\n",
    "# Constructing a batched input\n",
    "batch_size = 10\n",
    "first_element = 20 # index in the training set of the first element of the batch\n",
    "\n",
    "# the cat() function concatenates a list of tensor along a dimension\n",
    "batch_input = torch.cat(\n",
    "    [\n",
    "        data[\"input_tensor\"]\n",
    "        for data in train_data[first_element:first_element + batch_size]\n",
    "    ],\n",
    "    # we want to concatenate on the batch dimension,\n",
    "    # i.e. the first dimension\n",
    "    dim=0\n",
    ")\n",
    "print(batch_input.shape)  # batch of ten flat images (10, 784)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# just a helper function\n",
    "def build_batch(data):\n",
    "    batch_inputs = torch.cat(\n",
    "        [data[\"input_tensor\"] for data in data],\n",
    "        dim=0\n",
    "    )\n",
    "\n",
    "    labels = torch.LongTensor([data[\"output_value\"] for data in data ])\n",
    "    \n",
    "    return batch_inputs, labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Network definition and training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# A network network is a class extending nn.Module\n",
    "class MLPClassifier(nn.Module):\n",
    "    # constructor, you can define any argument you need\n",
    "    # to parameterize your network\n",
    "    def __init__(self, input_dim, hidden_dim, output_dim):\n",
    "        # you must always call the parent constructor,\n",
    "        # other it will fail when you will run the network :)\n",
    "        super().__init__()\n",
    "        \n",
    "        # Create the projections:\n",
    "        # - the first one project from input to hidden space\n",
    "        # - the second one from hidden space to output space (i.e. logits, weights of each class)\n",
    "        # Note that if you want to use list or dictionnaries instead of directly\n",
    "        # setting attributes of the object, you need to use special containers:\n",
    "        # https://pytorch.org/docs/stable/nn.html#containers\n",
    "        # do you understand why?\n",
    "        self.hidden_proj = nn.Linear(input_dim, hidden_dim)\n",
    "        self.output_proj = nn.Linear(hidden_dim, output_dim)\n",
    "                \n",
    "        # custom initialization\n",
    "        # note that:\n",
    "        # - we encapsulate in torch.no_grad() to disable autograd here\n",
    "        # - we use inplace functions (i.e. with an underscore at the end)\n",
    "        with torch.no_grad():\n",
    "            torch.nn.init.kaiming_uniform_(self.hidden_proj.weight.data)\n",
    "            torch.nn.init.kaiming_uniform_(self.output_proj.weight.data)\n",
    "            \n",
    "            self.hidden_proj.bias.zero_()\n",
    "            self.output_proj.bias.zero_()\n",
    "        \n",
    "    # the forward function is the one that will be called to compute outputs.\n",
    "    # note that we never call it directly:\n",
    "    # we will use the object as a function, as with linear layers\n",
    "    def forward(self, inputs):\n",
    "        # first proj\n",
    "        z = self.hidden_proj(inputs)\n",
    "        # apply relu\n",
    "        z = torch.relu(z)\n",
    "\n",
    "        # apply output proj and return\n",
    "        return self.output_proj(z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example of a training loop! :)\n",
    "\n",
    "# hyper-parameters\n",
    "n_epochs = 10\n",
    "batch_size = 10\n",
    "\n",
    "# Build the network\n",
    "network = MLPClassifier(784, 200, 10)\n",
    "\n",
    "# Build the optimizer, i.e. the object that will update parameters\n",
    "# using the gradient information.\n",
    "# \n",
    "# SGD is standard gradient descent, but there are many alternative!\n",
    "# https://pytorch.org/docs/stable/optim.html\n",
    "#\n",
    "# The first argument of an optimizer is the set of parameters if will update,\n",
    "# we can use network.parameters() to get all the parameters of our network\n",
    "\n",
    "# set momentum=0 for standard gradient descent\n",
    "optimizer = torch.optim.SGD(network.parameters(), lr=1e-3, momentum=0.9)\n",
    "\n",
    "# Adam is a very good alternative.\n",
    "#optimizer = torch.optim.Adam(network.parameters())\n",
    "\n",
    "\n",
    "for epoch in range(n_epochs):\n",
    "    print(\"%i / %i\" % (epoch+1, n_epochs))\n",
    "    \n",
    "    # shuffle the dataset\n",
    "    # its a good practice to do this at the beginning of each epoch\n",
    "    random.shuffle(train_data)\n",
    "    \n",
    "    # pass the network in training mode,\n",
    "    # i.e. dropout will be applied if the dropout module is called\n",
    "    network.train()\n",
    "    \n",
    "    for first_element in range(0, len(train_data), batch_size):\n",
    "        # IMPORTANT\n",
    "        # as gradient is accumulated, we need to set all gradients to 0\n",
    "        # there are several ways of doing that,\n",
    "        # the simplest is to call optimizer.zero_grad()\n",
    "        # that set all parameters tracked by the optimizer to 0\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        # build our batched input\n",
    "        batch_input, labels = build_batch(train_data[first_element:first_element + batch_size])\n",
    "        \n",
    "        # compute the output weights/logits\n",
    "        logits = network(batch_input)\n",
    "        \n",
    "        # compute the loss\n",
    "        # https://pytorch.org/docs/stable/nn.functional.html#cross-entropy\n",
    "        # the torch.nn.functional packages (renamed as F) contains many\n",
    "        # useful functions that are not network subpart (neither parameters or modules)\n",
    "        loss = F.cross_entropy(logits, labels)\n",
    "        \n",
    "        # compute the gradient\n",
    "        loss.backward()\n",
    "\n",
    "        # update parameters wrt to gradient information!\n",
    "        optimizer.step()\n",
    "        \n",
    "    # at the end of each epoch we evaluate on dev\n",
    "    # eval on dev data\n",
    "    n_correct = 0\n",
    "    # disable auto-grad as we don't need that during evaluation\n",
    "    # this speed things a little bit + use less memory\n",
    "    with torch.no_grad(): \n",
    "        # pass network in eval mode,\n",
    "        # i.e. if the dropout module is called,\n",
    "        # it won't be applied\n",
    "        network.eval()\n",
    "        \n",
    "        for first_element in range(0, len(dev_data), batch_size):\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            batch_input, labels = build_batch(dev_data[first_element:first_element + batch_size])\n",
    "\n",
    "            logits = network(batch_input)\n",
    "            \n",
    "            # logits is a tensor of shape (batch dim, n labels),\n",
    "            # to compute the prediction we just compute the argmax\n",
    "            # along the label dimension\n",
    "            prediction = logits.argmax(dim=1)\n",
    "            \n",
    "            # compare prediction to gold and add to the counter\n",
    "            # Be carefull: the .item() is used to get a float value\n",
    "            # instead of a pytorch tensor\n",
    "            n_correct += (prediction == labels).sum().item()\n",
    "        \n",
    "    print(\"Dev acc: %.2f\" % (100 * n_correct / len(dev_data)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TODO\n",
    "\n",
    "The goal of this lab exercise is that you play a little bit with the code above so you can learn how to use pytorch.\n",
    "I list here a sequence of things that you should be able to implement.\n",
    "\n",
    "It is really important that you learn how to do that, it will be important for the project.\n",
    "Of course you need to create the network variant and test it. :)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. Regularization\n",
    "\n",
    "You can try two types of regularization (they can be combined together):\n",
    "\n",
    "- weight decay: it is a parameter of the optimizer\n",
    "- dropout\n",
    "\n",
    "For dropout, you need to create a dropout layer as part of your network. :)\n",
    "It will be automatically enabled/disabled when you call network.train()/.eval().\n",
    "https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html#torch.nn.Dropout"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = torch.ones(2, 4)\n",
    "print(t)\n",
    "\n",
    "dropout = nn.Dropout(0.5)\n",
    "\n",
    "# activate train mode\n",
    "dropout.train()\n",
    "t2 = dropout(t)\n",
    "\n",
    "print(t2)\n",
    "\n",
    "dropout.eval()\n",
    "t3 = dropout(t)\n",
    "print(t3)\n",
    "\n",
    "# WARNING => of course you don't directly call these functions in the dropout object,\n",
    "# but instead you call the one of the network that will recursively call it to all\n",
    "# its module attributes!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A commong trick for training neural networks is gradient clipping: if the norm of the gradient is too big, we rescale the gradient. This trick can be used to prevent exploding gradients and also to make \"too big steps\" in the wrong direction due the use of approximate gradient computation in SGD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_loss.backward()  # compute gradient\n",
    "torch.nn.utils.clip_grad_value_(network.parameters(), 5.)  # clip gradient if its norm exceed 5\n",
    "optimizer.step()  # update parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. Deeper network\n",
    "\n",
    "The second exercise will be to create a deep network!\n",
    "\n",
    "We will explore 2 different ways of doing that.\n",
    "\n",
    "**(1)** The most simple technique is to build a list of linear projection in the constructor and set it as an attribute of the network. But **warning**: you should not use a Python list directly, you must instead use a nn.ModuleList(). Luckily, it works as a list: you can append objects and loops of the content. You will also need to update the initialization (to do a loop over all layers!) and the forward pass."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example 1\n",
    "class DeepMLPClassifier1(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dim, output_dim, n_hidden_layers):\n",
    "        super().__init__()\n",
    "        \n",
    "        # TODO...\n",
    "            \n",
    "        self.output_proj = nn.Linear(d, output_dim)\n",
    "        self.dropout = nn.Dropout(0.5)\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            torch.nn.init.kaiming_uniform_(self.output_proj.weight.data)\n",
    "            self.output_proj.bias.zero_()\n",
    "            \n",
    "            # TODO\n",
    "        \n",
    "    def forward(self, inputs):\n",
    "        # TODO...\n",
    "\n",
    "        return self.output_proj(z)\n",
    "    \n",
    "network = DeepMLPClassifier1(100, 200, 10, 3)\n",
    "\n",
    "# small check that must pass,\n",
    "# but you should also train it correctly to see if results improve!\n",
    "batch = torch.rand(10, 100)\n",
    "output = network(batch)\n",
    "print(output.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**(2)** The second technique is based on a nn.Sequential() object: https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html#torch.nn.Sequential\n",
    "\n",
    "The idea behind a Sequential() object is that it is a list of sub-modules. When you call the object, it will just execute one module after the other, passing as input of the next one the result of the previous one.\n",
    "\n",
    "Here is an example on how to use this to construct a single projection with non-linearity and dropout:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example of usage: we define a projection as a Sequential object\n",
    "seq = nn.Sequential(\n",
    "    nn.Linear(10, 5),\n",
    "    nn.ReLU(),\n",
    "    nn.Dropout(0.5)  \n",
    ")\n",
    "\n",
    "# batched input\n",
    "inputs = torch.rand(3, 10)\n",
    "\n",
    "# will call successively the 3 subnetworks,\n",
    "# i.e. it will apply linear transformation,\n",
    "# then relu and then dropout\n",
    "outputs = seq(inputs)\n",
    "print(outputs.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Unfortunately, it is a little bit more difficult to create than nn.Sequential() because it doesn't have an append() method... but you can use list comprehension + transform the list as a sequence of argument to the constructor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# the list is a single argument\n",
    "print([1, 2, 3])\n",
    "\n",
    "# here we call print with 3 different arguments,\n",
    "# notice how the output is different\n",
    "print(1, 2, 3)\n",
    "\n",
    "# so, how do we call a function by passing the values\n",
    "# from a list as separate argument?\n",
    "# Well, like this:\n",
    "print(*[1, 2, 3])\n",
    "\n",
    "# notive that this last output is similar to the second one,\n",
    "# and different from the first! :)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example 2\n",
    "class DeepMLPClassifier2(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dim, output_dim, n_hidden_layers):\n",
    "        super().__init__()\n",
    "        \n",
    "        # TODO\n",
    "        \n",
    "        self.output_proj = nn.Linear(input_dim if n_hidden_layers == 0 else hidden_dim, output_dim)\n",
    "                \n",
    "        with torch.no_grad():\n",
    "            torch.nn.init.kaiming_uniform_(self.output_proj.weight.data)\n",
    "            self.output_proj.bias.zero_()\n",
    "            \n",
    "            #TODO\n",
    "            \n",
    "        \n",
    "    def forward(self, inputs):\n",
    "        # TODO\n",
    "        return self.output_proj(z)\n",
    "    \n",
    "network = DeepMLPClassifier2(100, 200, 10, 3)\n",
    "\n",
    "# small check that must pass,\n",
    "# but you should also train it correctly to see if results improve!\n",
    "batch = torch.rand(10, 100)\n",
    "output = network(batch)\n",
    "print(output.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. Custom module\n",
    "\n",
    "Now, instead of using the Linear layer, we are going to implement a custom submodule that will:\n",
    "\n",
    "- apply an linear/affine transformation\n",
    "- apply a non-linearity\n",
    "- apply dropout\n",
    "\n",
    "However, we will do that **wihout** nn.Linear(). Remember that matrix multiplication is done with operator @.\n",
    "\n",
    "Instead, you need to use you own parameters for the projection matrix and bias.\n",
    "In the course example and in the previous lab, we use a projection defined as follows: Ax + b where x is the input, A the projection matrix and b the bias vector. However, this assume that the input is a column vector, or, if batched, a matrix where each input is a column of x.\n",
    "\n",
    "In Pytorch, we use a different format: a single input is a row vector, and a batched input is a matrix where each row in a different input data. So first, let's think a little bit:\n",
    "\n",
    "- How is the linear projection defined in this case? (no batch, just a single row vector x as input)\n",
    "- what is the shape of A? of b?\n",
    "- in the case of a batched input, you need to be careful so that broadcasting is applied correctly. Think about this and what is implies for parameters shape.\n",
    "\n",
    "To create a parameter in the constructor, you can do: self.whatever = nn.parameter.Parameter(torch.empty(..., ...))\n",
    "\n",
    "it create a tensor that is unitialized!\n",
    "\n",
    "\n",
    "**ANSWERS**\n",
    "\n",
    "xA + b with shapes:\n",
    "\n",
    "- x: (1, input dim)\n",
    "- A: (input dim, output dim)\n",
    "- b: (1, outputdim)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(nn.Module):\n",
    "    def __init__(self, input_dim, output_dim):\n",
    "        super().__init__()\n",
    "        \n",
    "        # TODO\n",
    "            \n",
    "    def forward(self, inputs):\n",
    "        return # TODO\n",
    "\n",
    "\n",
    "class MLPClassifier(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dim, output_dim, n_hidden_layers):\n",
    "        super().__init__()\n",
    "        \n",
    "        # TODO\n",
    "            \n",
    "        self.output_proj = nn.Linear(input_dim if n_hidden_layers == 0 else hidden_dim, output_dim)\n",
    "                \n",
    "        with torch.no_grad():\n",
    "            torch.nn.init.kaiming_uniform_(self.output_proj.weight.data)\n",
    "            self.output_proj.bias.zero_()            \n",
    "        \n",
    "    def forward(self, inputs):\n",
    "        # TODO\n",
    "        return self.output_proj(z)\n",
    "    \n",
    "network = MLPClassifier(100, 200, 10, 3)\n",
    "\n",
    "# small check that must pass,\n",
    "# but you should also train it correctly to see if results improve!\n",
    "batch = torch.rand(10, 100)\n",
    "output = network(batch)\n",
    "print(output.shape)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
