{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "9510dd98", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM\n", "from datasets import load_dataset\n", "from transformers import get_scheduler\n", "import torch\n", "from torch.utils.data import DataLoader\n", "from datasets import load_dataset\n", "from transformers import AutoTokenizer, AutoModelForTokenClassification\n", "from transformers import DataCollatorForSeq2Seq\n", "from accelerate import Accelerator\n", "import evaluate\n", "import datasets\n", "\n", "from tqdm.auto import tqdm" ] }, { "cell_type": "code", "execution_count": 2, "id": "f1da6c6c", "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/kave/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/transformers/models/t5/tokenization_t5_fast.py:155: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n", "For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n", "- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.\n", "- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n", "- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.\n", " warnings.warn(\n" ] } ], "source": [ "tokenizer = AutoTokenizer.from_pretrained(\"t5-base\")\n", "model = AutoModelForSeq2SeqLM.from_pretrained(\"t5-base\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "a6de1719", "metadata": {}, "outputs": [], "source": [ "# prep dataset\n", "dataset = load_dataset(\"tner/mit_restaurant\")" ] }, { "cell_type": "code", "execution_count": 4, "id": "8617d7d6", "metadata": {}, "outputs": [], "source": [ "ner_tags = {\n", " \"O\": 0,\n", " \"B-Rating\": 1,\n", " \"I-Rating\": 2,\n", " \"B-Amenity\": 3,\n", " \"I-Amenity\": 4,\n", " \"B-Location\": 5,\n", " \"I-Location\": 6,\n", " \"B-Restaurant_Name\": 7,\n", " \"I-Restaurant_Name\": 8,\n", " \"B-Price\": 9,\n", " \"B-Hours\": 10,\n", " \"I-Hours\": 11,\n", " \"B-Dish\": 12,\n", " \"I-Dish\": 13,\n", " \"B-Cuisine\": 14,\n", " \"I-Price\": 15,\n", " \"I-Cuisine\": 16,\n", "}\n", "\n", "\n", "label_names = {v: k for k, v in ner_tags.items()}" ] }, { "cell_type": "code", "execution_count": 5, "id": "de52b597", "metadata": {}, "outputs": [], "source": [ "def decode_tags(tags, words):\n", " dict_out = {}\n", " word_ = \"\"\n", " for tag, word in zip(tags[::-1], words[::-1]):\n", " if tag == 0:\n", " continue\n", " word_ = word_ + \" \" + word\n", " if label_names[tag].startswith(\"B\"):\n", " tag_name = label_names[tag][2:]\n", " word_ = word_.strip()\n", " if tag_name not in dict_out:\n", " dict_out[tag_name] = [word_]\n", " else:\n", " dict_out[tag_name].append(word_)\n", " word_ = \"\"\n", " return dict_out\n", "\n", "\n", "def format_to_text(decoded):\n", " text = \"\"\n", " for key, value in decoded.items():\n", " text += f\"{key}: {', '.join(value)}\\n\"\n", " return text" ] }, { "cell_type": "code", "execution_count": 6, "id": "5da715a8", "metadata": {}, "outputs": [], "source": [ "def generate_t5_data(example):\n", " decoded = decode_tags(example[\"tags\"], example[\"tokens\"])\n", " return {\"tokens\": \" \".join(example[\"tokens\"]), \"labels\": format_to_text(decoded)}" ] }, { "cell_type": "code", "execution_count": 7, "id": "57416e20", "metadata": {}, "outputs": [], "source": [ "from transformers import T5Tokenizer, T5ForConditionalGeneration\n", "import torch\n", "\n", "# the following 2 hyperparameters are task-specific\n", "max_source_length = 512\n", "max_target_length = 128\n", "\n", "# encode the inputs\n", "task_prefix = \"What is the user intent?\"\n", "\n", "\n", "def tokenize(example):\n", " tokenized = tokenizer(\n", " task_prefix + example[\"tokens\"],\n", " text_target=example[\"labels\"],\n", " max_length=512,\n", " truncation=True,\n", " )\n", " return tokenized" ] }, { "cell_type": "code", "execution_count": 8, "id": "137905d7", "metadata": { "scrolled": true }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "23bafa0f97bc4d4da8a96397f0f3bd5a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/6900 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "tokenized_datasets = dataset.map(generate_t5_data)\n", "tokenized_datasets = tokenized_datasets.remove_columns([\"tags\"])\n", "tokenized_datasets = tokenized_datasets.map(tokenize)" ] }, { "cell_type": "code", "execution_count": 9, "id": "e2bdf1b0", "metadata": {}, "outputs": [], "source": [ "import evaluate\n", "\n", "metric = evaluate.load(\"sacrebleu\")" ] }, { "cell_type": "code", "execution_count": 10, "id": "cd9871bf", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "\n", "def compute_metrics(eval_preds):\n", " preds, labels = eval_preds\n", " # In case the model returns more than the prediction logits\n", " if isinstance(preds, tuple):\n", " preds = preds[0]\n", "\n", " decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n", "\n", " # Replace -100s in the labels as we can't decode them\n", " labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n", " decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n", "\n", " # Some simple post-processing\n", " decoded_preds = [pred.strip() for pred in decoded_preds]\n", " decoded_labels = [[label.strip()] for label in decoded_labels]\n", "\n", " result = metric.compute(predictions=decoded_preds, references=decoded_labels)\n", " return {\"bleu\": result[\"score\"]}" ] }, { "cell_type": "code", "execution_count": 11, "id": "09afe1d0", "metadata": {}, "outputs": [], "source": [ "data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)" ] }, { "cell_type": "code", "execution_count": 12, "id": "58e84fd1", "metadata": {}, "outputs": [], "source": [ "from transformers import Seq2SeqTrainingArguments\n", "\n", "args = Seq2SeqTrainingArguments(\n", " f\"T5 test\",\n", " evaluation_strategy=\"no\",\n", " save_strategy=\"epoch\",\n", " learning_rate=3e-4,\n", " per_device_train_batch_size=64,\n", " per_device_eval_batch_size=32,\n", " weight_decay=0.01,\n", " save_total_limit=3,\n", " num_train_epochs=20,\n", " predict_with_generate=True,\n", " fp16=True,\n", ")" ] }, { "cell_type": "code", "execution_count": 13, "id": "edfcbac1", "metadata": {}, "outputs": [], "source": [ "from transformers import Seq2SeqTrainer\n", "\n", "trainer = Seq2SeqTrainer(\n", " model,\n", " args,\n", " train_dataset=tokenized_datasets[\"train\"],\n", " eval_dataset=tokenized_datasets[\"validation\"],\n", " data_collator=data_collator,\n", " tokenizer=tokenizer,\n", " compute_metrics=compute_metrics,\n", ")" ] }, { "cell_type": "code", "execution_count": 14, "id": "e0065364", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n" ] }, { "data": { "text/html": [ "\n", "
Step | \n", "Training Loss | \n", "
---|
"
],
"text/plain": [
"