{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "f396f201-d921-4fe4-95b5-ff7237a0fcf4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: https://download.pytorch.org/whl/cu121\n",
      "Requirement already satisfied: torch in /home/gerald/miniconda3/lib/python3.11/site-packages (2.5.1+cu121)\n",
      "Requirement already satisfied: torchvision in /home/gerald/miniconda3/lib/python3.11/site-packages (0.20.1+cu121)\n",
      "Requirement already satisfied: torchaudio in /home/gerald/miniconda3/lib/python3.11/site-packages (2.5.1+cu121)\n",
      "Requirement already satisfied: filelock in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch) (3.16.1)\n",
      "Requirement already satisfied: typing-extensions>=4.8.0 in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch) (4.10.0)\n",
      "Requirement already satisfied: networkx in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch) (3.2.1)\n",
      "Requirement already satisfied: jinja2 in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch) (3.1.3)\n",
      "Requirement already satisfied: fsspec in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch) (2024.9.0)\n",
      "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch) (9.1.0.70)\n",
      "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch) (12.1.3.1)\n",
      "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch) (11.0.2.54)\n",
      "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch) (10.3.2.106)\n",
      "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch) (11.4.5.107)\n",
      "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch) (12.1.0.106)\n",
      "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch) (2.21.5)\n",
      "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch) (12.1.105)\n",
      "Requirement already satisfied: triton==3.1.0 in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch) (3.1.0)\n",
      "Requirement already satisfied: sympy==1.13.1 in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch) (1.13.1)\n",
      "Requirement already satisfied: nvidia-nvjitlink-cu12 in /home/gerald/miniconda3/lib/python3.11/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch) (12.1.105)\n",
      "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/gerald/miniconda3/lib/python3.11/site-packages (from sympy==1.13.1->torch) (1.3.0)\n",
      "Requirement already satisfied: numpy in /home/gerald/miniconda3/lib/python3.11/site-packages (from torchvision) (1.26.4)\n",
      "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /home/gerald/miniconda3/lib/python3.11/site-packages (from torchvision) (10.3.0)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /home/gerald/miniconda3/lib/python3.11/site-packages (from jinja2->torch) (2.1.5)\n",
      "Requirement already satisfied: transformers in /home/gerald/miniconda3/lib/python3.11/site-packages (4.46.1)\n",
      "Requirement already satisfied: datasets in /home/gerald/miniconda3/lib/python3.11/site-packages (3.0.2)\n",
      "Requirement already satisfied: filelock in /home/gerald/miniconda3/lib/python3.11/site-packages (from transformers) (3.16.1)\n",
      "Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /home/gerald/miniconda3/lib/python3.11/site-packages (from transformers) (0.26.2)\n",
      "Requirement already satisfied: numpy>=1.17 in /home/gerald/miniconda3/lib/python3.11/site-packages (from transformers) (1.26.4)\n",
      "Requirement already satisfied: packaging>=20.0 in /home/gerald/miniconda3/lib/python3.11/site-packages (from transformers) (23.1)\n",
      "Requirement already satisfied: pyyaml>=5.1 in /home/gerald/miniconda3/lib/python3.11/site-packages (from transformers) (6.0.1)\n",
      "Requirement already satisfied: regex!=2019.12.17 in /home/gerald/miniconda3/lib/python3.11/site-packages (from transformers) (2024.9.11)\n",
      "Requirement already satisfied: requests in /home/gerald/miniconda3/lib/python3.11/site-packages (from transformers) (2.32.3)\n",
      "Requirement already satisfied: safetensors>=0.4.1 in /home/gerald/miniconda3/lib/python3.11/site-packages (from transformers) (0.4.5)\n",
      "Requirement already satisfied: tokenizers<0.21,>=0.20 in /home/gerald/miniconda3/lib/python3.11/site-packages (from transformers) (0.20.1)\n",
      "Requirement already satisfied: tqdm>=4.27 in /home/gerald/miniconda3/lib/python3.11/site-packages (from transformers) (4.66.6)\n",
      "Requirement already satisfied: pyarrow>=15.0.0 in /home/gerald/miniconda3/lib/python3.11/site-packages (from datasets) (18.0.0)\n",
      "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /home/gerald/miniconda3/lib/python3.11/site-packages (from datasets) (0.3.8)\n",
      "Requirement already satisfied: pandas in /home/gerald/miniconda3/lib/python3.11/site-packages (from datasets) (2.2.3)\n",
      "Requirement already satisfied: xxhash in /home/gerald/miniconda3/lib/python3.11/site-packages (from datasets) (3.5.0)\n",
      "Requirement already satisfied: multiprocess<0.70.17 in /home/gerald/miniconda3/lib/python3.11/site-packages (from datasets) (0.70.16)\n",
      "Requirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in /home/gerald/miniconda3/lib/python3.11/site-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets) (2024.9.0)\n",
      "Requirement already satisfied: aiohttp in /home/gerald/miniconda3/lib/python3.11/site-packages (from datasets) (3.10.10)\n",
      "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /home/gerald/miniconda3/lib/python3.11/site-packages (from aiohttp->datasets) (2.4.3)\n",
      "Requirement already satisfied: aiosignal>=1.1.2 in /home/gerald/miniconda3/lib/python3.11/site-packages (from aiohttp->datasets) (1.3.1)\n",
      "Requirement already satisfied: attrs>=17.3.0 in /home/gerald/miniconda3/lib/python3.11/site-packages (from aiohttp->datasets) (23.2.0)\n",
      "Requirement already satisfied: frozenlist>=1.1.1 in /home/gerald/miniconda3/lib/python3.11/site-packages (from aiohttp->datasets) (1.5.0)\n",
      "Requirement already satisfied: multidict<7.0,>=4.5 in /home/gerald/miniconda3/lib/python3.11/site-packages (from aiohttp->datasets) (6.1.0)\n",
      "Requirement already satisfied: yarl<2.0,>=1.12.0 in /home/gerald/miniconda3/lib/python3.11/site-packages (from aiohttp->datasets) (1.17.0)\n",
      "Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/gerald/miniconda3/lib/python3.11/site-packages (from huggingface-hub<1.0,>=0.23.2->transformers) (4.10.0)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /home/gerald/miniconda3/lib/python3.11/site-packages (from requests->transformers) (3.3.2)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /home/gerald/miniconda3/lib/python3.11/site-packages (from requests->transformers) (3.10)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/gerald/miniconda3/lib/python3.11/site-packages (from requests->transformers) (1.26.20)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /home/gerald/miniconda3/lib/python3.11/site-packages (from requests->transformers) (2024.8.30)\n",
      "Requirement already satisfied: python-dateutil>=2.8.2 in /home/gerald/miniconda3/lib/python3.11/site-packages (from pandas->datasets) (2.9.0.post0)\n",
      "Requirement already satisfied: pytz>=2020.1 in /home/gerald/miniconda3/lib/python3.11/site-packages (from pandas->datasets) (2024.2)\n",
      "Requirement already satisfied: tzdata>=2022.7 in /home/gerald/miniconda3/lib/python3.11/site-packages (from pandas->datasets) (2024.2)\n",
      "Requirement already satisfied: six>=1.5 in /home/gerald/miniconda3/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n",
      "Requirement already satisfied: propcache>=0.2.0 in /home/gerald/miniconda3/lib/python3.11/site-packages (from yarl<2.0,>=1.12.0->aiohttp->datasets) (0.2.0)\n",
      "Collecting peft\n",
      "  Downloading peft-0.13.2-py3-none-any.whl.metadata (13 kB)\n",
      "Requirement already satisfied: numpy>=1.17 in /home/gerald/miniconda3/lib/python3.11/site-packages (from peft) (1.26.4)\n",
      "Requirement already satisfied: packaging>=20.0 in /home/gerald/miniconda3/lib/python3.11/site-packages (from peft) (23.1)\n",
      "Requirement already satisfied: psutil in /home/gerald/miniconda3/lib/python3.11/site-packages (from peft) (5.9.8)\n",
      "Requirement already satisfied: pyyaml in /home/gerald/miniconda3/lib/python3.11/site-packages (from peft) (6.0.1)\n",
      "Requirement already satisfied: torch>=1.13.0 in /home/gerald/miniconda3/lib/python3.11/site-packages (from peft) (2.5.1+cu121)\n",
      "Requirement already satisfied: transformers in /home/gerald/miniconda3/lib/python3.11/site-packages (from peft) (4.46.1)\n",
      "Requirement already satisfied: tqdm in /home/gerald/miniconda3/lib/python3.11/site-packages (from peft) (4.66.6)\n",
      "Collecting accelerate>=0.21.0 (from peft)\n",
      "  Downloading accelerate-1.0.1-py3-none-any.whl.metadata (19 kB)\n",
      "Requirement already satisfied: safetensors in /home/gerald/miniconda3/lib/python3.11/site-packages (from peft) (0.4.5)\n",
      "Requirement already satisfied: huggingface-hub>=0.17.0 in /home/gerald/miniconda3/lib/python3.11/site-packages (from peft) (0.26.2)\n",
      "Requirement already satisfied: filelock in /home/gerald/miniconda3/lib/python3.11/site-packages (from huggingface-hub>=0.17.0->peft) (3.16.1)\n",
      "Requirement already satisfied: fsspec>=2023.5.0 in /home/gerald/miniconda3/lib/python3.11/site-packages (from huggingface-hub>=0.17.0->peft) (2024.9.0)\n",
      "Requirement already satisfied: requests in /home/gerald/miniconda3/lib/python3.11/site-packages (from huggingface-hub>=0.17.0->peft) (2.32.3)\n",
      "Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/gerald/miniconda3/lib/python3.11/site-packages (from huggingface-hub>=0.17.0->peft) (4.10.0)\n",
      "Requirement already satisfied: networkx in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch>=1.13.0->peft) (3.2.1)\n",
      "Requirement already satisfied: jinja2 in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch>=1.13.0->peft) (3.1.3)\n",
      "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch>=1.13.0->peft) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch>=1.13.0->peft) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch>=1.13.0->peft) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch>=1.13.0->peft) (9.1.0.70)\n",
      "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch>=1.13.0->peft) (12.1.3.1)\n",
      "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch>=1.13.0->peft) (11.0.2.54)\n",
      "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch>=1.13.0->peft) (10.3.2.106)\n",
      "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch>=1.13.0->peft) (11.4.5.107)\n",
      "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch>=1.13.0->peft) (12.1.0.106)\n",
      "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch>=1.13.0->peft) (2.21.5)\n",
      "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch>=1.13.0->peft) (12.1.105)\n",
      "Requirement already satisfied: triton==3.1.0 in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch>=1.13.0->peft) (3.1.0)\n",
      "Requirement already satisfied: sympy==1.13.1 in /home/gerald/miniconda3/lib/python3.11/site-packages (from torch>=1.13.0->peft) (1.13.1)\n",
      "Requirement already satisfied: nvidia-nvjitlink-cu12 in /home/gerald/miniconda3/lib/python3.11/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.13.0->peft) (12.1.105)\n",
      "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/gerald/miniconda3/lib/python3.11/site-packages (from sympy==1.13.1->torch>=1.13.0->peft) (1.3.0)\n",
      "Requirement already satisfied: regex!=2019.12.17 in /home/gerald/miniconda3/lib/python3.11/site-packages (from transformers->peft) (2024.9.11)\n",
      "Requirement already satisfied: tokenizers<0.21,>=0.20 in /home/gerald/miniconda3/lib/python3.11/site-packages (from transformers->peft) (0.20.1)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /home/gerald/miniconda3/lib/python3.11/site-packages (from jinja2->torch>=1.13.0->peft) (2.1.5)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /home/gerald/miniconda3/lib/python3.11/site-packages (from requests->huggingface-hub>=0.17.0->peft) (3.3.2)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /home/gerald/miniconda3/lib/python3.11/site-packages (from requests->huggingface-hub>=0.17.0->peft) (3.10)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/gerald/miniconda3/lib/python3.11/site-packages (from requests->huggingface-hub>=0.17.0->peft) (1.26.20)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /home/gerald/miniconda3/lib/python3.11/site-packages (from requests->huggingface-hub>=0.17.0->peft) (2024.8.30)\n",
      "Downloading peft-0.13.2-py3-none-any.whl (320 kB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m320.7/320.7 kB\u001b[0m \u001b[31m5.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mta \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hDownloading accelerate-1.0.1-py3-none-any.whl (330 kB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m330.9/330.9 kB\u001b[0m \u001b[31m10.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hInstalling collected packages: accelerate, peft\n",
      "Successfully installed accelerate-1.0.1 peft-0.13.2\n"
     ]
    }
   ],
   "source": [
    "!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121\n",
    "!pip install transformers datasets\n",
    "!pip install peft\n",
    "!pip install evaluate scikit-learn transformers[torch]\n",
    "!pip install accelerate>=0.26.0"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "50179772-8165-4a2e-9b51-c86942fa2721",
   "metadata": {},
   "source": [
    "# Prepare the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "8d1c7119-6624-4604-b4b0-f15c905390d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "imdb = load_dataset(\"stanfordnlp/imdb\")\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n",
    "\n",
    "\n",
    "def tokenize_function(examples):\n",
    "    return tokenizer(examples[\"text\"], padding=\"max_length\", truncation=True)\n",
    "\n",
    "\n",
    "tokenized_datasets = imdb.map(tokenize_function, batched=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "6714ed0a-053f-41d5-b340-353a6d2b3f41",
   "metadata": {},
   "outputs": [],
   "source": [
    "training_set = tokenized_datasets['train'].shuffle(42).select(range(5000))\n",
    "shuffled_set =  tokenized_datasets['test'].shuffle(42)\n",
    "validation_set  = shuffled_set.select(range(1000))\n",
    "test_set = shuffled_set.select(range(1000, 2000))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9624c1dc-cbca-4dd0-959f-e482182d94aa",
   "metadata": {},
   "source": [
    "## First method : Using the transformers library and pytorch\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25693f50-39a8-4d5f-97bb-c8f1a5ccefff",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/gerald/Documents/COURS/2023-2024/Formation-RI/.conda/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "from transformers import BertTokenizer\n",
    "from transformers import BertForSequenceClassification\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0176c44",
   "metadata": {},
   "source": [
    "## Fine-tuning the whole model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1cf32374",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
    "model = BertForSequenceClassification.from_pretrained('bert-base-cased', num_labels=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0a649a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Have a look to \n",
    "\n",
    "raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b3f0272",
   "metadata": {},
   "source": [
    "## LoRA implementation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c12fe280",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['embeddings.word_embeddings.weight',\n",
       " 'embeddings.position_embeddings.weight',\n",
       " 'embeddings.token_type_embeddings.weight',\n",
       " 'embeddings.LayerNorm.weight',\n",
       " 'embeddings.LayerNorm.bias',\n",
       " 'encoder.layer.0.attention.self.query.weight',\n",
       " 'encoder.layer.0.attention.self.query.bias',\n",
       " 'encoder.layer.0.attention.self.key.weight',\n",
       " 'encoder.layer.0.attention.self.key.bias',\n",
       " 'encoder.layer.0.attention.self.value.weight',\n",
       " 'encoder.layer.0.attention.self.value.bias',\n",
       " 'encoder.layer.0.attention.output.dense.weight',\n",
       " 'encoder.layer.0.attention.output.dense.bias',\n",
       " 'encoder.layer.0.attention.output.LayerNorm.weight',\n",
       " 'encoder.layer.0.attention.output.LayerNorm.bias',\n",
       " 'encoder.layer.0.intermediate.dense.weight',\n",
       " 'encoder.layer.0.intermediate.dense.bias',\n",
       " 'encoder.layer.0.output.dense.weight',\n",
       " 'encoder.layer.0.output.dense.bias',\n",
       " 'encoder.layer.0.output.LayerNorm.weight',\n",
       " 'encoder.layer.0.output.LayerNorm.bias',\n",
       " 'encoder.layer.1.attention.self.query.weight',\n",
       " 'encoder.layer.1.attention.self.query.bias',\n",
       " 'encoder.layer.1.attention.self.key.weight',\n",
       " 'encoder.layer.1.attention.self.key.bias',\n",
       " 'encoder.layer.1.attention.self.value.weight',\n",
       " 'encoder.layer.1.attention.self.value.bias',\n",
       " 'encoder.layer.1.attention.output.dense.weight',\n",
       " 'encoder.layer.1.attention.output.dense.bias',\n",
       " 'encoder.layer.1.attention.output.LayerNorm.weight',\n",
       " 'encoder.layer.1.attention.output.LayerNorm.bias',\n",
       " 'encoder.layer.1.intermediate.dense.weight',\n",
       " 'encoder.layer.1.intermediate.dense.bias',\n",
       " 'encoder.layer.1.output.dense.weight',\n",
       " 'encoder.layer.1.output.dense.bias',\n",
       " 'encoder.layer.1.output.LayerNorm.weight',\n",
       " 'encoder.layer.1.output.LayerNorm.bias',\n",
       " 'encoder.layer.2.attention.self.query.weight',\n",
       " 'encoder.layer.2.attention.self.query.bias',\n",
       " 'encoder.layer.2.attention.self.key.weight',\n",
       " 'encoder.layer.2.attention.self.key.bias',\n",
       " 'encoder.layer.2.attention.self.value.weight',\n",
       " 'encoder.layer.2.attention.self.value.bias',\n",
       " 'encoder.layer.2.attention.output.dense.weight',\n",
       " 'encoder.layer.2.attention.output.dense.bias',\n",
       " 'encoder.layer.2.attention.output.LayerNorm.weight',\n",
       " 'encoder.layer.2.attention.output.LayerNorm.bias',\n",
       " 'encoder.layer.2.intermediate.dense.weight',\n",
       " 'encoder.layer.2.intermediate.dense.bias',\n",
       " 'encoder.layer.2.output.dense.weight',\n",
       " 'encoder.layer.2.output.dense.bias',\n",
       " 'encoder.layer.2.output.LayerNorm.weight',\n",
       " 'encoder.layer.2.output.LayerNorm.bias',\n",
       " 'encoder.layer.3.attention.self.query.weight',\n",
       " 'encoder.layer.3.attention.self.query.bias',\n",
       " 'encoder.layer.3.attention.self.key.weight',\n",
       " 'encoder.layer.3.attention.self.key.bias',\n",
       " 'encoder.layer.3.attention.self.value.weight',\n",
       " 'encoder.layer.3.attention.self.value.bias',\n",
       " 'encoder.layer.3.attention.output.dense.weight',\n",
       " 'encoder.layer.3.attention.output.dense.bias',\n",
       " 'encoder.layer.3.attention.output.LayerNorm.weight',\n",
       " 'encoder.layer.3.attention.output.LayerNorm.bias',\n",
       " 'encoder.layer.3.intermediate.dense.weight',\n",
       " 'encoder.layer.3.intermediate.dense.bias',\n",
       " 'encoder.layer.3.output.dense.weight',\n",
       " 'encoder.layer.3.output.dense.bias',\n",
       " 'encoder.layer.3.output.LayerNorm.weight',\n",
       " 'encoder.layer.3.output.LayerNorm.bias',\n",
       " 'encoder.layer.4.attention.self.query.weight',\n",
       " 'encoder.layer.4.attention.self.query.bias',\n",
       " 'encoder.layer.4.attention.self.key.weight',\n",
       " 'encoder.layer.4.attention.self.key.bias',\n",
       " 'encoder.layer.4.attention.self.value.weight',\n",
       " 'encoder.layer.4.attention.self.value.bias',\n",
       " 'encoder.layer.4.attention.output.dense.weight',\n",
       " 'encoder.layer.4.attention.output.dense.bias',\n",
       " 'encoder.layer.4.attention.output.LayerNorm.weight',\n",
       " 'encoder.layer.4.attention.output.LayerNorm.bias',\n",
       " 'encoder.layer.4.intermediate.dense.weight',\n",
       " 'encoder.layer.4.intermediate.dense.bias',\n",
       " 'encoder.layer.4.output.dense.weight',\n",
       " 'encoder.layer.4.output.dense.bias',\n",
       " 'encoder.layer.4.output.LayerNorm.weight',\n",
       " 'encoder.layer.4.output.LayerNorm.bias',\n",
       " 'encoder.layer.5.attention.self.query.weight',\n",
       " 'encoder.layer.5.attention.self.query.bias',\n",
       " 'encoder.layer.5.attention.self.key.weight',\n",
       " 'encoder.layer.5.attention.self.key.bias',\n",
       " 'encoder.layer.5.attention.self.value.weight',\n",
       " 'encoder.layer.5.attention.self.value.bias',\n",
       " 'encoder.layer.5.attention.output.dense.weight',\n",
       " 'encoder.layer.5.attention.output.dense.bias',\n",
       " 'encoder.layer.5.attention.output.LayerNorm.weight',\n",
       " 'encoder.layer.5.attention.output.LayerNorm.bias',\n",
       " 'encoder.layer.5.intermediate.dense.weight',\n",
       " 'encoder.layer.5.intermediate.dense.bias',\n",
       " 'encoder.layer.5.output.dense.weight',\n",
       " 'encoder.layer.5.output.dense.bias',\n",
       " 'encoder.layer.5.output.LayerNorm.weight',\n",
       " 'encoder.layer.5.output.LayerNorm.bias',\n",
       " 'encoder.layer.6.attention.self.query.weight',\n",
       " 'encoder.layer.6.attention.self.query.bias',\n",
       " 'encoder.layer.6.attention.self.key.weight',\n",
       " 'encoder.layer.6.attention.self.key.bias',\n",
       " 'encoder.layer.6.attention.self.value.weight',\n",
       " 'encoder.layer.6.attention.self.value.bias',\n",
       " 'encoder.layer.6.attention.output.dense.weight',\n",
       " 'encoder.layer.6.attention.output.dense.bias',\n",
       " 'encoder.layer.6.attention.output.LayerNorm.weight',\n",
       " 'encoder.layer.6.attention.output.LayerNorm.bias',\n",
       " 'encoder.layer.6.intermediate.dense.weight',\n",
       " 'encoder.layer.6.intermediate.dense.bias',\n",
       " 'encoder.layer.6.output.dense.weight',\n",
       " 'encoder.layer.6.output.dense.bias',\n",
       " 'encoder.layer.6.output.LayerNorm.weight',\n",
       " 'encoder.layer.6.output.LayerNorm.bias',\n",
       " 'encoder.layer.7.attention.self.query.weight',\n",
       " 'encoder.layer.7.attention.self.query.bias',\n",
       " 'encoder.layer.7.attention.self.key.weight',\n",
       " 'encoder.layer.7.attention.self.key.bias',\n",
       " 'encoder.layer.7.attention.self.value.weight',\n",
       " 'encoder.layer.7.attention.self.value.bias',\n",
       " 'encoder.layer.7.attention.output.dense.weight',\n",
       " 'encoder.layer.7.attention.output.dense.bias',\n",
       " 'encoder.layer.7.attention.output.LayerNorm.weight',\n",
       " 'encoder.layer.7.attention.output.LayerNorm.bias',\n",
       " 'encoder.layer.7.intermediate.dense.weight',\n",
       " 'encoder.layer.7.intermediate.dense.bias',\n",
       " 'encoder.layer.7.output.dense.weight',\n",
       " 'encoder.layer.7.output.dense.bias',\n",
       " 'encoder.layer.7.output.LayerNorm.weight',\n",
       " 'encoder.layer.7.output.LayerNorm.bias',\n",
       " 'encoder.layer.8.attention.self.query.weight',\n",
       " 'encoder.layer.8.attention.self.query.bias',\n",
       " 'encoder.layer.8.attention.self.key.weight',\n",
       " 'encoder.layer.8.attention.self.key.bias',\n",
       " 'encoder.layer.8.attention.self.value.weight',\n",
       " 'encoder.layer.8.attention.self.value.bias',\n",
       " 'encoder.layer.8.attention.output.dense.weight',\n",
       " 'encoder.layer.8.attention.output.dense.bias',\n",
       " 'encoder.layer.8.attention.output.LayerNorm.weight',\n",
       " 'encoder.layer.8.attention.output.LayerNorm.bias',\n",
       " 'encoder.layer.8.intermediate.dense.weight',\n",
       " 'encoder.layer.8.intermediate.dense.bias',\n",
       " 'encoder.layer.8.output.dense.weight',\n",
       " 'encoder.layer.8.output.dense.bias',\n",
       " 'encoder.layer.8.output.LayerNorm.weight',\n",
       " 'encoder.layer.8.output.LayerNorm.bias',\n",
       " 'encoder.layer.9.attention.self.query.weight',\n",
       " 'encoder.layer.9.attention.self.query.bias',\n",
       " 'encoder.layer.9.attention.self.key.weight',\n",
       " 'encoder.layer.9.attention.self.key.bias',\n",
       " 'encoder.layer.9.attention.self.value.weight',\n",
       " 'encoder.layer.9.attention.self.value.bias',\n",
       " 'encoder.layer.9.attention.output.dense.weight',\n",
       " 'encoder.layer.9.attention.output.dense.bias',\n",
       " 'encoder.layer.9.attention.output.LayerNorm.weight',\n",
       " 'encoder.layer.9.attention.output.LayerNorm.bias',\n",
       " 'encoder.layer.9.intermediate.dense.weight',\n",
       " 'encoder.layer.9.intermediate.dense.bias',\n",
       " 'encoder.layer.9.output.dense.weight',\n",
       " 'encoder.layer.9.output.dense.bias',\n",
       " 'encoder.layer.9.output.LayerNorm.weight',\n",
       " 'encoder.layer.9.output.LayerNorm.bias',\n",
       " 'encoder.layer.10.attention.self.query.weight',\n",
       " 'encoder.layer.10.attention.self.query.bias',\n",
       " 'encoder.layer.10.attention.self.key.weight',\n",
       " 'encoder.layer.10.attention.self.key.bias',\n",
       " 'encoder.layer.10.attention.self.value.weight',\n",
       " 'encoder.layer.10.attention.self.value.bias',\n",
       " 'encoder.layer.10.attention.output.dense.weight',\n",
       " 'encoder.layer.10.attention.output.dense.bias',\n",
       " 'encoder.layer.10.attention.output.LayerNorm.weight',\n",
       " 'encoder.layer.10.attention.output.LayerNorm.bias',\n",
       " 'encoder.layer.10.intermediate.dense.weight',\n",
       " 'encoder.layer.10.intermediate.dense.bias',\n",
       " 'encoder.layer.10.output.dense.weight',\n",
       " 'encoder.layer.10.output.dense.bias',\n",
       " 'encoder.layer.10.output.LayerNorm.weight',\n",
       " 'encoder.layer.10.output.LayerNorm.bias',\n",
       " 'encoder.layer.11.attention.self.query.weight',\n",
       " 'encoder.layer.11.attention.self.query.bias',\n",
       " 'encoder.layer.11.attention.self.key.weight',\n",
       " 'encoder.layer.11.attention.self.key.bias',\n",
       " 'encoder.layer.11.attention.self.value.weight',\n",
       " 'encoder.layer.11.attention.self.value.bias',\n",
       " 'encoder.layer.11.attention.output.dense.weight',\n",
       " 'encoder.layer.11.attention.output.dense.bias',\n",
       " 'encoder.layer.11.attention.output.LayerNorm.weight',\n",
       " 'encoder.layer.11.attention.output.LayerNorm.bias',\n",
       " 'encoder.layer.11.intermediate.dense.weight',\n",
       " 'encoder.layer.11.intermediate.dense.bias',\n",
       " 'encoder.layer.11.output.dense.weight',\n",
       " 'encoder.layer.11.output.dense.bias',\n",
       " 'encoder.layer.11.output.LayerNorm.weight',\n",
       " 'encoder.layer.11.output.LayerNorm.bias',\n",
       " 'pooler.dense.weight',\n",
       " 'pooler.dense.bias']"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[k for k,v  in model.bert.named_parameters()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "759b1ef3-25e9-44aa-82e0-af70ce76abfc",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "What is the model \n",
      "\n",
      " BertForSequenceClassification(\n",
      "  (bert): BertModel(\n",
      "    (embeddings): BertEmbeddings(\n",
      "      (word_embeddings): Embedding(28996, 768, padding_idx=0)\n",
      "      (position_embeddings): Embedding(512, 768)\n",
      "      (token_type_embeddings): Embedding(2, 768)\n",
      "      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
      "      (dropout): Dropout(p=0.1, inplace=False)\n",
      "    )\n",
      "    (encoder): BertEncoder(\n",
      "      (layer): ModuleList(\n",
      "        (0-11): 12 x BertLayer(\n",
      "          (attention): BertAttention(\n",
      "            (self): BertSdpaSelfAttention(\n",
      "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
      "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
      "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
      "              (dropout): Dropout(p=0.1, inplace=False)\n",
      "            )\n",
      "            (output): BertSelfOutput(\n",
      "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
      "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
      "              (dropout): Dropout(p=0.1, inplace=False)\n",
      "            )\n",
      "          )\n",
      "          (intermediate): BertIntermediate(\n",
      "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
      "            (intermediate_act_fn): GELUActivation()\n",
      "          )\n",
      "          (output): BertOutput(\n",
      "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
      "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
      "            (dropout): Dropout(p=0.1, inplace=False)\n",
      "          )\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (pooler): BertPooler(\n",
      "      (dense): Linear(in_features=768, out_features=768, bias=True)\n",
      "      (activation): Tanh()\n",
      "    )\n",
      "  )\n",
      "  (dropout): Dropout(p=0.1, inplace=False)\n",
      "  (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "print(f\"What is the model \\n\\n {model}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b82be16-830b-474b-b5f6-a98df468391f",
   "metadata": {},
   "source": [
    "### Create the LoRA module\n",
    "\n",
    "We will consider in this module the original linear, and the down and up projection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ba190fe-401c-411c-9958-1e74d51bcaf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "import copy\n",
    "from torch import nn\n",
    "\n",
    "class LoRALinear(nn.Module):\n",
    "  def __init__(\n",
    "    self, in_dim: int, out_dim: int, rank: int\n",
    "  ):\n",
    "    super().__init__()\n",
    "    raise NotImplementedError\n",
    "  def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "    raise NotImplementedError\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32301774-a17f-44dd-a7b3-e7aa8d15e655",
   "metadata": {},
   "source": [
    "### Create a function replacing linear by lora module\n",
    "From the original Linear module return a LoRA module\n",
    "* The linear of the LoRA linear must be initialised with pretrained weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "6cbf7015-fdd5-41ea-b865-5d2021b970d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "def linear_to_lora(linear):\n",
    "    linear_weight = linear.weight.data\n",
    "    has_bias = linear.bias is not None\n",
    "    if has_bias:\n",
    "        linear_bias = linear.bias.data\n",
    "    output_size, input_size =  linear_weight.shape\n",
    "    lora = LoRALinear(input_size, output_size, rank=8)\n",
    "    lora.linear.weight.data = linear_weight\n",
    "    if has_bias:\n",
    "        lora.linear.bias.data = linear_bias\n",
    "    return lora\n",
    "    \n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c5e3fac-0662-4f41-937f-d9736800103c",
   "metadata": {},
   "source": [
    "We now replace the target linear by the LoRALinear described above"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68d0fb24-cdaa-4cbd-ac75-a00a12c9a1b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "lora_model = copy.deepcopy(model)\n",
    "lora_parameters = []\n",
    "for block in lora_model.bert.encoder.layer:\n",
    "    raise NotImplementedError"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "8484560f-2d4f-4e77-a381-f3b6dbd581d3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "What is the model modified with LoRA: \n",
      "\n",
      " BertForSequenceClassification(\n",
      "  (bert): BertModel(\n",
      "    (embeddings): BertEmbeddings(\n",
      "      (word_embeddings): Embedding(28996, 768, padding_idx=0)\n",
      "      (position_embeddings): Embedding(512, 768)\n",
      "      (token_type_embeddings): Embedding(2, 768)\n",
      "      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
      "      (dropout): Dropout(p=0.1, inplace=False)\n",
      "    )\n",
      "    (encoder): BertEncoder(\n",
      "      (layer): ModuleList(\n",
      "        (0-11): 12 x BertLayer(\n",
      "          (attention): BertAttention(\n",
      "            (self): BertSdpaSelfAttention(\n",
      "              (query): LoRALinear(\n",
      "                (linear): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (lora_a): Linear(in_features=768, out_features=8, bias=False)\n",
      "                (lora_b): Linear(in_features=8, out_features=768, bias=False)\n",
      "              )\n",
      "              (key): LoRALinear(\n",
      "                (linear): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (lora_a): Linear(in_features=768, out_features=8, bias=False)\n",
      "                (lora_b): Linear(in_features=8, out_features=768, bias=False)\n",
      "              )\n",
      "              (value): LoRALinear(\n",
      "                (linear): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (lora_a): Linear(in_features=768, out_features=8, bias=False)\n",
      "                (lora_b): Linear(in_features=8, out_features=768, bias=False)\n",
      "              )\n",
      "              (dropout): Dropout(p=0.1, inplace=False)\n",
      "            )\n",
      "            (output): BertSelfOutput(\n",
      "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
      "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
      "              (dropout): Dropout(p=0.1, inplace=False)\n",
      "            )\n",
      "          )\n",
      "          (intermediate): BertIntermediate(\n",
      "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
      "            (intermediate_act_fn): GELUActivation()\n",
      "          )\n",
      "          (output): BertOutput(\n",
      "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
      "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
      "            (dropout): Dropout(p=0.1, inplace=False)\n",
      "          )\n",
      "        )\n",
      "      )\n",
      "    )\n",
      "    (pooler): BertPooler(\n",
      "      (dense): Linear(in_features=768, out_features=768, bias=True)\n",
      "      (activation): Tanh()\n",
      "    )\n",
      "  )\n",
      "  (dropout): Dropout(p=0.1, inplace=False)\n",
      "  (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "print(f\"What is the model modified with LoRA: \\n\\n {lora_model}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0ad8e4a-2d59-47c9-87dc-3b26836dada0",
   "metadata": {},
   "source": [
    "### Defining module requiring grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00bed4e1-e5ef-4cfe-83a8-dc915bc3070f",
   "metadata": {},
   "outputs": [],
   "source": [
    "for k,v  in lora_model.bert.named_parameters():\n",
    "    print(k)\n",
    "    if ('lora' in k):\n",
    "        v.requires_grad = True\n",
    "    else:\n",
    "        v.requires_grad = False"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1f724d4-b1bb-40c2-8902-ab9ea7fe9bd6",
   "metadata": {},
   "source": [
    "## Using transformers Trainer to fine-tune with LoRA "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "a6f2f95c-aff2-46a8-abc0-d21bdd663bfe",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import evaluate\n",
    "\n",
    "metric = evaluate.load(\"accuracy\")\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    logits, labels = eval_pred\n",
    "    predictions = np.argmax(logits, axis=-1)\n",
    "    return metric.compute(predictions=predictions, references=labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "f7f4a058-fcfc-448c-b1c5-e945cc7f58ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "training_args = TrainingArguments(output_dir=\"test_custom_lora\", \n",
    "                                  eval_strategy=\"steps\",\n",
    "                                  eval_steps= 128,\n",
    "                                  num_train_epochs=2,)\n",
    "trainer = Trainer(\n",
    "    model=lora_model,\n",
    "    args=training_args,\n",
    "    train_dataset= training_set,\n",
    "    eval_dataset=validation_set,\n",
    "    compute_metrics=compute_metrics,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "8971eec6-b562-45c4-8c7a-8affe54a4364",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='1250' max='1250' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [1250/1250 02:51, Epoch 2/2]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>128</td>\n",
       "      <td>No log</td>\n",
       "      <td>0.677891</td>\n",
       "      <td>0.566000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>256</td>\n",
       "      <td>No log</td>\n",
       "      <td>0.614952</td>\n",
       "      <td>0.684000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>384</td>\n",
       "      <td>No log</td>\n",
       "      <td>0.502475</td>\n",
       "      <td>0.774000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>512</td>\n",
       "      <td>0.619000</td>\n",
       "      <td>0.418735</td>\n",
       "      <td>0.815000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>640</td>\n",
       "      <td>0.619000</td>\n",
       "      <td>0.381323</td>\n",
       "      <td>0.836000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>768</td>\n",
       "      <td>0.619000</td>\n",
       "      <td>0.371402</td>\n",
       "      <td>0.835000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>896</td>\n",
       "      <td>0.619000</td>\n",
       "      <td>0.354594</td>\n",
       "      <td>0.848000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1024</td>\n",
       "      <td>0.388000</td>\n",
       "      <td>0.347619</td>\n",
       "      <td>0.851000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1152</td>\n",
       "      <td>0.388000</td>\n",
       "      <td>0.342413</td>\n",
       "      <td>0.855000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=1250, training_loss=0.47044661254882814, metrics={'train_runtime': 171.7994, 'train_samples_per_second': 58.207, 'train_steps_per_second': 7.276, 'total_flos': 2644700098560000.0, 'train_loss': 0.47044661254882814, 'epoch': 2.0})"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "f11f8f99-26f4-4573-a4d7-7ee87f43f587",
   "metadata": {},
   "source": [
    "## Using PEFT library to fine-tune with LoRA "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "94a30566-0231-420d-8709-1be4fb238f1e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "from peft import LoraConfig, TaskType,  get_peft_model\n",
    "\n",
    "lora_config = LoraConfig(\n",
    "    task_type=TaskType.SEQ_CLS, r=1, lora_alpha=1, lora_dropout=0.1\n",
    ")\n",
    "\n",
    "model = BertForSequenceClassification.from_pretrained(\n",
    "    'bert-base-cased', \n",
    "    num_labels=2\n",
    ")\n",
    "peft_model = get_peft_model(model, lora_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "e0cfba65-25c7-455f-b3e4-4ecca8a76b14",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "5d361529-f3e8-44b3-8fcd-5640679405f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "training_args = TrainingArguments(output_dir=\"test_custom_lora\", \n",
    "                                  eval_strategy=\"steps\",\n",
    "                                  eval_steps= 128,\n",
    "                                  num_train_epochs=2,)\n",
    "trainer = Trainer(\n",
    "    model=peft_model,\n",
    "    args=training_args,\n",
    "    train_dataset= training_set,\n",
    "    eval_dataset=validation_set,\n",
    "    compute_metrics=compute_metrics,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "5cf108d7-6ec0-4ee8-8c71-15accf8ceb51",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='1250' max='1250' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [1250/1250 02:50, Epoch 2/2]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>128</td>\n",
       "      <td>No log</td>\n",
       "      <td>0.686927</td>\n",
       "      <td>0.548000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>256</td>\n",
       "      <td>No log</td>\n",
       "      <td>0.677564</td>\n",
       "      <td>0.593000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>384</td>\n",
       "      <td>No log</td>\n",
       "      <td>0.668519</td>\n",
       "      <td>0.616000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>512</td>\n",
       "      <td>0.689200</td>\n",
       "      <td>0.655116</td>\n",
       "      <td>0.635000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>640</td>\n",
       "      <td>0.689200</td>\n",
       "      <td>0.653552</td>\n",
       "      <td>0.616000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>768</td>\n",
       "      <td>0.689200</td>\n",
       "      <td>0.630284</td>\n",
       "      <td>0.650000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>896</td>\n",
       "      <td>0.689200</td>\n",
       "      <td>0.620816</td>\n",
       "      <td>0.659000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1024</td>\n",
       "      <td>0.652700</td>\n",
       "      <td>0.614406</td>\n",
       "      <td>0.658000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1152</td>\n",
       "      <td>0.652700</td>\n",
       "      <td>0.610755</td>\n",
       "      <td>0.665000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=1250, training_loss=0.6643443115234375, metrics={'train_runtime': 170.7475, 'train_samples_per_second': 58.566, 'train_steps_per_second': 7.321, 'total_flos': 2632290263040000.0, 'train_loss': 0.6643443115234375, 'epoch': 2.0})"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "111faf3a-272b-4dea-9f9a-8b00e7e194a5",
   "metadata": {},
   "source": [
    "## Compare the results of the different approaches"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7933b06b",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
