{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2a880972-e0ba-4b47-8b37-7567e5a84d31",
   "metadata": {},
   "source": [
    "# Implementation of a transformer model\n",
    "In this exercise we will implement the transformer model (an architecture based on LLama2). The squeleton of the code is provided and you should complete it (mostly self attention mechanism)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24473635-9661-44ad-be96-dd9635e9fd14",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time \n",
    "from time import gmtime, strftime\n",
    "import math\n",
    "import shutil\n",
    "import os\n",
    "\n",
    "\n",
    "\n",
    "# neural network utilities\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch import optim\n",
    "from torch.nn import functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "# dataset an tokenization\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "from rope_embedding import RoPE\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73434fce-7c28-4395-b25e-919797eb153d",
   "metadata": {},
   "source": [
    "## Self Attention with no mask\n",
    "\n",
    "Let consider $Q, K, V \\in R^{B\\times H \\times L \\times N}$ for each batch element (B) for each heads (H) we want to compute $$A = \\frac{Q_{bh}.K_{bh}^{t}}{\\sqrt{N}}$$\n",
    "\n",
    "Attention weights are defined by $$ \n",
    "softmax(A_i) = \\begin{pmatrix}\n",
    "\\frac{e^{A_{i1}}}{\\sum\\limits_{j=1}^L  e^{A_{ij}}} & \n",
    "\\frac{e^{A_{i2}}}{\\sum\\limits_{j=1}^L  e^{A_{ij}}} & \n",
    "\\dots &\n",
    "\\frac{e^{A_{iL}}}{\\sum\\limits_{j=1}^L  e^{A_{ij}}}\n",
    "\\end{pmatrix}$$ \n",
    "And $$ \n",
    "Softmax(A) = \\begin{pmatrix}\n",
    "\\frac{e^{A_{11}}}{\\sum\\limits_{j=1}^L  e^{A_{1j}}} & \n",
    "\\frac{e^{A_{12}}}{\\sum\\limits_{j=1}^L  e^{A_{1j}}} & \n",
    "\\dots &\n",
    "\\frac{e^{A_{1L}}}{\\sum\\limits_{j=1}^L  e^{A_{1j}}} \\\\\n",
    "\\frac{e^{A_{21}}}{\\sum\\limits_{j=1}^L  e^{A_{2j}}} & \n",
    "\\frac{e^{A_{22}}}{\\sum\\limits_{j=1}^L  e^{A_{2j}}} & \n",
    "\\dots &\n",
    "\\frac{e^{A_{2l}}}{\\sum\\limits_{j=1}^L  e^{A_{2j}}} \\\\\n",
    "\\vdots & \\vdots & \\ddots & \\vdots \\\\\n",
    "\\frac{e^{A_{L1}}}{\\sum\\limits_{j=1}^L  e^{A_{Lj}}} & \n",
    "\\frac{e^{A_{L2}}}{\\sum\\limits_{j=1}^L  e^{A_{Lj}}} & \n",
    "\\dots &\n",
    "\\frac{e^{A_{LL}}}{\\sum\\limits_{j=1}^L  e^{A_{Lj}}}\n",
    "\\end{pmatrix}$$ \n",
    "\n",
    "\n",
    "Notice that the function softmax can be used to compute this matrix [`torch.softmax`](https://docs.pytorch.org/docs/stable/generated/torch.softmax.html#torch.softmax).\n",
    "\n",
    "\n",
    "## Masked Attention\n",
    "In many cases, certains tokens will not compute attention with later tokens (padding or decoding approaches). Typically in decoder only architectures we would have the attention weights matrix only considering lower diagonal values :\n",
    "\n",
    "$$ \n",
    "Softmax(A) = \\begin{pmatrix}\n",
    "\\frac{e^{A_{11}}}{\\sum\\limits_{j=1}^1  e^{A_{1j}}} & \n",
    "0 & \n",
    "\\dots & 0 &\n",
    "0 \\\\\n",
    "\\frac{e^{A_{21}}}{\\sum\\limits_{j=1}^2  e^{A_{2j}}} & \n",
    "\\frac{e^{A_{22}}}{\\sum\\limits_{j=1}^2  e^{A_{2j}}} & \n",
    "\\dots & 0 &\n",
    "0 \\\\\n",
    "\\vdots & \\vdots & \\ddots & \\vdots \\\\\n",
    "\\frac{e^{A_{{L-1}1}}}{\\sum\\limits_{j=1}^{L-1}  e^{A_{L-1j}}} & \n",
    "\\frac{e^{A_{{L-1}2}}}{\\sum\\limits_{j=1}^{L-1}  e^{A_{L-1j}}} & \n",
    "\\dots &\n",
    "\\frac{e^{A_{L-1L-1}}}{\\sum\\limits_{j=1}^{L-1}  e^{A_{L-1j}}}& 0 \\\\\n",
    "\\frac{e^{A_{L1}}}{\\sum\\limits_{j=1}^L  e^{A_{Lj}}} & \n",
    "\\frac{e^{A_{L2}}}{\\sum\\limits_{j=1}^L  e^{A_{Lj}}} & \n",
    "\\dots &\n",
    "\\frac{e^{A_{LL-1}}}{\\sum\\limits_{j=1}^L  e^{A_{Lj}}} & \n",
    "\\frac{e^{A_{LL}}}{\\sum\\limits_{j=1}^L  e^{A_{Lj}}} \\\\\n",
    "\\end{pmatrix}$$ \n",
    "\n",
    "A simple implementation would consist in setting the upper diagonal of A to $-\\infty$ and apply the softmax. We will thus consider a matrix name $M$ (or `mask` in the code) having upper diagonal element to $-\\infty$ : \n",
    "$$ M = \n",
    "\\begin{pmatrix}\n",
    "0& \n",
    "-\\infty & \n",
    "\\dots & -\\infty&\n",
    "-\\infty \\\\\n",
    "0 & \n",
    "0 & \n",
    "\\dots & -\\infty &\n",
    "-\\infty \\\\\n",
    "\\vdots & \\vdots & \\ddots & \\vdots \\\\\n",
    "0 & \n",
    "0 & \n",
    "\\dots &\n",
    "0& -\\infty \\\\\n",
    "0 & \n",
    "0 & \n",
    "\\dots &\n",
    "0 & \n",
    "0\\\\\n",
    "\\end{pmatrix}\n",
    "$$\n",
    "\n",
    "And apply softmax on $A + M$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbf3ecaa-8ee7-4940-a225-81f8e96a09a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def multi_head_attention(\n",
    "    q : torch.FloatTensor,\n",
    "    k : torch.FloatTensor,\n",
    "    v : torch.FloatTensor,\n",
    "    mask : torch.FloatTensor\n",
    "):\n",
    "    \"\"\"\n",
    "    Given Q,K,V and the attention mask compute cross attention\n",
    "\n",
    "    The shape of the matrix are the following:\n",
    "    B: The batch size (number of examples)\n",
    "    Q: The size of the decoder sequence\n",
    "    K: The size of the encoder sequence\n",
    "    H: Then number of attention heads\n",
    "    N: The embedding size \n",
    "    \n",
    "    Parameters:\n",
    "        q : torch.FloatTensor\n",
    "            The query matrix of shape BxHxQxN\n",
    "        k : torch.FloatTensor\n",
    "            The key matrix of shape BxHxKxN\n",
    "        v : torch.FloatTensor\n",
    "            The value matrix of shape BxHxKxN\n",
    "        mask : torch.FloatTensor\n",
    "            The attention mask of shape BxQxK,\n",
    "            the masked elements are set to -inf \n",
    "            else elements are set to 0\n",
    "    Return : (torch.Tensor, torch.Tensor)\n",
    "        return two tensor the first containing \n",
    "        the attention weights (QK^t) and the second \n",
    "        the result of the attention\n",
    "    \"\"\"\n",
    "    raise NotImplementedError\n",
    "    return attention_weights, attention_output\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2849459-0fa5-4508-9052-c641e8fd3322",
   "metadata": {},
   "outputs": [],
   "source": [
    "q, k, v = (torch.ones(1, 1, 3, 4), torch.ones(1, 1, 3, 4), torch.ones(1, 1, 3, 4))\n",
    "mask = torch.triu(-q.new_ones(1, 3, 3) *  torch.inf, diagonal =  1)\n",
    "attention_weigths, attention_output = multi_head_attention(q, k, v, mask)\n",
    "\n",
    "assert(torch.all(attention_weigths == torch.Tensor([\n",
    "          [[[1., 0., 0.],\n",
    "            [1/2, 1/2, 0.],\n",
    "            [1/3, 1/3, 1/3]]]])))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b4d27a7a",
   "metadata": {},
   "source": [
    "##  Create the attention module\n",
    "\n",
    "In pytorch module are block of neural networks that inherit from the class `nn.Module`.  In the following you willbe asked to complete the attention code. The attention module compute the Query $Q$, the Keys $K$ and Values $V$ for each heads and compute attention. The forward methods is the method implementing these operations. In this lab we use a mult-head attention, meaning that we compute attention for each heads, that can be done with a for loop or using batched operation (it is up to you). Also as we follow the Llama 2 architecture, we use for positional embeding the RoPe positonal encoding that is applied on the two matrix $Q$ and $K$. The forward function should compute for an input $X \\in \\mathbb{R}^{L \\times N}$ (in the decoder only case) the following:\n",
    "\n",
    "1. $Q_h = RoPE(W^q_{h}X^{\\intercal})$ (for all heads h) with $W^q_h \\in \\mathbb{R}^{N \\times N}$\n",
    "2. $K_h = RoPE(W^k_{h}X^{\\intercal})$ (for all heads h)\n",
    "3. $V_h = W^v_{h}X^{\\intercal}$ (for all heads h)\n",
    "4. $A_h =  Attention(Q_h, K_h, V_h)$ (for all heads h)\n",
    "5. $A$ = [A_1, A_2, \\dots A_H] (the concatenation along the last dimension)\n",
    "6. $O = W^oA^{\\intercal}$ with $w^{o} \\in \\mathbb{R}^{N \\times NH}$ (NH being the scalar multplication of N and H)\n",
    "\n",
    "Notice that Q, K and V can be computed without a loop using $W_{q}$ (respectively for $W_k$ and $W_v$) .\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7564526-ad70-4764-86d2-a0984af9223f",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class RoPEAttentionModule(nn.Module):\n",
    "    \"\"\" A cross/self attention pytorch module.\n",
    "    \n",
    "    \"\"\"\n",
    "    def __init__(self, input_dim, output_dim, num_heads=1):\n",
    "        super().__init__()\n",
    "        \n",
    "        self.Wq, self.Wk, self.Wv  =\\\n",
    "            (nn.Linear(input_dim, output_dim * num_heads) for _ in range(3))\n",
    "\n",
    "        self.output_dim = output_dim\n",
    "        self.num_heads = num_heads\n",
    "        self.Wo = nn.Linear(output_dim * num_heads, input_dim)\n",
    "        self.rope_func = RoPE(output_dim)\n",
    "    \n",
    "    def forward(\n",
    "            self,\n",
    "            x : torch.Tensor,\n",
    "            attention_mask : torch.BoolTensor = None, \n",
    "            y : torch.Tensor = None,\n",
    "            decoder_mask = True,\n",
    "            k_cache = None,\n",
    "            v_cache = None,\n",
    "            start_cache = 0\n",
    "        ):\n",
    "        ''' \n",
    "            Parameters:\n",
    "                x : torch.Tensor\n",
    "                    The input of the Attention module\n",
    "                    used at least for K, V computation\n",
    "                    (for Q if decoder only)\n",
    "                attention_mask : torch.BoolTensor\n",
    "                    The mask for attention\n",
    "                y : torch.Tensor or None\n",
    "                    The query input in the case of\n",
    "                    cross-attention\n",
    "        '''\n",
    "        raise NotImplementedError\n",
    "\n",
    "        return attention_weigths, output\n",
    "\n",
    "    \n",
    "\n",
    "\n",
    "    \n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c085a7a",
   "metadata": {},
   "source": [
    "## Trasnsformer FeedForward\n",
    "\n",
    "The FeedForward Network is defined using Llama 2 architecture as following for $X \\in \\mathbb{R^{L\\times N}}$ :\n",
    "\n",
    "\n",
    "1. $G = SiLU(W^gX^\\intercal)$ with $W^g \\in \\mathbb{R}^{M \\times N}$ (see SiLU in pytorch [documentation](https://docs.pytorch.org/docs/stable/generated/torch.nn.SiLU.html))\n",
    "2. $U = W^u X^\\intercal$ with $W^u  \\in \\mathbb{R}^{M \\times N}$\n",
    "3. $I = G \\odot I$ (Hadamard multiplication)\n",
    "4. $O = W^{o} I^{\\intercal}$ with $W_{o} \\in \\mathbb{R}^{N \\times M}$\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9c95d49",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TransformerFeedForward(nn.Module):\n",
    "    def __init__(self, embed_size, intermediate_size):\n",
    "        super().__init__()\n",
    "        self.gate_proj = nn.Linear(embed_size, intermediate_size)\n",
    "        self.up_proj = nn.Linear(embed_size, intermediate_size)\n",
    "        self.down_proj = nn.Linear(intermediate_size, embed_size)\n",
    "        self.gate_func = nn.SiLU()\n",
    "        \n",
    "    def forward(self, x):\n",
    "        raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48bd9e1c",
   "metadata": {},
   "source": [
    "## Transformer block\n",
    "We can now create the block of the \n",
    "\n",
    "Let consider $x_0 \\in \\mathbb{R}^{B \\times L \\times N}$, the decoder block is given by : \n",
    "\n",
    "* $x_1 = LN_1(x_0)$ (apply layer norm)\n",
    "* $x_2 = Attention(x_1)$ (apply attention)\n",
    "* $x_3 = x_0 + x_2$ (adding residual)\n",
    "* $x_4 = LN_2(x_3)$\n",
    "* $y = FF(x_4) + x_4$\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66e08611",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TransformerDecoderBlock(nn.Module):\n",
    "    def __init__(self, embed_size, intermediate_size, num_heads):\n",
    "        super().__init__()\n",
    "        self.attention_module =\\\n",
    "            RoPEAttentionModule(embed_size, embed_size, num_heads=num_heads)\n",
    "        self.feed_forward = TransformerFeedForward(embed_size, intermediate_size)\n",
    "        \n",
    "        self.attention_layer_norm = nn.RMSNorm(embed_size)\n",
    "        self.feed_forward_layer_norm = nn.RMSNorm(embed_size)\n",
    "        \n",
    "    def forward(\n",
    "        self,\n",
    "        x,\n",
    "        attention_mask = None,\n",
    "        k_cache=None,\n",
    "        v_cache=None,\n",
    "        start_cache=0,\n",
    "        layer=0,\n",
    "    ):\n",
    "        \n",
    "        raise NotImplementedError\n",
    "        \n",
    "        return output"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0eed92e",
   "metadata": {},
   "source": [
    "## The Model\n",
    "\n",
    "The model contain a stack of block on input embedding, it applies the different transformation and should return both the output embeddings and the logits (size of the vocabulary).\n",
    "\n",
    "It should return $o_{e} \\mathbb{R}^{B\\times L \\times N}$ and $\\hat{y}\\in \\mathbb{R}^{B\\times L \\times V}$ with : \n",
    "\n",
    "* $o_e = RMSNORM \\left( b_h\\circ \\dots   \\circ \\circ b_2 \\circ b_1 (w_{word\\_embedding}x^\\intercal) \\right)$ with $b_i$ the $i^{th}$ block\n",
    "* $\\hat{y} = W_{lm\\_head} o_e^{T}$\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8307f7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TransformerDecoder(nn.Module):\n",
    "    def __init__(self, vocabulary_size, embed_size, intermediate_size, num_heads, hidden_layers=5, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.nh = num_heads\n",
    "        self.nhl = hidden_layers\n",
    "        self.es = embed_size \n",
    "\n",
    "        self.wte = nn.Embedding(vocabulary_size, embed_size)\n",
    "\n",
    "        self.blocks = nn.ModuleList([TransformerDecoderBlock(embed_size, intermediate_size, num_heads) for _ in range(hidden_layers)])\n",
    "        self.ln_f = nn.LayerNorm(embed_size, eps=1e-3)\n",
    "        self.lm_head = nn.Linear(embed_size, vocabulary_size)\n",
    "\n",
    "\n",
    "    def forward(\n",
    "            self,\n",
    "            input_ids,\n",
    "            attention_mask = None,\n",
    "            k_cache = None,\n",
    "            v_cache = None,\n",
    "            start_cache = 0\n",
    "        ):\n",
    "        \n",
    "\n",
    "        input_embed = self.wte(input_ids) \n",
    "        intermediate = input_embed \n",
    "\n",
    "        for i, block in enumerate(self.blocks):\n",
    "            raise NotImplementedError\n",
    "\n",
    "        \n",
    "        output_embed = self.ln_f(intermediate)\n",
    "        output_lm = self.lm_head(output_embed)  \n",
    "\n",
    "        return output_embed, output_lm"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a6d31f0",
   "metadata": {},
   "source": [
    "##  Training the transformer\n",
    "\n",
    "### I.Dataset \n",
    "For simplicity we will use a small dataset name TinyStories, we can load it using the huggingface Datasets library as following"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8315ae28",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "tinystories_dataset = load_dataset(\"roneneldan/TinyStories\")\n",
    "training_set = tinystories_dataset['train']\n",
    "print(training_set[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1626f71d",
   "metadata": {},
   "source": [
    "### The Dataloader\n",
    "\n",
    "We can define the batch_size default we will use, using the DataLoader object. Here we consider that the training machine will have enough RAM for a batch of 16 samples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9e8bdc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "training_dl = DataLoader(training_set, batch_size=16, shuffle=True) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "191921cb",
   "metadata": {},
   "source": [
    "### III. Tokenizer\n",
    "\n",
    "In this exercice we will consider the LLama Tokenizer, however, you can train your own tokenizer if you prefer (see lab 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9dd7f477",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "\n",
    "model_hg_id = \"mistralai/Mistral-7B-v0.3\"\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_hg_id)\n",
    "tokenizer.pad_token = '<pad>'\n",
    "vocabulary_size = tokenizer.vocab_size"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2bbecf6",
   "metadata": {},
   "source": [
    "### IV. Create the model and the optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc3d5266",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda'\n",
    "model = TransformerDecoder(vocabulary_size, 128, 256, 4, 2)\n",
    "model = model.train()\n",
    "model = model.to(device)\n",
    "loss_function = nn.CrossEntropyLoss()\n",
    "optimizer = optim.AdamW(model.parameters(), lr=6e-4)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e635ccd4",
   "metadata": {},
   "outputs": [],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f49f86e6",
   "metadata": {},
   "source": [
    "### VI. Training the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d5ca404",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.train()\n",
    "avg_loss = []\n",
    "start_time = time.time()\n",
    "\n",
    "for i, data in enumerate(training_dl):\n",
    "    res = tokenizer(data['text'], return_tensors=\"pt\", padding=True, padding_side='right', max_length=128, truncation=True)\n",
    "    x = res.input_ids[:, :-1] + 0\n",
    "    y = res.input_ids[:, 1:] + 0\n",
    "    # attention_mask = (res.attention_mask == 0)\n",
    "    optimizer.zero_grad()\n",
    "    oe, oy = model(x.to(device))\n",
    "    loss = loss_function(oy.to(device).view(-1, vocabulary_size), y.to(device).view(-1))\n",
    "    loss.backward()\n",
    "    loss_value = loss.item()\n",
    "    avg_loss.append(loss_value)\n",
    "    if( i%500 == 0):\n",
    "        elapsed_time = time.time() - start_time\n",
    "        remaining_time = int((elapsed_time/(i+1)) * (len(training_dl) - i))\n",
    "        loop_sec = ((i+1)/elapsed_time)\n",
    "        print(f\"The loss at iteration {i+1} is {sum(avg_loss)/len(avg_loss):3.4f} remaining_time is {strftime('%H:%M:%S', gmtime(remaining_time))}s ({loop_sec:4.1f} it/s)\", flush=True)\n",
    "        avg_loss = []\n",
    "    if( i%1000 == 0):\n",
    "        torch.save(optimizer.state_dict() ,\"optimizer_state_dict_llama_mini.pth.temp\")\n",
    "        torch.save(model.state_dict() ,\"transformer_state_dict_llama_mini.pth.temp\")\n",
    "        shutil.copyfile(\"optimizer_state_dict_llama_mini.pth.temp\", \"optimizer_state_dict_llama_mini.pth\")\n",
    "        shutil.copyfile(\"transformer_state_dict_llama_mini.pth.temp\", \"transformer_state_dict_llama_mini.pth\")\n",
    "    optimizer.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df38645a",
   "metadata": {},
   "source": [
    "## Decoding time\n",
    "\n",
    "There is different way to decode, the simplest one is the greedy decoding method, the principle is to loop with the new token produced each time (complete `nonefficient_greedy_decoding`) and choosing the token the most likely. We can also store Keys and Values as shown in course that in practice only accellerate the generation but not change it (complete `greedy_decoding`). And finally we can also sampling new tokens (complete `sampling_decoding` )."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "899b529b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def nonefficient_greedy_decoding(model, x, max_new_tokens=64):\n",
    "    previous_gen = x\n",
    "    for _ in range(max_new_tokens):\n",
    "        raise NotImplementedError\n",
    "    \n",
    "    return previous_gen\n",
    "\n",
    "def greedy_decoding(model, x, max_new_tokens=64, max_cache_size=512, tokenizer=None):\n",
    "    raise NotImplementedError\n",
    "\n",
    "def sampling_decoding(model, x, max_new_tokens=64, max_cache_size=512, tokenizer=None, temperature=.7):\n",
    "\n",
    "    raise NotImplementedError\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4f2087e",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = TransformerDecoder(vocabulary_size, 128, 256, 4, 2)\n",
    "model.load_state_dict(torch.load(\"transformer_state_dict_llama_mini.pth\", map_location=\"cpu\"))\n",
    "model = model.eval()\n",
    "\n",
    "decoding_method = greedy_decoding\n",
    "\n",
    "text = \"Alice\"\n",
    "model.to(\"cpu\")\n",
    "tokenized_text = tokenizer(text, return_tensors='pt')\n",
    "\n",
    "output_ids = decoding_method(model, tokenized_text.input_ids.to(\"cpu\"), max_new_tokens=64)[0]\n",
    "print(tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f063be3",
   "metadata": {},
   "source": [
    "## Evaluation ?\n",
    "\n",
    "It remains difficult to evaluate the model. You can here propose an evalaution based on perplexity using an other model, or comparing results to ground truth using the test (here only validation) set, but it will not be totally informative !!!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14e34481",
   "metadata": {},
   "outputs": [],
   "source": [
    "validation_dataset = tinystories_dataset['validation']"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
