{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "51807e80-8b23-451a-9e1e-9ebc05705106",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "281ae9d2-0bbb-440f-83f5-4419f95abb56",
   "metadata": {},
   "source": [
    "## The initial library\n",
    "Below the initial library (previous one with new functionnalities), you can look at it if you want to now how it works but it is not necessary yet.\n",
    "Some operation are implemented :\n",
    "* Addition if x is a node you can add using `x + k` (notice that k is not necessarilly a node)\n",
    "* Multiplication (or hadamard product) using `x * y`\n",
    "* Selection using `x[i, :]` (same syntax as numpy) you also canset value in the node using `x[i, j] = value` (where x and value are nodes) Notice that this last functionalitie create a new derivate node at each selection/attribution\n",
    "* Sum using x.sum() (with axis eventually)\n",
    "* Average using x.mean()\n",
    "* ...\n",
    "\n",
    "NB : !!! It is clearly not an efficient implementation !!!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "51bbd660-778e-4999-94ae-d6aa165d76ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Operation(object):\n",
    "    @staticmethod\n",
    "    def forward(*args):\n",
    "        raise NotImplementedError(\"It is an abstract method\")\n",
    "    \n",
    "    def __call__(self, *args):\n",
    "        output_node = self.forward(*args)\n",
    "        output_node.set_func(self)\n",
    "        return output_node\n",
    "        \n",
    "    @staticmethod\n",
    "    def backward(*args):\n",
    "        pass\n",
    "class Addition(Operation):\n",
    "    @staticmethod\n",
    "    def forward(x, y):\n",
    "        output_array = x.value + y.value\n",
    "        output_node = ComputationGraphNode(output_array)\n",
    "        output_node.set_input_nodes(x, y)\n",
    "        return output_node\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(x, y, gradient):\n",
    "        return (gradient, gradient)\n",
    "\n",
    "class Selection(Operation):\n",
    "    @staticmethod\n",
    "    def forward(x, slice):\n",
    "        np_x = x.value\n",
    "\n",
    "        output_array = np_x.__getitem__(slice)\n",
    "        \n",
    "        output_node = ComputationGraphNode(output_array)\n",
    "        output_node.set_input_nodes(x)\n",
    "        output_node.set_func_parameters(slice)\n",
    "\n",
    "        return output_node\n",
    "        \n",
    "    @staticmethod\n",
    "    def backward(x, slice, gradient):\n",
    "        np_x = x.value\n",
    "\n",
    "        cgrad = np_x.copy()\n",
    "        cgrad.fill(0)\n",
    "        cgrad.__setitem__(slice, gradient)\n",
    "        \n",
    "        return cgrad,\n",
    "\n",
    "class Multiplication(Operation):\n",
    "    @staticmethod\n",
    "    def forward(x, y):\n",
    "        np_x = x.value\n",
    "        np_y = y.value\n",
    "\n",
    "        output_array = np_x * np_y \n",
    "        \n",
    "        output_node = ComputationGraphNode(output_array)\n",
    "        output_node.set_input_nodes(x, y)\n",
    "\n",
    "        return output_node\n",
    "        \n",
    "    @staticmethod\n",
    "    def backward(x, y, gradient):\n",
    "        np_x = x.value\n",
    "        np_y = y.value\n",
    "        return (np_y * gradient, np_x * gradient) \n",
    "        \n",
    "class Sum(Operation):\n",
    "    @staticmethod\n",
    "    def forward(x, axis):\n",
    "        np_x = x.value\n",
    "\n",
    "        output_array = np.sum(np_x, axis=axis)\n",
    "        \n",
    "        output_node = ComputationGraphNode(output_array)\n",
    "        output_node.set_input_nodes(x)\n",
    "        output_node.set_func_parameters(axis)\n",
    "\n",
    "        return output_node\n",
    "        \n",
    "    @staticmethod\n",
    "    def backward(x, axis, gradient):\n",
    "        np_g = np.ones(x.value.shape)\n",
    "        \n",
    "        return np_g * gradient, \n",
    "\n",
    "class Mean(Operation):\n",
    "    @staticmethod\n",
    "    def forward(x, axis):\n",
    "        np_x = x.value\n",
    "\n",
    "        output_array = np.mean(np_x, axis=axis)\n",
    "        output_node = ComputationGraphNode(output_array)\n",
    "        output_node.set_input_nodes(x)\n",
    "        output_node.set_func_parameters(axis)\n",
    "\n",
    "        return output_node\n",
    "        \n",
    "    @staticmethod\n",
    "    def backward(x, axis, gradient):\n",
    "        np_g = np.ones(x.value.shape)\n",
    "        divider = np.prod(np_g.shape)\n",
    "        if axis is not None:\n",
    "            divider = np_g.shape[axis]\n",
    "        \n",
    "        return (np_g/divider) * gradient, \n",
    "\n",
    "class Max(Operation):\n",
    "    @staticmethod\n",
    "    def forward(x, axis):\n",
    "        np_x = x.value\n",
    "\n",
    "        output_array = np.max(np_x, axis=axis)\n",
    "        \n",
    "        output_node = ComputationGraphNode(output_array)\n",
    "        output_node.set_input_nodes(x)\n",
    "        output_node.set_func_parameters(axis)\n",
    "\n",
    "        return output_node\n",
    "        \n",
    "    @staticmethod\n",
    "    def backward(x, axis, gradient):\n",
    "        np_g = np.zeros(x.value.shape)\n",
    "        np_g[np.argmax(np_g, axis=axis)] = 1.\n",
    "        \n",
    "        return np_g * gradient,    \n",
    "\n",
    "class Pad(Operation):\n",
    "    @staticmethod\n",
    "    def forward(x, where):\n",
    "        np_x = x.value\n",
    "\n",
    "        output_array = np.pad(np_x, where)\n",
    "        \n",
    "        output_node = ComputationGraphNode(output_array)\n",
    "        output_node.set_input_nodes(x)\n",
    "        output_node.set_func_parameters(where)\n",
    "        return output_node\n",
    "        \n",
    "    @staticmethod\n",
    "    def backward(x, where, gradient):\n",
    "        np_g = np.ones(x.value.shape)\n",
    "        def unpad(x, pad_width):\n",
    "            slices = []\n",
    "            for c in pad_width:\n",
    "                e = None if c[1] == 0 else -c[1]\n",
    "                slices.append(slice(c[0], e))\n",
    "            return x[tuple(slices)]\n",
    "        return np_g * unpad(gradient, where),    \n",
    "\n",
    "class SetSelection(Operation):\n",
    "    @staticmethod\n",
    "    def forward(x, slice, value):\n",
    "        np_x = np.copy(x.value)\n",
    "        if value.value.ndim == 1 and value.value.shape[0] == 1:\n",
    "            np_x.__setitem__(slice, value.value[0])\n",
    "        else:\n",
    "            np_x.__setitem__(slice, value.value)\n",
    "        output_array = np_x\n",
    "        \n",
    "        output_node = ComputationGraphNode(output_array)\n",
    "        output_node.set_input_nodes(x, value)\n",
    "        output_node.set_func_parameters(slice)\n",
    "        return output_node\n",
    "        \n",
    "    @staticmethod\n",
    "    def backward(x, value, slice, gradient):\n",
    "        np_x = x.value\n",
    "        np_value = np.ones(value.value.shape)\n",
    "        cgrad = np_x.copy()\n",
    "        cgrad.fill(1)\n",
    "        cgrad.__setitem__(slice, gradient.__getitem__(slice))\n",
    "        return cgrad, np_value * gradient.__getitem__(slice)\n",
    "import copy\n",
    "class ComputationGraphNode(object):\n",
    "    \n",
    "    def __init__(self, data, require_grad=False):\n",
    "        # we initialise the value of the node and the grad\n",
    "        if(not isinstance(data, np.ndarray)):\n",
    "            if(isinstance(data, int) or isinstance(data, float)):\n",
    "                data = [data]\n",
    "            data = np.array(data)\n",
    "        self.value = data\n",
    "        self.grad = None\n",
    "        \n",
    "        self.require_grad = require_grad\n",
    "        self.func = None\n",
    "        self.input_nodes = None\n",
    "        self.func_parameters = []\n",
    "        \n",
    "    @property\n",
    "    def shape(self):\n",
    "        return self.value.shape\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.value)\n",
    "    \n",
    "    def set_input_nodes(self, *nodes):\n",
    "        self.input_nodes = list(nodes)\n",
    "\n",
    "    def set_func_parameters(self, *func_parameters):\n",
    "        self.func_parameters = list(func_parameters)\n",
    "    \n",
    "    def set_func(self, func):\n",
    "        self.func = func\n",
    "\n",
    "    def zero_grad(self):\n",
    "        if self.grad is not None:\n",
    "            self.grad.fill(0)\n",
    "\n",
    "    def set_gradient(self, gradient):\n",
    "        \"\"\"\n",
    "        Accumulate gradient for this tensor\n",
    "        \"\"\"\n",
    "        if gradient.shape != self.value.shape:\n",
    "            print(gradient.shape, self.value.shape)\n",
    "            raise RuntimeError(\"Invalid gradient dimension\")\n",
    "        if self.grad is None:\n",
    "            self.grad = gradient\n",
    "        else:\n",
    "            self.grad += gradient\n",
    "    \n",
    "    def backward(self, g=None):\n",
    "        if g is None:\n",
    "            g = self.value.copy()\n",
    "            g.fill(1.)\n",
    "        self.set_gradient(g)\n",
    "        if self.func is not None:\n",
    "            grad_list = self.func.backward(*(self.input_nodes + self.func_parameters + [g]))\n",
    "            for input_node, ngrad in zip(self.input_nodes, grad_list):\n",
    "                input_node.backward(ngrad)\n",
    "    \n",
    "    def __add__(self, y):\n",
    "        if not isinstance(y, ComputationGraphNode):\n",
    "            y = ComputationGraphNode(y)\n",
    "        return Addition()(self, y)\n",
    "\n",
    "    def __mul__(self, y):\n",
    "        if not isinstance(y, ComputationGraphNode):\n",
    "            y = ComputationGraphNode(y)\n",
    "        return Multiplication()(self, y)\n",
    "\n",
    "    def sum(self, axis=None):\n",
    "        return Sum()(self, axis)\n",
    "        \n",
    "    def mean(self, axis=None):\n",
    "        return Mean()(self, axis)\n",
    "        \n",
    "    def max(self, axis=None):\n",
    "        return Max()(self, axis) \n",
    "\n",
    "    def pad(self, where=None):\n",
    "        return Pad()(self, where)\n",
    "    \n",
    "    def __getitem__(self, slice):\n",
    "        return Selection()(self, slice)\n",
    "\n",
    "    def __str__(self):\n",
    "        return self.value.__str__()\n",
    "\n",
    "    def __repr__(self):\n",
    "        return self.value.__str__()\n",
    "        \n",
    "\n",
    "    def __setitem__(self, slice, value):\n",
    "        # Be carrefull this is clearly unneficient, it will works for the example but not for too large input\n",
    "        new_node = SetSelection()(self, slice, value)\n",
    "        intermediate_node = copy.copy(self)\n",
    "        self.value = new_node.value\n",
    "        self.grad = new_node.grad\n",
    "        self.require_grad = new_node.require_grad\n",
    "        self.func = new_node.func\n",
    "        self.input_nodes = [intermediate_node, value]\n",
    "        self.func_parameters = new_node.func_parameters\n",
    "        \n",
    "class Parameter(ComputationGraphNode):\n",
    "    def __init__(self, data, name=\"default\"):\n",
    "        super().__init__(data, require_grad=True)\n",
    "        self.name  = name\n",
    "\n",
    "    def backward(self, g=None):\n",
    "        if g is not None:\n",
    "            self.set_gradient(g)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6451abf5-3e64-4725-a9c7-b98702b19106",
   "metadata": {},
   "source": [
    "## Lab 5: Implement Convolutional Neural Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0ce7c453-1748-4dc6-a9d1-c8dd31ced0f3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaEAAAGdCAYAAAC7EMwUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAvbElEQVR4nO3df3Dc9X3n8dfuanf1a7X+rR/YFiLYBDBwAROwQ8CQ4kO5MhCncyRcc+baciH86HmcDK3hZtB0pjZDBw+5cXHbJEehgcLNFCg3EMAdsN3UcWtTOHyGEhMMCCwhW7Z+S/vze39w1iBsw/ttS3ws+fmY2Rm0++atz3e/3923v9rd18aiKIoEAEAA8dALAACcuhhCAIBgGEIAgGAYQgCAYBhCAIBgGEIAgGAYQgCAYBhCAIBgKkIv4NPK5bL27dunTCajWCwWejkAAKcoitTf36+mpibF4599rnPSDaF9+/Zp3rx5oZcBADhB7e3tmjt37mfWTNgQevDBB/Vnf/Zn6ujo0LnnnqsHHnhAX//61z/3/8tkMpKkm1YuVCqVMP2ukRH7GdNwzpdSVJa9viLpO3NLJ5Pm2mLJ17tUsN13khQve3v76isrp5trF5x9oav3BV9ZYq6dO/d0V+90ZZW5Nllh35cf857l24/DUqno7G1X4dxOTyhYPp9z9U4k7K8m7H2v3dX7bx59zFX/mz1vmmt7ew64eudyQ+bayHGcSFLic85SPqmq2l5bKpX1b7vbR5/PP8uEDKEnnnhCq1at0oMPPqivfe1r+su//Eu1trbqjTfe0Pz58z/z/z38J7hUKmEeQmXHk2jRGZVXdtQnU74nllTSMSi8Qyjm6e17abDofAJNpexPXJWVaVfvmpoac22t4QExdi0MoU+b2CGUcvX2DKGamlpX71TKt5aKCvtTaSJhf2xK+tw/Z32SdwjFHfeh5/4+zPKSyoS8MWH9+vX6/d//ff3BH/yBzj77bD3wwAOaN2+eNm7cOBG/DgAwSY37EMrn83rllVe0fPnyMdcvX75c27ZtO6I+l8upr69vzAUAcGoY9yF04MABlUol1dfXj7m+vr5enZ2dR9SvW7dO2Wx29MKbEgDg1DFhnxP69N8Coyg66t8H16xZo97e3tFLe7vvBUQAwOQ17m9MmDVrlhKJxBFnPV1dXUecHUlSOp1WOu17MRoAMDWM+5lQKpXSRRddpE2bNo25ftOmTVq6dOl4/zoAwCQ2IW/RXr16tb73ve9p8eLFWrJkif7qr/5K77//vm655ZaJ+HUAgElqQobQDTfcoO7ubv3Jn/yJOjo6tGjRIj333HNqbm6eiF8HAJikYlHk/PTmBOvr61M2m9X3/rM9MaFUsvfPlZ2bG9k/+JescCxEUtLxx9CY7B+c/Hgt9g/nxVXt6l1ff46r/opvfNtcu3DhIlfvRMLz4Unfvvc8NIrFgqt32Xkcej6wWCz6PqxaKNjXnnKkfEhSZZX9OOzt6Xb1VmHAXJqsmeVq/fqbe1z1zzzztLn25ReecfUuy/68EnN+oDTpOA2pdXzWu1Qs67VX3lNvb6/q6uo+s5YUbQBAMAwhAEAwDCEAQDAMIQBAMAwhAEAwDCEAQDAMIQBAMAwhAEAwDCEAQDAMIQBAMBOSHTce+gYiJZO2aJNY3B6BUi6XXetIJO2RGYlyztU7V8yba9OVvl117nmXmmsXL77e1Xv+/C+76uOOuJxCwX6fSFKx6Ng/Cd99GEX2Y6XkyY6SFJV99cMDQ+bavoP7Xb1zg4Pm2oP7u1y9z77QfhxW1mZdvbsPdphrp1XPcPX+d+f6jvH337vQXPtPW/7B1XtosNdcm0raos4Oiyfs9YmEI2oqOvK74465BntXAADGF0MIABAMQwgAEAxDCAAQDEMIABAMQwgAEAxDCAAQDEMIABAMQwgAEAxDCAAQDEMIABDMSZsdVxqpULxkyzXyZBrFKwqudaQcUUwVcV/vWNKer3TRxde5en/z399hrq1KpV29C3l71pgkFQr2+6VYLPp6F+35buWYL1crEbfvn2Sq0tXbmx3nyfiSM39v1yv/aq598H/+3NX76suXmmv/cM1/d/WunlZvru09dMDVe/qsBlf9N6/5hrn23b1vuXo/9jc/M9fGfYe4Egn7eUgU2ZtHjpg5zoQAAMEwhAAAwTCEAADBMIQAAMEwhAAAwTCEAADBMIQAAMEwhAAAwTCEAADBMIQAAMGctLE906rqlEralhdP2mNhYhW+yJnqOnssTHV1xtV70Xn/0Vx72dducfWOFUfMtf3D/a7eyWTSVR85Mjz6PnzH1Xv/vg/Ntd7Ynmlz7NEts884x9W77Mk1kZRI2O/z2c0LXL3PGLA/fmpqnnL1/j+73jTX7tq22dX7rEu+bq796N23Xb3rps1y1SccsUr/7Y7/5updLtr3zxNP/LWrd4UjOqxcTjlq7c+bnAkBAIJhCAEAgmEIAQCCYQgBAIJhCAEAgmEIAQCCYQgBAIJhCAEAgmEIAQCCYQgBAIJhCAEAgjlps+NOr29QZdqYlxW3ZyvlY32udSSr7Llq02fMdPU+84yl5tp4Ke/q3dvxnrk2WedbdyqVdtUXBg6aax/66cOu3r/619fNtYvP/pKr943/ZaW5Nm6P4JIkVaSrXPWxmP0XePfPv/vqJebadfc/4Oq9f/8Bc23LWS2u3h/+5g1z7btvvubq/aXzLnLV53L2x6fzUNEdt/+hubZpbr2r9+aXnzHX9vTacxqLRfsaOBMCAAQz7kOora1NsVhszKWhwZ5GDAA4dUzIn+POPfdc/cM//MPoz56YcwDAqWNChlBFRQVnPwCAzzUhrwnt2bNHTU1Namlp0Xe+8x29886xv6gsl8upr69vzAUAcGoY9yF0ySWX6JFHHtELL7ygn/zkJ+rs7NTSpUvV3d191Pp169Ypm82OXubNmzfeSwIAnKTGfQi1trbq29/+ts477zz91m/9lp599llJ0sMPH/2tt2vWrFFvb+/opb29fbyXBAA4SU3454Rqamp03nnnac+ePUe9PZ1OK532fa4BADA1TPjnhHK5nN588001NjZO9K8CAEwy4z6EfvSjH2nLli3au3ev/vmf/1m/8zu/o76+Pq1caf/0OQDg1DDuf4774IMP9N3vflcHDhzQ7Nmzdemll2r79u1qbm529TltVpOqKm1/pusf7DH3PTjiWob6e0rm2tNOO8/Ve9Y0+31SHOx19S6MDJlrU3WzXL3jMd+/Xd77t/9rrt38q52u3vsO9Jhrz/6SLxZmZpP9TTLOu0QlT66JpELZ/lm7/QeO/iagY+nrtR9bc5tPd/VuPsN+jJfLZVfvygH7O2mHh+2PB0kq5Ydd9YlUrbl2sN/3DuBC3r723/3OTa7eXzn/K+ban/60zVxbKBQlHftd0Z807kPo8ccfH++WAIApiuw4AEAwDCEAQDAMIQBAMAwhAEAwDCEAQDAMIQBAMAwhAEAwDCEAQDAMIQBAMAwhAEAwE/5VDscrKqcVlW3Zcb0D9hyu9zsHXevoGS6Ya884M+vqnY7b/w0wNOTLjuvr3m8vrqh29U461i1JuRF7DteSryxy9U5npptrf/uaq1y9q2rteWBR0X6cSNKhvn5X/XPPbDLXvv9rW2bXYf09B+3F6UpX7+/85xvNtRd/1Z5jJkkz6+eaa0vxlKt330HHfSKprt7+GBoeHHD1rq6pMdcePNDl6j1tuj03Mpu15wDm83lzLWdCAIBgGEIAgGAYQgCAYBhCAIBgGEIAgGAYQgCAYBhCAIBgGEIAgGAYQgCAYBhCAIBgTtrYnv39/arM50y1b31gj6r4cH+3ax3plD2m5Kyzznf19kS9xCsmLnakd9AesSFJFfJF1JzW1GSuXbX6D129+wfsMUyVmTpX76hcNtcmnPvnX//ln1z1/+tv/qe5tvOjDlfvOkcszIFDh1y9i2X7sXLxkq+6elc7YpWmOyJ+JGmo37ed6bqZ5trCyJCrdzHhOFeI7MesJNVlpplrzzr7CnPtyPCwpCdMtZwJAQCCYQgBAIJhCAEAgmEIAQCCYQgBAIJhCAEAgmEIAQCCYQgBAIJhCAEAgmEIAQCCYQgBAII5abPjambOVlWlLbct023Pg5sf983dCxcvMdd+6fQzXL1zQ3324sjVWpnp08y17+xtd/Wed/rprvrq2fbcrphzQ+ODA+baSuPxNLqWeMJcW3DkAErSB++/76o/cOiAubYvN+zqXU7EzLXpmmpX73Rmurk2ly+5eldX2vdPQ2Ojq3e87FvLcJ89qzFW9GU15h1Zc/EK+30iSVFk387T5jWba4eG7JmOnAkBAIJhCAEAgmEIAQCCYQgBAIJhCAEAgmEIAQCCYQgBAIJhCAEAgmEIAQCCYQgBAIJhCAEAgjlps+POvnCJampqTLULL7jI3DdW9GVCzZlpz74qjthzzLzKzmyyWQ3zzLVvvfmWq/dQvuyqnz59jrm2XPDlak1PVdmLHTlZ3rWMjPj2z/zT57vqv3zOInPtb36zx9V70JHz5ZWtqzPX9vT4Hj9VDdPMtQlnZmRZvgy2YrFork1nsq7enuDIyJlLt/f9DnNtfsSeSZgfGTHXciYEAAjGPYS2bt2qa6+9Vk1NTYrFYnr66afH3B5Fkdra2tTU1KSqqiotW7ZMu3fvHq/1AgCmEPcQGhwc1AUXXKANGzYc9fb77rtP69ev14YNG7Rjxw41NDTo6quvVn9//wkvFgAwtbhfE2ptbVVra+tRb4uiSA888IDuvvturVixQpL08MMPq76+Xo899pi+//3vn9hqAQBTyri+JrR37151dnZq+fLlo9el02ldccUV2rZt21H/n1wup76+vjEXAMCpYVyHUGdnpySpvr5+zPX19fWjt33aunXrlM1mRy/z5tnf1QUAmNwm5N1xsdjYrwuOouiI6w5bs2aNent7Ry/t7b6vmgYATF7j+jmhhoYGSR+fETV+4jvdu7q6jjg7OiydTiudTo/nMgAAk8S4ngm1tLSooaFBmzZtGr0un89ry5YtWrp06Xj+KgDAFOA+ExoYGNDbb789+vPevXv12muvacaMGZo/f75WrVqltWvXasGCBVqwYIHWrl2r6upq3XjjjeO6cADA5OceQjt37tSVV145+vPq1aslSStXrtRf//Vf684779Tw8LBuvfVWHTp0SJdccolefPFFZTIZ1+9pOWOh+f8pO+JY3NE6Qz3m0sKwr3chb4/YKOXtMRiSVFU7zV4c90WUvPPrf3PVz114tmMtR3/t8Jjl6UpzrTcSKFa2R7HknZ+DS1X4/ggxf+5p5tpD+7tcvSsdfw6vqEi5end++J65tsP5evDM6bZYL0k6dKDb1TuZ9D01zp3RYK6trvU9F8Zlj8k6eNC3nYe6D5hrE7I/zyaUM9e6h9CyZcsURcfOMorFYmpra1NbW5u3NQDgFEN2HAAgGIYQACAYhhAAIBiGEAAgGIYQACAYhhAAIBiGEAAgGIYQACAYhhAAIBiGEAAgmHH9KofxlE6nlE7bcqoKhYK5b7kw7FpHVLbnNpWKvmyy3NCgubact2cxffw/2EtPa2pytX7j129/ftEnjPT3mmsrqqpdvUsle76byseOmzqamKM8P+Q7rvI99swuSYoV7fv/WN/ddSzVFfangXLZ/liTpHLOfoznB33fqpwbsd/nqWTS1Xvg4NG/hPNY8rnTzbXVNb5jPFlZZS/27XrVz6w111ZU2DMmBwfteYScCQEAgmEIAQCCYQgBAIJhCAEAgmEIAQCCYQgBAIJhCAEAgmEIAQCCYQgBAIJhCAEAgjlpY3sGBwcVj9tm5NDQkL1x3hHzIimVH7G3dsTwSNLIgD2mpJizr0OS+nsOmWszad9hkIx88UTdnfvMtTMbTnP1lvEYkaRy3rfueNyegVJ0xNNIUrG/x1WfTdm3c3rGFwvz+v/dba4tOeOjLjpznrl2oPNdV++hHvuxkqmtdPVO5Xz1QwfsMT/ptCOGR1JNJmOuPfPMZldvT8RTKmWLUZOkvr5+cy1nQgCAYBhCAIBgGEIAgGAYQgCAYBhCAIBgGEIAgGAYQgCAYBhCAIBgGEIAgGAYQgCAYBhCAIBgTtrsuA8//FA1NTWm2uERe65aptqXCTXDMaZzngw7Sflhe3084dtVMU/uWWSvlaQ64345rOPd35hra2vrXL3jVfa1fPCWPSNNkqbNmmOura2153tJ0szpWVf99Ep7blf3/i5X7wHHcWhfxccSZfuxtfNX/+TqXS7ZH/dfXuDLVKvK+o7Dob4D5tqKJnueniTNmj3DXFtd5Xt+c3HkzClhP1I4EwIABMMQAgAEwxACAATDEAIABMMQAgAEwxACAATDEAIABMMQAgAEwxACAATDEAIABHPSxvYMDg6aa2OOOIlCsexaRy6ZsBdHrtaqqq4118Yqkq7eUdm+nfmcPf5EkjJZe4yIJB3c9565tnDWua7e1VXV5tpdr/2rq/c//sv/Mdd+bfFXXL0batKu+h077Gvv3Nfp6p2O2/8tWpnwHYe737Lv+8aDRVfvr1621F5czLt65x1RYJI00tttrh3u2OvqXVh4lrk2XmN/TplI8Zj9eZMzIQBAMAwhAEAw7iG0detWXXvttWpqalIsFtPTTz895vabbrpJsVhszOXSSy8dr/UCAKYQ9xAaHBzUBRdcoA0bNhyz5pprrlFHR8fo5bnnnjuhRQIApib3GxNaW1vV2tr6mTXpdFoNDQ3HvSgAwKlhQl4T2rx5s+bMmaOFCxfq5ptvVlfXsb9kK5fLqa+vb8wFAHBqGPch1NraqkcffVQvvfSS7r//fu3YsUNXXXWVcrncUevXrVunbDY7epk3z/etgwCAyWvcPyd0ww03jP73okWLtHjxYjU3N+vZZ5/VihUrjqhfs2aNVq9ePfpzX18fgwgAThET/mHVxsZGNTc3a8+ePUe9PZ1OK532fXAPADA1TPjnhLq7u9Xe3q7GxsaJ/lUAgEnGfSY0MDCgt99+e/TnvXv36rXXXtOMGTM0Y8YMtbW16dvf/rYaGxv17rvv6q677tKsWbP0rW99a1wXDgCY/NxDaOfOnbryyitHfz78es7KlSu1ceNG7dq1S4888oh6enrU2NioK6+8Uk888YQymYzr91RXV6u62pYLlk6nzH2rKn1/+qupqrT3duY2DXXac7WiYsHVO5aw3yfxyJen570Pe7oGzLWD3b7cs1jZnje26PzzXb2ffO5lc+3/+MnPXb1/Z+l5rvo5mWnm2rl10129Bwr2XLWEI2dOktIxe9bcwpm+14Ir0/bnlN4ee7abJJWc2XGFgr1+6MCHrt4Huvaba2uzvn2frHBkY04Q9xBatmyZoujYSZ0vvPDCCS0IAHDqIDsOABAMQwgAEAxDCAAQDEMIABAMQwgAEAxDCAAQDEMIABAMQwgAEAxDCAAQDEMIABDMhH+Vw/FKJBJKJGy5RpVJ+2bUOHPP6hyZdzVNTa7eBxyZd70fvOPqnayw//uir/sjV+9fv/G6qz4qDplrD334G1fvkQH7N/HOnF3v6n3r7/+uufZP7v8LV+/nX33784s+4bfO/JK5dvFpZ7h6d4wMm2srkvZjVpJqUrb8R0lKFO3rkKRU+ehflHk0g329rt6Dh3xZcwmV7L0dtZI07MhT7K+f4+qdzdaZa4sl+7pzeXuWHmdCAIBgGEIAgGAYQgCAYBhCAIBgGEIAgGAYQgCAYBhCAIBgGEIAgGAYQgCAYBhCAIBgTtrYnlw+r4pk0lQbi8XMfYuRbx2e+ijmnOlxe31uaNDV+sN3/81c+6tfbXP17u3vd9W3NM8315YKBd9a9neYa3NDvnWfdWazuXbVf/1Prt5/9dATrvpNv95jrj1/VqOr91mzT3NU+47xtCPmp6LS17t2+jRzbWEw7+odlcuu+j7Hcaiiby39779lX0e9b9/39vWYa/P5orl2YGDAXMuZEAAgGIYQACAYhhAAIBiGEAAgGIYQACAYhhAAIBiGEAAgGIYQACAYhhAAIBiGEAAgGIYQACCYkzY7Lp8rqCJhyxErFErmvoNDI6519PTaM5D2H+x19a4p2/PgEqm0q3eiZoa5dubMWa7einy5WvEK+2FWLPhytXoOdZtrp0VNrt6d7e+aa8+aX+/qfceN/8FV/8T/3mSuffPgflfv7uE+c21uxJftFzfmP0rSsuv/o6v3jNMXmms/2NXl6l0u23PSJKlpwSJzbdc7b7p6D3XvM9eO9B5y9e6XPdsvl8uZawcH7c9tnAkBAIJhCAEAgmEIAQCCYQgBAIJhCAEAgmEIAQCCYQgBAIJhCAEAgmEIAQCCYQgBAII5aWN7yuVI5XJkqo0cMTLlmK3nYfG4vT7tS9aRklWOUntkhiTNnne6uTY3bI/YkKTa9t+46pPVGXNtPu+L7amZNtNcW1WXdfUuOyKEDnX7YmFOm3eaq/4H37NH2ux8c6+r9zv7Osy1Xc5oqpbTzzDXXvy1xa7eFQn7YzNZZT8GJanjN7td9dUZ+7HV+OWvuHoXP3rHXDvQ9aGr9/74dHNtbmTYXDs0NGSu5UwIABCMawitW7dOF198sTKZjObMmaPrr79eb7311piaKIrU1tampqYmVVVVadmyZdq92/evCgDAqcE1hLZs2aLbbrtN27dv16ZNm1QsFrV8+fIxian33Xef1q9frw0bNmjHjh1qaGjQ1Vdfrf7+/nFfPABgcnO9JvT888+P+fmhhx7SnDlz9Morr+jyyy9XFEV64IEHdPfdd2vFihWSpIcfflj19fV67LHH9P3vf3/8Vg4AmPRO6DWh3t6PX6ScMePj767Zu3evOjs7tXz58tGadDqtK664Qtu2bTtqj1wup76+vjEXAMCp4biHUBRFWr16tS677DItWvTxFzp1dnZKkurrx37BV319/ehtn7Zu3Tpls9nRy7x58453SQCASea4h9Dtt9+u119/XX/7t397xG2xWGzMz1EUHXHdYWvWrFFvb+/opb29/XiXBACYZI7rc0J33HGHnnnmGW3dulVz584dvb6hoUHSx2dEjY2No9d3dXUdcXZ0WDqdVtr9ARsAwFTgOhOKoki33367nnzySb300ktqaWkZc3tLS4saGhq0adOm0evy+by2bNmipUuXjs+KAQBThutM6LbbbtNjjz2mv//7v1cmkxl9nSebzaqqqkqxWEyrVq3S2rVrtWDBAi1YsEBr165VdXW1brzxxgnZAADA5OUaQhs3bpQkLVu2bMz1Dz30kG666SZJ0p133qnh4WHdeuutOnTokC655BK9+OKLymR8sRkAgKnPNYSi6POzmmKxmNra2tTW1na8axr9XdZMuCg6+psejsaaR3dYdU21uXZats7VuyIqmmuj4oCrd8ywrw6bPfd0V2+V7euWpKGRgrk2XpF09Z41355NViz6cumKjuy4eMr3uubg8IirPp2xH1uXX3q+q/fXdKG5Nh63P9YkKdffba7Nlva7evfss2dGlhz7UpIOHfCtpVQumWvPv/ybrt5Dg/b7UCXfYzOS/T6Mxe2v3nhqyY4DAATDEAIABMMQAgAEwxACAATDEAIABMMQAgAEwxACAATDEAIABMMQAgAEwxACAARzXF/l8EUoFIoqFGwRFBUV9s2YPWeWax0zpk8z18Zkj6eRpFjBHq1TKNvjNSRJkT1GJOaolaSq2lpXfbHYY65N1WVdvauzM8y1ffv3uXqn0vbIplLOFwujuC+eqL+/3966MOTqXRw4ZK5NVla6evf1HrT3dkYCVToO24Ee+zokKV/0xd94jq2Brg9cvbP188218ZRv/yRiCUe1/fkqWWHvy5kQACAYhhAAIBiGEAAgGIYQACAYhhAAIBiGEAAgGIYQACAYhhAAIBiGEAAgGIYQACAYhhAAIJiTNjuuIlmhiqQtX2vu3NPMfevrfdlxxcKwuXZkcMTVu1C2ZzGN5H3ZcVHRkWOXH3T1Lg7b7xNJSqXS5toqZzaZivbMtnjMl02mkv0+HBny3YcVjlw6SXrv/XfMtY2Njb61pOxPA6lqX25gbWQ/bhMVvjy94YEec+3Brg5Xb88xK0npiipz7dChA67es+rnmmsTdfYsRUmqzNv3fSFuf6wVCvbHDmdCAIBgGEIAgGAYQgCAYBhCAIBgGEIAgGAYQgCAYBhCAIBgGEIAgGAYQgCAYBhCAIBgTtrYnjPPPEO1tRlTbTZrq5OkQt4XOZN3xLFEkT2GR5IUs/8boOBs3d+fM9dWFH3Nk8mEq75csm+nK25IUr7/kLk2VZHy9S7Y78NC3h5pIklVdTNd9UPDQ+baomPdkjQta19LPOmLsxkY6DfXJtL26BtJGhkYMNfW1Phikg595DsOMzNmm2urMllX7+KI/TkomZnm6l1R4Yjt8RzjjqcUzoQAAMEwhAAAwTCEAADBMIQAAMEwhAAAwTCEAADBMIQAAMEwhAAAwTCEAADBMIQAAMEwhAAAwZy02XE11VWqrbFlSRUL9pynQm7EtY5S3t67HHO1VrFkD1iKVfvyprLJSnNtz0dFV+940peTlhu2Z3wV8wddvYf7e821NXW++9Cz75MVvjy9qlrfWs5YcLa5dqSv29U74ThWvLl0I47HW9e+91y988WSuTZjzKE8bO7801318aL9MVEatufpSVIxP81T7OqdSNXaaxP2Y9xTy5kQACAY1xBat26dLr74YmUyGc2ZM0fXX3+93nrrrTE1N910k2Kx2JjLpZdeOq6LBgBMDa4htGXLFt12223avn27Nm3apGKxqOXLl2twcGzU+DXXXKOOjo7Ry3PPPTeuiwYATA2u14Sef/75MT8/9NBDmjNnjl555RVdfvnlo9en02k1NDSMzwoBAFPWCb0m1Nv78YvCM2bMGHP95s2bNWfOHC1cuFA333yzurq6jtkjl8upr69vzAUAcGo47iEURZFWr16tyy67TIsWLRq9vrW1VY8++qheeukl3X///dqxY4euuuoq5XJHf1fNunXrlM1mRy/z5s073iUBACaZ436L9u23367XX39dv/zlL8dcf8MNN4z+96JFi7R48WI1Nzfr2Wef1YoVK47os2bNGq1evXr0576+PgYRAJwijmsI3XHHHXrmmWe0detWzZ079zNrGxsb1dzcrD179hz19nQ6rXTa9731AICpwTWEoijSHXfcoaeeekqbN29WS0vL5/4/3d3dam9vV2Nj43EvEgAwNbleE7rtttv085//XI899pgymYw6OzvV2dmp4eFhSdLAwIB+9KMf6Ve/+pXeffddbd68Wddee61mzZqlb33rWxOyAQCAyct1JrRx40ZJ0rJly8Zc/9BDD+mmm25SIpHQrl279Mgjj6inp0eNjY268sor9cQTTyiT8cVmAACmPvef4z5LVVWVXnjhhRNa0GG53LBSKVv+UDxuP6Er2+PaPpay52rF5WuejNsz29IzZ7t6x1Q211Y4c8969vneVJnM29dSHPLlag0PD5lr+/vsOXOSVJepMdeWR4Zdvfe3v+2qnzOn3lx7oGi/TySp/9ABc22x4MsmK0X2YyU37LsPBwfsx0rZkdMoSRWy59JJUnV2urk2XTfL1buctL9mHkvZ8jYPq6xMOqrt92E5sj+3kR0HAAiGIQQACIYhBAAIhiEEAAiGIQQACIYhBAAIhiEEAAiGIQQACIYhBAAIhiEEAAjmuL9PaKKVc8MqJ21xMmXF7I1Tvq+NSFXaY3tKxYKrd9IRl5OoSLl6F0v22JHKuqyrd23et53lmP3fOr05X++KtL13IeeLbunvHzDXRpE9mkiS9n/4hqv+wAfV5tqycy11tbXm2upaXwZkrMYeC5Mb9MUqTZ9u7101fY6rd11Ds6u+2N9trs00fv63D3zStKbP/rqcT/KmkiWTjlgyx+O4XCK2BwAwCTCEAADBMIQAAMEwhAAAwTCEAADBMIQAAMEwhAAAwTCEAADBMIQAAMEwhAAAwTCEAADBnLzZcQOHVIryptoobs9Vi5d9uVpK2TPYFPPdnZ6VxMr2LCZJijny9BKpKlfv6hkzfWtJ2/tHcd992N3xgbm2FNmz+iQpP2zPjisMHHT1TjtyAyWp46NOc+2H+31raaqvN9c2z7fnmElSotKeeVeR9B2HKuXMpZVpe0aaJM1o8uW7RbEzzLXptD3zTpISCftjIp709R4p2J9XYjH7c4qnljMhAEAwDCEAQDAMIQBAMAwhAEAwDCEAQDAMIQBAMAwhAEAwDCEAQDAMIQBAMAwhAEAwJ21sz1B3p+LDtsiPeFXG3DcV883dfDky1yaSvkigCkccR+T+94J9LYkKe+yRJBUTBVd9Za19/0ybe7qrtxyRQN0fvOdqnXDEQXV22GN1JKnSEWsiSZnqtLl2zow6V++U41lgeNgWpXVYomCvTzmjdWoy0821ueFhV+/c8JCrvq6x2VxbLIy4evf22+OjKhwxSZJUWWk/rmbOrDXXplL2WCrOhAAAwTCEAADBMIQAAMEwhAAAwTCEAADBMIQAAMEwhAAAwTCEAADBMIQAAMEwhAAAwTCEAADBnLTZcaViWaViyVacs+c85Yb6XOsYdGRlxZ15YNOnZ821iYQ9i0mSyiXjfSepHNlrJSnuXEsiZc8EKzjWLUmZGTPNtVFkzwGUpL4uex5cVZ19X0pSynmsDBywH4dVaV8WYCplr6+ssmeNSVKxYM8Z9GQpSlJltT3LLBH33d+FoX5Xfblsz2qsSNnzDiWpVHLk76WSrt61tfasudpq++M4FhXNtZwJAQCCcQ2hjRs36vzzz1ddXZ3q6uq0ZMkS/eIXvxi9PYoitbW1qampSVVVVVq2bJl279497osGAEwNriE0d+5c3Xvvvdq5c6d27typq666Stddd93ooLnvvvu0fv16bdiwQTt27FBDQ4Ouvvpq9ff7Tm0BAKcG1xC69tpr9c1vflMLFy7UwoUL9ad/+qeqra3V9u3bFUWRHnjgAd19991asWKFFi1apIcfflhDQ0N67LHHJmr9AIBJ7LhfEyqVSnr88cc1ODioJUuWaO/evers7NTy5ctHa9LptK644gpt27btmH1yuZz6+vrGXAAApwb3ENq1a5dqa2uVTqd1yy236KmnntI555yjzs6P30lUX18/pr6+vn70tqNZt26dstns6GXevHneJQEAJin3EDrrrLP02muvafv27frBD36glStX6o033hi9Pfapt55GUXTEdZ+0Zs0a9fb2jl7a29u9SwIATFLuzwmlUimdeeaZkqTFixdrx44d+vGPf6w/+qM/kiR1dnaqsbFxtL6rq+uIs6NPSqfTSqd9nz0AAEwNJ/w5oSiKlMvl1NLSooaGBm3atGn0tnw+ry1btmjp0qUn+msAAFOQ60zorrvuUmtrq+bNm6f+/n49/vjj2rx5s55//nnFYjGtWrVKa9eu1YIFC7RgwQKtXbtW1dXVuvHGGydq/QCAScw1hD766CN973vfU0dHh7LZrM4//3w9//zzuvrqqyVJd955p4aHh3Xrrbfq0KFDuuSSS/Tiiy8qk8m4Fxb9/4ut2B7Hks/bIzAkKS973EfSeV6ZH7bHpXhje+SIhSnlfPdJIu2LHYnF7YeZN1onFrPf6dWZOlfv0kjOXNvQfJar9/ChLld90RE31d/f4+qdnTbDvo68PYZHkpJJ+zFeLPn2vRz7PlU7zdc65o3JssfUVFbXuHr39gzb11EecfVOxO33oeceGRm2P3ZcQ+hnP/vZZ94ei8XU1tamtrY2T1sAwCmK7DgAQDAMIQBAMAwhAEAwDCEAQDAMIQBAMAwhAEAwDCEAQDAMIQBAMAwhAEAw7hTtiXY4tmVwyB5VUZG0R4kUCvY4G0kakj2OwxvbE4/bgzBKvtau2J6iN7an4FtNLJ401w4MDLh6Fwv2fV/M+SJNBoeGzLVDw/bjVZJGHJFAkjTs2Ee5gv2Y9fYuRb44m2LZUZzwPYCGhu37s1zh2z+lCvu+l6S447gtO+KGJN9joqLCGe/leGaJRfba/v+/ZksMVyzyhnVNsA8++IAvtgOAKaC9vV1z5879zJqTbgiVy2Xt27dPmUxmzJfh9fX1ad68eWpvb1ddnS+IcjJhO6eOU2EbJbZzqhmP7YyiSP39/WpqalL8c0JST7o/x8Xj8c+cnHV1dVP6ADiM7Zw6ToVtlNjOqeZEtzObzZrqeGMCACAYhhAAIJhJM4TS6bTuuecepdPp0EuZUGzn1HEqbKPEdk41X/R2nnRvTAAAnDomzZkQAGDqYQgBAIJhCAEAgmEIAQCCmTRD6MEHH1RLS4sqKyt10UUX6R//8R9DL2lctbW1KRaLjbk0NDSEXtYJ2bp1q6699lo1NTUpFovp6aefHnN7FEVqa2tTU1OTqqqqtGzZMu3evTvMYk/A523nTTfddMS+vfTSS8Ms9jitW7dOF198sTKZjObMmaPrr79eb7311piaqbA/Lds5Ffbnxo0bdf75549+IHXJkiX6xS9+MXr7F7kvJ8UQeuKJJ7Rq1SrdfffdevXVV/X1r39dra2tev/990MvbVyde+656ujoGL3s2rUr9JJOyODgoC644AJt2LDhqLffd999Wr9+vTZs2KAdO3aooaFBV199tfr7+7/glZ6Yz9tOSbrmmmvG7NvnnnvuC1zhiduyZYtuu+02bd++XZs2bVKxWNTy5cs1ODg4WjMV9qdlO6XJvz/nzp2re++9Vzt37tTOnTt11VVX6brrrhsdNF/ovowmga9+9avRLbfcMua6L3/5y9Ef//EfB1rR+LvnnnuiCy64IPQyJoyk6Kmnnhr9uVwuRw0NDdG99947et3IyEiUzWajv/iLvwiwwvHx6e2MoihauXJldN111wVZz0Tp6uqKJEVbtmyJomjq7s9Pb2cUTc39GUVRNH369OinP/3pF74vT/ozoXw+r1deeUXLly8fc/3y5cu1bdu2QKuaGHv27FFTU5NaWlr0ne98R++8807oJU2YvXv3qrOzc8x+TafTuuKKK6bcfpWkzZs3a86cOVq4cKFuvvlmdXV1hV7SCent7ZUkzZgxQ9LU3Z+f3s7DptL+LJVKevzxxzU4OKglS5Z84fvypB9CBw4cUKlUUn19/Zjr6+vr1dnZGWhV4++SSy7RI488ohdeeEE/+clP1NnZqaVLl6q7uzv00ibE4X031ferJLW2turRRx/VSy+9pPvvv187duzQVVddpVzO951CJ4soirR69WpddtllWrRokaSpuT+Ptp3S1Nmfu3btUm1trdLptG655RY99dRTOuecc77wfXnSpWgfS+xTX9IWRdER101mra2to/993nnnacmSJfrSl76khx9+WKtXrw64sok11ferJN1www2j/71o0SItXrxYzc3NevbZZ7VixYqAKzs+t99+u15//XX98pe/POK2qbQ/j7WdU2V/nnXWWXrttdfU09Ojv/u7v9PKlSu1ZcuW0du/qH150p8JzZo1S4lE4ogJ3NXVdcSknkpqamp03nnnac+ePaGXMiEOv/PvVNuvktTY2Kjm5uZJuW/vuOMOPfPMM3r55ZfHfOXKVNufx9rOo5ms+zOVSunMM8/U4sWLtW7dOl1wwQX68Y9//IXvy5N+CKVSKV100UXatGnTmOs3bdqkpUuXBlrVxMvlcnrzzTfV2NgYeikToqWlRQ0NDWP2az6f15YtW6b0fpWk7u5utbe3T6p9G0WRbr/9dj355JN66aWX1NLSMub2qbI/P287j2Yy7s+jiaJIuVzui9+X4/5Whwnw+OOPR8lkMvrZz34WvfHGG9GqVauimpqa6N133w29tHHzwx/+MNq8eXP0zjvvRNu3b49++7d/O8pkMpN6G/v7+6NXX301evXVVyNJ0fr166NXX301eu+996IoiqJ77703ymaz0ZNPPhnt2rUr+u53vxs1NjZGfX19gVfu81nb2d/fH/3whz+Mtm3bFu3duzd6+eWXoyVLlkSnnXbapNrOH/zgB1E2m402b94cdXR0jF6GhoZGa6bC/vy87Zwq+3PNmjXR1q1bo71790avv/56dNddd0XxeDx68cUXoyj6YvflpBhCURRFf/7nfx41NzdHqVQquvDCC8e8ZXIquOGGG6LGxsYomUxGTU1N0YoVK6Ldu3eHXtYJefnllyNJR1xWrlwZRdHHb+u95557ooaGhiidTkeXX355tGvXrrCLPg6ftZ1DQ0PR8uXLo9mzZ0fJZDKaP39+tHLlyuj9998PvWyXo22fpOihhx4arZkK+/PztnOq7M/f+73fG30+nT17dvSNb3xjdABF0Re7L/kqBwBAMCf9a0IAgKmLIQQACIYhBAAIhiEEAAiGIQQACIYhBAAIhiEEAAiGIQQACIYhBAAIhiEEAAiGIQQACIYhBAAI5v8BQU5TfwzFejYAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "dog_img = Image.open(\"dog.jpg\")\n",
    "rescaled_img = dog_img.crop((1500,1000,4500,4000)).resize((32,32))\n",
    "img = np.array(rescaled_img)\n",
    "pltimg = plt.imshow(img, cmap=\"gray\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8a238c00-8a2c-4197-b00c-cc0985de6c93",
   "metadata": {},
   "source": [
    "## The convolution kernel \n",
    "We define the following kernel:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b454261a-7361-4c6c-a305-76ef1f406ca7",
   "metadata": {},
   "outputs": [],
   "source": [
    "kernel_array = np.expand_dims(np.array([[0,-1,0], [-1, 4 , -1], [0, -1, 0]]), -1).repeat( 3, axis=-1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d97ae03-9e9b-4560-9242-82f532d5ebd1",
   "metadata": {},
   "source": [
    "## The convolution module"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef780799-e307-4dca-8644-a0a3f6b4bc66",
   "metadata": {},
   "source": [
    "Let start with an example using the library defined at the top of the notebook. If all necessary operation to implement the convolution, using our library with basics operations will leads to computational error (python does not realy like recursion).\n",
    "\n",
    "\n",
    "**Question 1:** Create a function convolution with two loops that will performs the convolution of an image (**img**) with a kernel (**kernel**) given in parameters. You should use up to 2 loops (visit all pixels of the input image) and apply the convolution. Notice that we consider the mean pooling approach. We will not consider stride at first and only one channel for the output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "681be377-287e-4caf-8339-6922a0c55dd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def convolution_with_loop(img, kernel):\n",
    "    iwidth, iheight, ichannel = img.shape\n",
    "    kwidth, kheight, kchannel = kernel.shape\n",
    "\n",
    "    out_width = \n",
    "    out_height = \n",
    "\n",
    "    output = ComputationGraphNode(np.zeros((out_width, out_height)))\n",
    "\n",
    "    return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c92ceefa-941b-4577-bf69-93bf8fe867b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "kernel = Parameter(kernel_array, 'kernel')\n",
    "input = ComputationGraphNode(img/256.)\n",
    "convolued_image = convolution_with_loop(input, kernel)\n",
    "fig = plt.imshow(convolued_image.value/convolued_image.value.max(), cmap='gray')\n",
    "convolued_image.mean().backward()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "51fe3666-946f-4012-815b-1e73a877c120",
   "metadata": {
    "scrolled": true
   },
   "source": [
    "**Question 2:** What happens when we increase the size of the image? Why? What is the solution? (enginering question)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b6f8b37-580e-4bbf-ae2f-4d65a19b54c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "rescaled_img = dog_img.crop((1500,1000,4500,4000)).resize((128,128))\n",
    "img = np.array(rescaled_img)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c398c9d-6e0e-4876-844d-953d78213a2a",
   "metadata": {},
   "source": [
    "**Question 3:** Implement the forward of the convolution operator, for different stride, padding and different channels as output (using loop). All parameters here will be numpy array, you can use loops"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 404,
   "id": "37b2d653-2799-415e-9fe0-f70447f2052d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def Conv2D(input_data, kernel, stride=1, padding=1):\n",
    "    raise NotImplementedError()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0365a070-7510-4fe2-bc41-40e220f69e24",
   "metadata": {},
   "source": [
    "**Question 4:** Try different kernel and plot the different output "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99101627-5c55-47f0-a7df-901649f18777",
   "metadata": {},
   "outputs": [],
   "source": [
    "dog_img = Image.open(\"dog.jpg\")\n",
    "rescaled_img = dog_img.crop((1500,1000,4500,4000)).resize((256,256))\n",
    "img = np.array(rescaled_img)/256.\n",
    "\n",
    "raise NotImplementedError()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ca88082-459b-4794-9565-8aa3ec55fb72",
   "metadata": {},
   "source": [
    "**Question 5:** (Bonus) Implement the forward and backward operation (but go to the next lab exercise first)\n",
    "\n",
    "\n",
    "**NB:** The backward for both weights and input are also convolutions (if no stride classic convolution else dilated convolution)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4656f0d0-0820-46e8-a4b7-0a46b257f879",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Convolution2DLoop(Operation):\n",
    "    @staticmethod\n",
    "    def forward(input_node, kernels_node, bias_node, stride=1, padding=1):\n",
    "        pass\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(input_node, kernels_node, bias_node, stride, padding, output_gradient):\n",
    "        pass"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "604eeb8a",
   "metadata": {},
   "source": [
    "# Implement with pytorch\n",
    "\n",
    "\n",
    "Implment a pytorch version of the CNN and train it on MNIST (see Lab3). Read the pytorch documentation!!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42c15a7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import libs that we will use\n",
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import math\n",
    "\n",
    "# To load the data we will use the script of Gaetan Marceau Caron\n",
    "# You can download it from the course webiste and move it to the same directory that contains this ipynb file\n",
    "import dataset_loader\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d205eff",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download mnist dataset \n",
    "if(\"mnist.pkl.gz\" not in os.listdir(\".\")):\n",
    "    # this link doesn't work any more,\n",
    "    # seach on google for the file \"mnist.pkl.gz\"\n",
    "    # and download it\n",
    "    !wget https://github.com/mnielsen/neural-networks-and-deep-learning/raw/master/data/mnist.pkl.gz\n",
    "\n",
    "# if you have it somewhere else, you can comment the lines above\n",
    "# and overwrite the path below\n",
    "mnist_path = \"./mnist.pkl.gz\"\n",
    "\n",
    "# load the 3 splits\n",
    "train_data, dev_data, test_data = dataset_loader.load_mnist(mnist_path)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
