{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f3710942-d267-47db-b6cf-15cce55fef50",
   "metadata": {},
   "source": [
    "# Neural network: first experiments with a linear model\n",
    "\n",
    "In this lab exercise we will code a neural network using numpy, without a neural network library.\n",
    "Next week, the lab exercise will be to extend this program with hidden layers and activation functions.\n",
    "\n",
    "The task is digit recognition: the neural network has to predict which digit in $\\{0...9\\}$ is written in the input picture. We will use the [MNIST](http://yann.lecun.com/exdb/mnist/) dataset, a standard benchmark in machine learning.\n",
    "\n",
    "The model is a simple linear  classifier $o = \\operatorname{softmax}(Wx + b)$ where:\n",
    "* $x$ is an input image that is represented as a column vector, each value being the \"color\" of a pixel\n",
    "* $W$ and $b$ are the parameters of the classifier\n",
    "* $\\operatorname{softmax}$ transforms the output weight (logits) into probabilities\n",
    "* $o$ is column vector that contains the probability of each category\n",
    "\n",
    "We will train this model via stochastic gradient descent by minimizing the negative log-likelihood of the data:\n",
    "$$\n",
    "    \\hat{W}, \\hat{b} = \\operatorname{argmin}_{W, b} \\sum_{x, y} - \\log p(y | x)\n",
    "$$\n",
    "Although this is a linear model, it classifies raw data without any manual feature extraction step."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "780e4c48-dd03-40d8-857c-759b997f9b35",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import libs that we will use\n",
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import math\n",
    "\n",
    "# To load the data we will use the script of Gaetan Marceau Caron\n",
    "# You can download it from the course webiste and move it to the same directory that contains this ipynb file\n",
    "import dataset_loader\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3e162f3f-18cb-4fd3-b003-da463965f333",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2024-03-20 18:00:35--  https://github.com/mnielsen/neural-networks-and-deep-learning/raw/master/data/mnist.pkl.gz\n",
      "Résolution de github.com (github.com)… 140.82.121.3\n",
      "Connexion à github.com (github.com)|140.82.121.3|:443… connecté.\n",
      "requête HTTP transmise, en attente de la réponse… 302 Found\n",
      "Emplacement : https://raw.githubusercontent.com/mnielsen/neural-networks-and-deep-learning/master/data/mnist.pkl.gz [suivant]\n",
      "--2024-03-20 18:00:35--  https://raw.githubusercontent.com/mnielsen/neural-networks-and-deep-learning/master/data/mnist.pkl.gz\n",
      "Résolution de raw.githubusercontent.com (raw.githubusercontent.com)… 185.199.108.133, 185.199.110.133, 185.199.111.133, ...\n",
      "Connexion à raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443… connecté.\n",
      "requête HTTP transmise, en attente de la réponse… 200 OK\n",
      "Taille : 17051982 (16M) [application/octet-stream]\n",
      "Enregistre : ‘mnist.pkl.gz’\n",
      "\n",
      "mnist.pkl.gz        100%[===================>]  16,26M  17,2MB/s    ds 0,9s    \n",
      "\n",
      "2024-03-20 18:00:37 (17,2 MB/s) - ‘mnist.pkl.gz’ enregistré [17051982/17051982]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Download mnist dataset \n",
    "if(\"mnist.pkl.gz\" not in os.listdir(\".\")):\n",
    "    # this link doesn't work any more,\n",
    "    # seach on google for the file \"mnist.pkl.gz\"\n",
    "    # and download it\n",
    "    !wget https://github.com/mnielsen/neural-networks-and-deep-learning/raw/master/data/mnist.pkl.gz\n",
    "\n",
    "# if you have it somewhere else, you can comment the lines above\n",
    "# and overwrite the path below\n",
    "mnist_path = \"./mnist.pkl.gz\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5fca3e04-965a-44b1-9419-0b532c3352b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the 3 splits\n",
    "train_data, dev_data, test_data = dataset_loader.load_mnist(mnist_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6371bc90-6323-4109-8419-93f6fa309cae",
   "metadata": {},
   "source": [
    "Each dataset is a list with two elemets:\n",
    "* data[0] contains images\n",
    "* data[1] contains labels\n",
    "\n",
    "Data is stored as numpy.ndarray. You can use data[0][i] to retrieve image number i and data[1][i] to retrieve its label."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "22acb142-3f16-4bb4-9704-1b243f27df5c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "label: 1\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7f2eccd79350>"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAZb0lEQVR4nO3df0zU9x3H8df569QKh4hwMNGhrdpVpalTRmydnURkifFXFm27RJtOo8Vmyro2LK3WbQmbJl3Thmm2dLomVVuTqqnZXCwWSDewETXGbCNC2MQIuJpxh6ho5LM/jLeegnp455uD5yP5Jt7d93v37rdfefrlvhwe55wTAAAP2QDrAQAA/RMBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJgZZD3C7zs5OnT9/XgkJCfJ4PNbjAAAi5JxTW1ubMjIyNGBA9+c5vS5A58+fV2ZmpvUYAIAH1NjYqDFjxnT7eK8LUEJCgqSbgycmJhpPAwCIVDAYVGZmZujreXdiFqDS0lJt3bpVzc3Nys7O1nvvvaeZM2fec7tb33ZLTEwkQAAQx+71NkpMLkL46KOPVFRUpE2bNun48ePKzs5Wfn6+Lly4EIuXAwDEoZgE6O2339aqVav04osv6lvf+pa2b9+u4cOH6w9/+EMsXg4AEIeiHqBr166ppqZGeXl5/3+RAQOUl5enqqqqO9bv6OhQMBgMWwAAfV/UA/TVV1/pxo0bSktLC7s/LS1Nzc3Nd6xfUlIin88XWrgCDgD6B/MfRC0uLlYgEAgtjY2N1iMBAB6CqF8Fl5KSooEDB6qlpSXs/paWFvn9/jvW93q98nq90R4DANDLRf0MaMiQIZo+fbrKyspC93V2dqqsrEy5ubnRfjkAQJyKyc8BFRUVacWKFfr2t7+tmTNn6p133lF7e7tefPHFWLwcACAOxSRAy5Yt03/+8x9t3LhRzc3NevLJJ3Xo0KE7LkwAAPRfHuecsx7i64LBoHw+nwKBAJ+EAABx6H6/jptfBQcA6J8IEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAE4OsBwD6o0AgEPE2+fn5EW8zceLEiLeRpKKiooi3efLJJ3v0Wui/OAMCAJggQAAAE1EP0FtvvSWPxxO2TJ48OdovAwCIczF5D+iJJ57QZ5999v8XGcRbTQCAcDEpw6BBg+T3+2Px1ACAPiIm7wGdOXNGGRkZGj9+vF544QWdPXu223U7OjoUDAbDFgBA3xf1AOXk5Gjnzp06dOiQtm3bpoaGBj3zzDNqa2vrcv2SkhL5fL7QkpmZGe2RAAC9UNQDVFBQoB/84AeaNm2a8vPz9ac//Umtra36+OOPu1y/uLhYgUAgtDQ2NkZ7JABALxTzqwOSkpI0ceJE1dXVdfm41+uV1+uN9RgAgF4m5j8HdOnSJdXX1ys9PT3WLwUAiCNRD9Crr76qiooK/etf/9Lf/vY3LV68WAMHDtRzzz0X7ZcCAMSxqH8L7ty5c3ruued08eJFjR49Wk8//bSqq6s1evToaL8UACCORT1Ae/bsifZTAn1Od++J3s2XX375ULaRpOvXr0e8ze7du3v0Wui/+Cw4AIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMDHIegCgP6qsrLQeATDHGRAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIPIwUMbN261XqEu0pPT7ceAf0AZ0AAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAk+jBTAHTZt2mQ9AvoBzoAAACYIEADARMQBqqys1IIFC5SRkSGPx6P9+/eHPe6c08aNG5Wenq5hw4YpLy9PZ86cida8AIA+IuIAtbe3Kzs7W6WlpV0+vmXLFr377rvavn27jh49qkceeUT5+fm6evXqAw8LAOg7Ir4IoaCgQAUFBV0+5pzTO++8ozfeeEMLFy6UJH3wwQdKS0vT/v37tXz58gebFgDQZ0T1PaCGhgY1NzcrLy8vdJ/P51NOTo6qqqq63Kajo0PBYDBsAQD0fVENUHNzsyQpLS0t7P60tLTQY7crKSmRz+cLLZmZmdEcCQDQS5lfBVdcXKxAIBBaGhsbrUcCADwEUQ2Q3++XJLW0tITd39LSEnrsdl6vV4mJiWELAKDvi2qAsrKy5Pf7VVZWFrovGAzq6NGjys3NjeZLAQDiXMRXwV26dEl1dXWh2w0NDTp58qSSk5M1duxYrV+/Xr/85S/12GOPKSsrS2+++aYyMjK0aNGiaM4NAIhzEQfo2LFjevbZZ0O3i4qKJEkrVqzQzp079dprr6m9vV2rV69Wa2urnn76aR06dEhDhw6N3tQAgLjncc456yG+LhgMyufzKRAI8H4Q4sLvfve7iLdZu3ZtxNs8zL+q//3vfyPexufzxWASxKP7/TpufhUcAKB/IkAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgImIfx0D0Jd1dnZGvE1TU1PE2zysT7YeOXJkj7YbMIB/myL2OMoAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABN8GCnwNcFgMOJtNm/eHINJoqOnsyUkJER5EuBOnAEBAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJiIOUGVlpRYsWKCMjAx5PB7t378/7PGVK1fK4/GELfPnz4/WvACAPiLiALW3tys7O1ulpaXdrjN//nw1NTWFlt27dz/QkACAvmdQpBsUFBSooKDgrut4vV75/f4eDwUA6Pti8h5QeXm5UlNTNWnSJK1du1YXL17sdt2Ojg4Fg8GwBQDQ90U9QPPnz9cHH3ygsrIy/frXv1ZFRYUKCgp048aNLtcvKSmRz+cLLZmZmdEeCQDQC0X8Lbh7Wb58eejPU6dO1bRp0zRhwgSVl5dr7ty5d6xfXFysoqKi0O1gMEiEAKAfiPll2OPHj1dKSorq6uq6fNzr9SoxMTFsAQD0fTEP0Llz53Tx4kWlp6fH+qUAAHEk4m/BXbp0KexspqGhQSdPnlRycrKSk5O1efNmLV26VH6/X/X19Xrttdf06KOPKj8/P6qDAwDiW8QBOnbsmJ599tnQ7Vvv36xYsULbtm3TqVOn9Mc//lGtra3KyMjQvHnz9Itf/EJerzd6UwMA4l7EAZozZ46cc90+/pe//OWBBgLQtREjRkS8TVcX/gC9BZ8FBwAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABNR/5XcAGLj8uXLEW9z/PjxHr3W448/3qPtgEhwBgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmODDSIE44ZyLeJsrV67EYBIgOjgDAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBM8GGkQJxIT0+PeJsf/ehHMZgEiA7OgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAE3wYKRAn3n//fesRgKjiDAgAYIIAAQBMRBSgkpISzZgxQwkJCUpNTdWiRYtUW1sbts7Vq1dVWFioUaNGacSIEVq6dKlaWlqiOjQAIP5FFKCKigoVFhaqurpahw8f1vXr1zVv3jy1t7eH1tmwYYM+/fRT7d27VxUVFTp//ryWLFkS9cEBAPEtoosQDh06FHZ7586dSk1NVU1NjWbPnq1AIKD3339fu3bt0ve+9z1J0o4dO/T444+rurpa3/nOd6I3OQAgrj3Qe0CBQECSlJycLEmqqanR9evXlZeXF1pn8uTJGjt2rKqqqrp8jo6ODgWDwbAFAND39ThAnZ2dWr9+vWbNmqUpU6ZIkpqbmzVkyBAlJSWFrZuWlqbm5uYun6ekpEQ+ny+0ZGZm9nQkAEAc6XGACgsLdfr0ae3Zs+eBBiguLlYgEAgtjY2ND/R8AID40KMfRF23bp0OHjyoyspKjRkzJnS/3+/XtWvX1NraGnYW1NLSIr/f3+Vzeb1eeb3enowBAIhjEZ0BOee0bt067du3T0eOHFFWVlbY49OnT9fgwYNVVlYWuq+2tlZnz55Vbm5udCYGAPQJEZ0BFRYWateuXTpw4IASEhJC7+v4fD4NGzZMPp9PL730koqKipScnKzExES98sorys3N5Qo4AECYiAK0bds2SdKcOXPC7t+xY4dWrlwpSfrNb36jAQMGaOnSpero6FB+fr5++9vfRmVYAEDfEVGAnHP3XGfo0KEqLS1VaWlpj4cCrHR2dka8zf38vbidx+OJeJvRo0dHvA3Qm/FZcAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADDRo9+ICvRVL7/8csTb9OSTrQFwBgQAMEKAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmODDSIGvmTBhgvUI3fr9738f8Tbbtm2LwSRAdHAGBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCY8DjnnPUQXxcMBuXz+RQIBJSYmGg9DvqZ1tbWiLcZOXJkxNt4PJ6Itzl27FjE2zz11FMRbwM8qPv9Os4ZEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgYpD1AEBvkpSUFPE2vezzfIG4wRkQAMAEAQIAmIgoQCUlJZoxY4YSEhKUmpqqRYsWqba2NmydOXPmyOPxhC1r1qyJ6tAAgPgXUYAqKipUWFio6upqHT58WNevX9e8efPU3t4ett6qVavU1NQUWrZs2RLVoQEA8S+iixAOHToUdnvnzp1KTU1VTU2NZs+eHbp/+PDh8vv90ZkQANAnPdB7QIFAQJKUnJwcdv+HH36olJQUTZkyRcXFxbp8+XK3z9HR0aFgMBi2AAD6vh5fht3Z2an169dr1qxZmjJlSuj+559/XuPGjVNGRoZOnTql119/XbW1tfrkk0+6fJ6SkhJt3ry5p2MAAOKUx/XwhxjWrl2rP//5z/riiy80ZsyYbtc7cuSI5s6dq7q6Ok2YMOGOxzs6OtTR0RG6HQwGlZmZqUAgoMTExJ6MBgAwFAwG5fP57vl1vEdnQOvWrdPBgwdVWVl51/hIUk5OjiR1GyCv1yuv19uTMQAAcSyiADnn9Morr2jfvn0qLy9XVlbWPbc5efKkJCk9Pb1HAwIA+qaIAlRYWKhdu3bpwIEDSkhIUHNzsyTJ5/Np2LBhqq+v165du/T9739fo0aN0qlTp7RhwwbNnj1b06ZNi8l/AAAgPkX0HpDH4+ny/h07dmjlypVqbGzUD3/4Q50+fVrt7e3KzMzU4sWL9cYbb9z3+zn3+71DAEDvFJP3gO7VqszMTFVUVETylACAforPggMAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmBhkPcDtnHOSpGAwaDwJAKAnbn39vvX1vDu9LkBtbW2SpMzMTONJAAAPoq2tTT6fr9vHPe5eiXrIOjs7df78eSUkJMjj8YQ9FgwGlZmZqcbGRiUmJhpNaI/9cBP74Sb2w03sh5t6w35wzqmtrU0ZGRkaMKD7d3p63RnQgAEDNGbMmLuuk5iY2K8PsFvYDzexH25iP9zEfrjJej/c7cznFi5CAACYIEAAABNxFSCv16tNmzbJ6/Vaj2KK/XAT++Em9sNN7Ieb4mk/9LqLEAAA/UNcnQEBAPoOAgQAMEGAAAAmCBAAwETcBKi0tFTf/OY3NXToUOXk5OjLL7+0Humhe+utt+TxeMKWyZMnW48Vc5WVlVqwYIEyMjLk8Xi0f//+sMedc9q4caPS09M1bNgw5eXl6cyZMzbDxtC99sPKlSvvOD7mz59vM2yMlJSUaMaMGUpISFBqaqoWLVqk2trasHWuXr2qwsJCjRo1SiNGjNDSpUvV0tJiNHFs3M9+mDNnzh3Hw5o1a4wm7lpcBOijjz5SUVGRNm3apOPHjys7O1v5+fm6cOGC9WgP3RNPPKGmpqbQ8sUXX1iPFHPt7e3Kzs5WaWlpl49v2bJF7777rrZv366jR4/qkUceUX5+vq5evfqQJ42te+0HSZo/f37Y8bF79+6HOGHsVVRUqLCwUNXV1Tp8+LCuX7+uefPmqb29PbTOhg0b9Omnn2rv3r2qqKjQ+fPntWTJEsOpo+9+9oMkrVq1Kux42LJli9HE3XBxYObMma6wsDB0+8aNGy4jI8OVlJQYTvXwbdq0yWVnZ1uPYUqS27dvX+h2Z2en8/v9buvWraH7Wltbndfrdbt37zaY8OG4fT8459yKFSvcwoULTeaxcuHCBSfJVVRUOOdu/r8fPHiw27t3b2idf/zjH06Sq6qqshoz5m7fD845993vftf9+Mc/thvqPvT6M6Br166ppqZGeXl5ofsGDBigvLw8VVVVGU5m48yZM8rIyND48eP1wgsv6OzZs9YjmWpoaFBzc3PY8eHz+ZSTk9Mvj4/y8nKlpqZq0qRJWrt2rS5evGg9UkwFAgFJUnJysiSppqZG169fDzseJk+erLFjx/bp4+H2/XDLhx9+qJSUFE2ZMkXFxcW6fPmyxXjd6nUfRnq7r776Sjdu3FBaWlrY/WlpafrnP/9pNJWNnJwc7dy5U5MmTVJTU5M2b96sZ555RqdPn1ZCQoL1eCaam5slqcvj49Zj/cX8+fO1ZMkSZWVlqb6+Xj/72c9UUFCgqqoqDRw40Hq8qOvs7NT69es1a9YsTZkyRdLN42HIkCFKSkoKW7cvHw9d7QdJev755zVu3DhlZGTo1KlTev3111VbW6tPPvnEcNpwvT5A+L+CgoLQn6dNm6acnByNGzdOH3/8sV566SXDydAbLF++PPTnqVOnatq0aZowYYLKy8s1d+5cw8lio7CwUKdPn+4X74PeTXf7YfXq1aE/T506Venp6Zo7d67q6+s1YcKEhz1ml3r9t+BSUlI0cODAO65iaWlpkd/vN5qqd0hKStLEiRNVV1dnPYqZW8cAx8edxo8fr5SUlD55fKxbt04HDx7U559/HvbrW/x+v65du6bW1taw9fvq8dDdfuhKTk6OJPWq46HXB2jIkCGaPn26ysrKQvd1dnaqrKxMubm5hpPZu3Tpkurr65Wenm49ipmsrCz5/f6w4yMYDOro0aP9/vg4d+6cLl682KeOD+ec1q1bp3379unIkSPKysoKe3z69OkaPHhw2PFQW1urs2fP9qnj4V77oSsnT56UpN51PFhfBXE/9uzZ47xer9u5c6f7+9//7lavXu2SkpJcc3Oz9WgP1U9+8hNXXl7uGhoa3F//+leXl5fnUlJS3IULF6xHi6m2tjZ34sQJd+LECSfJvf322+7EiRPu3//+t3POuV/96lcuKSnJHThwwJ06dcotXLjQZWVluStXrhhPHl132w9tbW3u1VdfdVVVVa6hocF99tln7qmnnnKPPfaYu3r1qvXoUbN27Vrn8/lceXm5a2pqCi2XL18OrbNmzRo3duxYd+TIEXfs2DGXm5vrcnNzDaeOvnvth7q6Ovfzn//cHTt2zDU0NLgDBw648ePHu9mzZxtPHi4uAuScc++9954bO3asGzJkiJs5c6arrq62HumhW7ZsmUtPT3dDhgxx3/jGN9yyZctcXV2d9Vgx9/nnnztJdywrVqxwzt28FPvNN990aWlpzuv1urlz57ra2lrboWPgbvvh8uXLbt68eW706NFu8ODBbty4cW7VqlV97h9pXf33S3I7duwIrXPlyhX38ssvu5EjR7rhw4e7xYsXu6amJruhY+Be++Hs2bNu9uzZLjk52Xm9Xvfoo4+6n/70py4QCNgOfht+HQMAwESvfw8IANA3ESAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAm/gfpWESk6H2QqAAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "index = 900\n",
    "label = train_data[1][index]\n",
    "picture = train_data[0][index]\n",
    "\n",
    "print(\"label: %i\" % label)\n",
    "plt.imshow(picture.reshape(28,28), cmap='Greys')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e65aa50a-ea30-4708-b9ca-7dae904687f0",
   "metadata": {},
   "source": [
    "**Question 1:** What are the characteristics of training data? (number of samples, dimension of input, number of labels)\n",
    "\n",
    "The documentation of ndarray class is available here: https://docs.scipy.org/doc/numpy/reference/generated/numpy.ndarray.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "91efaaab-2287-417d-9a62-b4a631b0b5e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def getDimDataset(data):\n",
    "    raise NotImplementedError('Implement the function')\n",
    "    return n_training, n_feature, n_label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8938454-4727-41c7-802c-3955fba9f7e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "getDimDataset(train_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03da32ee-fa5c-48c8-af0f-580eba03f6c5",
   "metadata": {},
   "source": [
    "# 1. Building functions\n",
    "\n",
    "We now need to build functions that are required for the neural network.\n",
    "$$\n",
    "    o = \\operatorname{softmax}(Wx + b) \\\\\n",
    "    L(x, y) = -\\log p(y | x) = -\\log o[y]\n",
    "$$\n",
    "\n",
    "Note that in numpy, operator @ is used for matrix multiplication while * is used for element-wise multiplication.\n",
    "The documentation for linear algebra in numpy is available here: https://docs.scipy.org/doc/numpy/reference/routines.linalg.html\n",
    "\n",
    "The first operation is the affine transformation $v = Wx + b$.\n",
    "To compute the gradient, it is often convenient to write the forward pass as $v[i] = b[i] + \\sum_j W[i, j] x[j]$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "f4989ef8-ec12-487a-988b-a3d657d454a8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10, 5)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[-0.2601596 ],\n",
       "       [ 0.05727994],\n",
       "       [-0.56649217],\n",
       "       [-0.7666645 ],\n",
       "       [ 3.21327553],\n",
       "       [-0.70473022],\n",
       "       [-1.26487746],\n",
       "       [-1.23618592],\n",
       "       [-3.42500955],\n",
       "       [ 4.44976421]])"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "W = np.random.randn(5,10)\n",
    "x = np.random.randn(1,5)\n",
    "print(W.transpose(1,0).shape)\n",
    "np.matmul(W.transpose(1,0), x.transpose(1,0))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c167d320-9a24-43de-9311-0846e2f1a280",
   "metadata": {},
   "source": [
    "**Question 2:**  Complete the two functions `affine_transform` and `backward_affine_transform`. The last function compute the gradient the loss function according to weights of the linear module. The gradient of the loss according to output of the linear module is given as last parameter of the function backward_affine_transform.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "ee5da951-0405-4aaa-9207-b02aff3116cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Input:\n",
    "# - W: projection matrix\n",
    "# - b: bias\n",
    "# - x: input features\n",
    "# Output:\n",
    "# - vector\n",
    "def affine_transform(W, b, x):\n",
    "    raise NotImplementedError('Implement the function')\n",
    "# Input:\n",
    "# - W: projection matrix\n",
    "# - b: bias\n",
    "# - x: input features\n",
    "# - g: incoming gradient\n",
    "# Output:\n",
    "# - g_W: gradient wrt W\n",
    "# - g_b: gradient wrt b\n",
    "def backward_affine_transform(W, b, x, g):\n",
    "    raise NotImplementedError('Implement the function')\n",
    "    return g_W, g_b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "80dfc98c-5ec4-42ac-9d5d-dc671eec3b2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "W = np.asarray([[ 0.63024213,  0.53679375, -0.92079597],\n",
    " [-0.1155045,   0.62780356, -0.67961305],\n",
    " [ 0.08465286, -0.06561815, -0.39778322],\n",
    " [ 0.8242268,   0.58907262, -0.52208052],\n",
    " [-0.43894227, -0.56993247,  0.09520727]])\n",
    "\n",
    "\n",
    "b = np.asarray([ 0.42706842,  0.69636598, -0.85611933, -0.08682553,  0.83160079])\n",
    "x = np.asarray([-0.32809223, -0.54751413,  0.81949319])\n",
    "\n",
    "o_gold = np.asarray([-0.82819732, -0.16640748, -1.17394705, -1.10761496,  1.36568213])\n",
    "g = np.asarray([-0.08938868,  0.44083873, -0.2260743,  -0.96196726, -0.53428805])\n",
    "g_W_gold = np.asarray([[ 0.02932773,  0.04894156, -0.07325341],\n",
    " [-0.14463576, -0.24136543,  0.36126434],\n",
    " [ 0.07417322,  0.12377887, -0.18526635],\n",
    " [ 0.31561399,  0.52669067, -0.78832562],\n",
    " [ 0.17529576,  0.29253025, -0.43784542]])\n",
    "g_b_gold = np.asarray([-0.08938868,  0.44083873, -0.2260743,  -0.96196726, -0.53428805])\n",
    "\n",
    "\n",
    "# quick test of the forward pass\n",
    "o = affine_transform(W, b, x)\n",
    "if o.shape != o_gold.shape:\n",
    "    raise RuntimeError(\"Unexpected output dimension: got %s, expected %s\" % (str(o.shape), str(o_gold.shape)))\n",
    "if not np.allclose(o, o_gold):\n",
    "    raise RuntimeError(\"Output of the affine_transform function is incorrect\")\n",
    "    \n",
    "# quick test if the backward pass\n",
    "g_W, g_b = backward_affine_transform(W, b, x, g)\n",
    "if g_W.shape != g_W_gold.shape:\n",
    "        raise RuntimeError(\"Unexpected gradient dimension for W: got %s, expected %s\" % (str(g_W.shape), str(g_W_gold.shape)))\n",
    "if g_b.shape != g_b_gold.shape:\n",
    "        raise RuntimeError(\"Unexpected gradient dimension for b: got %s, expected %s\" % (str(g_b.shape), str(g_b_gold.shape)))\n",
    "if not np.allclose(g_W, g_W_gold):\n",
    "    raise RuntimeError(\"Gradient of W is incorrect\")\n",
    "if not np.allclose(g_b, g_b_gold):\n",
    "    raise RuntimeError(\"Gradient of b is incorrect\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed21a0ba-a70e-4d81-a2ef-ffff48ca7b29",
   "metadata": {},
   "source": [
    "The softmax function:\n",
    "$$\n",
    "     o = \\operatorname{softmax}(w)\n",
    "$$\n",
    "where $w$ is a vector of logits in $\\mathbb R$ and $o$ a vector of probabilities such that:\n",
    "$$\n",
    "    o[i] = \\frac{\\exp(w[i])}{\\sum_j \\exp(w[j])}\n",
    "$$\n",
    "We do not need to implement the backward for this experiment.\n",
    "\n",
    "**Question 3** Implement the function softmax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3079bfae-61fd-4bb3-aa5b-5e8f560334cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Input:\n",
    "# - x: vector of logits\n",
    "# Output\n",
    "# - vector of probabilities\n",
    "def softmax(x):\n",
    "    raise NotImplementedError('Implement the function')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b72da73-9779-4374-9fdc-16e35e0181e9",
   "metadata": {},
   "source": [
    "**WARNING:** is your implementation numerically stable?\n",
    "\n",
    "The $\\exp$ function results in computations that overflows (i.e. results in numbers that cannot be represented with floating point numbers).\n",
    "Therefore, it is always convenient to use the following trick to improve stability: https://timvieira.github.io/blog/post/2014/02/11/exp-normalize-trick/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5b75087-b8bd-4d9c-bc6a-65f4d68114ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example for testing the numerical stability of softmax\n",
    "# It should return [1., 0. ,0.], not [nan, 0., 0.]\n",
    "z = [1000000,1,100]\n",
    "print(softmax(z))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55067f3f-c6cf-408a-a574-ae805bb804e8",
   "metadata": {},
   "source": [
    "**Question 4**: From the result of the cell above, what can you say about the softmax output, even when it is stable?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04b0495c-9cc4-4877-a14b-1b3a170661a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Just too simple test for the softmax function\n",
    "x = np.asarray([0.92424884, -0.92381088, -0.74666024, -0.87705478, -0.54797015])\n",
    "y_gold = np.asarray([0.57467369, 0.09053556, 0.10808233, 0.09486917, 0.13183925])\n",
    "\n",
    "y = softmax(x)\n",
    "if not np.allclose(y, y_gold):\n",
    "    raise RuntimeError(\"Output of the softmax function is incorrect\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0aabf403-e4a2-4acb-a835-0ae5b7d7a2b4",
   "metadata": {},
   "source": [
    "Finally, we build the loss function and its gradient for training the network.\n",
    "\n",
    "The loss function is the negative log-likelihood defined as:\n",
    "$$\n",
    "    \\mathcal L(x, gold) = -\\log \\frac{\\exp(x[gold])}{\\sum_j \\exp(x[j])} = -x[gold] + \\log \\sum_j \\exp(x[j])\n",
    "$$\n",
    "This function is also called the cross-entropy loss (in Pytorch, different names are used dependending if the inputs are probabilities or raw logits).\n",
    "\n",
    "Similarly to the softmax, we have to rely on the log-sum-exp trick to stabilize the computation: https://timvieira.github.io/blog/post/2014/02/11/exp-normalize-trick/\n",
    "\n",
    "**Question 5:** Implement the forward and backward function for the negative loglikelihood"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14cd59b9-7a7f-446a-81fa-8eaf3abf6de8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Input:\n",
    "# - x: vector of logits\n",
    "# - gold: index of the gold class\n",
    "# Output:\n",
    "# - scalare equal to -log(softmax(x)[gold])\n",
    "def nll(x, gold):\n",
    "    raise NotImplementedError('Implement the function')\n",
    "\n",
    "# Input:\n",
    "# - x: vector of logits\n",
    "# - gold: index of the gold class\n",
    "# - gradient (scalar)\n",
    "# Output:\n",
    "# - gradient wrt x\n",
    "def backward_nll(x, gold, g):\n",
    "    raise NotImplementedError('Implement the function')\n",
    "    return g_x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5df5bdcc-6452-434e-b343-87c08abf0964",
   "metadata": {},
   "outputs": [],
   "source": [
    "# test\n",
    "x = np.asarray([-0.13590009, -0.83649656,  0.03130881,  0.42559402,  0.08488182])\n",
    "y_gold = 1.5695014420179738\n",
    "g_gold = np.asarray([ 0.17609875,  0.08739591, -0.79185107,  0.30875221,  0.2196042 ])\n",
    "\n",
    "y = nll(x, 2)\n",
    "g = backward_nll(x, 2, 1.)\n",
    "\n",
    "if not np.allclose(y, y_gold):\n",
    "    raise RuntimeError(\"Output is incorrect\")\n",
    "\n",
    "if g.shape != g_gold.shape:\n",
    "        raise RuntimeError(\"Unexpected gradient dimension: got %s, expected %s\" % (str(g.shape), str(g_gold.shape)))\n",
    "if not np.allclose(g, g_gold):\n",
    "    raise RuntimeError(\"Gradient is incorrect\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29b0d56d-82ea-478c-a72e-042e5bc63240",
   "metadata": {},
   "source": [
    "The following code test the implementation of the gradient using finite-difference approximation, see: https://timvieira.github.io/blog/post/2017/04/21/how-to-test-gradient-implementations/\n",
    "\n",
    "Your implementation should pass this test."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f66cd42-9cae-496d-84df-5deaf41817c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# this is python re-implementation of the test from the Dynet library\n",
    "# https://github.com/clab/dynet/blob/master/dynet/grad-check.cc\n",
    "\n",
    "def is_almost_equal(grad, computed_grad):\n",
    "    #print(grad, computed_grad)\n",
    "    f = abs(grad - computed_grad)\n",
    "    m = max(abs(grad), abs(computed_grad))\n",
    "\n",
    "    if f > 0.01 and m > 0.:\n",
    "        f /= m\n",
    "\n",
    "    if f > 0.01 or math.isnan(f):\n",
    "        return False\n",
    "    else:\n",
    "        return True\n",
    "\n",
    "def check_gradient(function, weights, true_grad, alpha = 1e-3):\n",
    "    # because input can be of any dimension,\n",
    "    # we build a view of the underlying data with the .shape(-1) method\n",
    "    # then we can access any element of the tensor as a elements of a list\n",
    "    # with a single dimension\n",
    "    weights_view = weights.reshape(-1)\n",
    "    true_grad_view = true_grad.reshape(-1)\n",
    "    for i in range(weights_view.shape[0]):\n",
    "        old = weights_view[i]\n",
    "\n",
    "        weights_view[i] = old - alpha\n",
    "        value_left = function(weights).reshape(-1)\n",
    "\n",
    "        weights_view[i] = old + alpha\n",
    "        value_right = function(weights).reshape(-1)\n",
    "\n",
    "        weights_view[i] = old\n",
    "        grad = (value_right - value_left) / (2. * alpha)\n",
    "\n",
    "        if not is_almost_equal(grad, true_grad_view[i]):\n",
    "            return False\n",
    "\n",
    "        return True"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5e74c83-d581-4499-acb5-6b6ac57de0f7",
   "metadata": {},
   "source": [
    "# 2. Parameter initialization\n",
    "\n",
    "We are now going to build the function that will be used to initialize the parameters of the neural network before training.\n",
    "Note that for parameter initialization you must use **in-place** operations:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b81523ec-c37d-4061-ad8c-59956121eae0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create a random ndarray\n",
    "a = np.random.uniform(-1, 1, (5,))\n",
    "\n",
    "# this does not change the data of the ndarray created above!\n",
    "# it creates a new ndarray and replace the reference stored in a\n",
    "a = np.zeros((5, ))\n",
    "\n",
    "# this will change the underlying data of the ndarray that a points to\n",
    "a[:] = 0\n",
    "\n",
    "# similarly, this creates a new array and change the object pointed by a\n",
    "a = a + 1\n",
    "\n",
    "# while this change the underlying data of a\n",
    "a += 1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8bc53dc-2ca1-4619-a0e2-bdce3f65ee1d",
   "metadata": {},
   "source": [
    "For an affine transformation, it is common to:\n",
    "* initialize the bias to 0\n",
    "* initialize the projection matrix with Glorot initialization (also known as Xavier initialization)\n",
    "\n",
    "The formula for Glorot initialization can be found in equation 16 (page 5) of the original paper: http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf\n",
    "\n",
    "**Question 6:** Fill the two initilization functions below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5171ab4-848a-47f0-92db-81acb31aeabb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def zero_init(b):\n",
    "    raise NotImplementedError('Implement the function')\n",
    "\n",
    "def glorot_init(W):\n",
    "    raise NotImplementedError('Implement the function')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a1aede1-96af-4f2c-97ec-273b95a256c8",
   "metadata": {},
   "source": [
    "# 3. Building and training the neural network\n",
    "\n",
    "In our simple example, creating the neural network is simply instantiating the parameters $W$ and $b$.\n",
    "They must be ndarray object with the correct dimensions."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ccc4a37-d975-45cb-b225-f22ab642ca4b",
   "metadata": {},
   "source": [
    "**Question 7:** Fill the function that create and initialize the parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fad7abf3-fe48-4738-9eb3-2b716cb11cbe",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_parameters(dim_input, dim_output):\n",
    "    W = # TODO\n",
    "    b = # TODO\n",
    "    \n",
    "    return W, b"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "91b83f22-cc7d-4c05-9607-8d8c809558cc",
   "metadata": {},
   "source": [
    "The recent success of deep learning is (partly) due to the ability to train very big neural networks.\n",
    "However, researchers became interested in building small neural networks to improve computational efficiency and memory usage.\n",
    "Therefore, we often want to compare neural networks by their number of parameters, i.e. the size of the memory required to store the parameters."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a3e4375-f265-4bb2-9f37-889713f7e277",
   "metadata": {},
   "source": [
    "**Question 8:** Fill the function that  print the number of parameters of the linear model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7729a857-04ab-4a7a-b85c-cf171b62501d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_n_parameters(W, b):\n",
    "    n = # TODO\n",
    "    print(\"Number of parameters: %i\" % (n))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ce568ce-415c-4750-9169-5c67dc9653cf",
   "metadata": {},
   "source": [
    "We can now create the neural network and print its number of parameters:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8f54c2a-98a1-4c06-94dc-511b05cd7bf2",
   "metadata": {},
   "outputs": [],
   "source": [
    "dim_input = # TODO\n",
    "dim_output = # TODO\n",
    "W, b = create_parameters(dim_input, dim_output)\n",
    "print_n_parameters(W, b)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d99f7f13-4f6d-4c03-a9bf-5776571dec58",
   "metadata": {},
   "source": [
    "Finally, the training loop!\n",
    "\n",
    "The training loop should be structured as follows:\n",
    "* we do **epochs** over the data, i.e. one epoch is one loop over the dataset\n",
    "* at each epoch, we first loop over the data and update the network parameters with respect to the loss gradient\n",
    "* at the end of each epoch, we evaluate the network on the dev dataset\n",
    "* after all epochs are done, we evaluate our network on the test dataset and compare its performance with the performance on dev\n",
    "\n",
    "During training, it is useful to print the following information:\n",
    "* the mean loss over the epoch: it should be decreasing!\n",
    "* the accuracy on the dev set: it should be increasing!\n",
    "* the accuracy on the train set: it shoud be increasing!\n",
    "\n",
    "If you observe a decreasing loss (+increasing accuracy on test data) but decreasing accuracy on dev data, your network is overfitting!\n",
    "\n",
    "Once you have build **and tested** this a simple training loop, you should introduce the following improvements:\n",
    "* instead of evaluating on dev after each loop on the training data, you can also evaluate on dev n times per epoch\n",
    "* shuffle the data before each epoch\n",
    "* instead of memorizing the parameters of the last epoch only, you should have a copy of the parameters that produced the best value on dev data during training and evaluate on test with those instead of the parameters after the last epoch\n",
    "* learning rate decay: if you do not observe improvement on dev, you can try to reduce the step size\n",
    "\n",
    "After you conducted (successful?) experiments, you should write a report with results (in the notebook)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55b5b6ae-43be-4017-8a19-9cbd122302f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# before training, we initialize the parameters of the network\n",
    "zero_init(b)\n",
    "glorot_init(W)\n",
    "\n",
    "n_epochs = 5 # number of epochs\n",
    "step = 0.01 # step size for gradient updates\n",
    "\n",
    "for epoch in range(n_epochs):\n",
    "    # TODO\n",
    "    # ...\n",
    "    \n",
    "# Test evaluation\n",
    "# TODO"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
