Kave Bahraman commited on
Commit
9896b0f
·
unverified ·
2 Parent(s): 5f923af a99b495

Merge pull request #3 from BerserkerMother/dev

Browse files
.gitignore CHANGED
@@ -3,6 +3,14 @@ __pycache__/
3
  *.py[cod]
4
  *$py.class
5
 
 
 
 
 
 
 
 
 
6
  # IDEs file
7
  .idea
8
 
 
3
  *.py[cod]
4
  *$py.class
5
 
6
+ # Project files
7
+ mlruns/
8
+ mlartifacts/
9
+ experiment_models/
10
+
11
+ # mlflow databases
12
+ mlflow.db
13
+
14
  # IDEs file
15
  .idea
16
 
.pylintrc CHANGED
@@ -428,7 +428,8 @@ disable=raw-checker-failed,
428
  suppressed-message,
429
  useless-suppression,
430
  deprecated-pragma,
431
- use-symbolic-message-instead
 
432
 
433
  # Enable the message, report, category or checker with the given id(s). You can
434
  # either give multiple identifier separated by comma (,) or put this option
 
428
  suppressed-message,
429
  useless-suppression,
430
  deprecated-pragma,
431
+ use-symbolic-message-instead,
432
+ R0902
433
 
434
  # Enable the message, report, category or checker with the given id(s). You can
435
  # either give multiple identifier separated by comma (,) or put this option
MLproject ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ name: Afterhours Elise Model pipeline
2
+
3
+ entry_points:
4
+ train_t5:
5
+ command: "python elise/src/train_t5_seq2seq.py"
6
+
Makefile CHANGED
@@ -1,10 +1,14 @@
1
  setup:
2
  pip install -r requirements.txt
 
 
3
  format:
4
  black elise/src/
5
  lint:
6
  pylint elise/src/
7
  gradio:
8
  python elise/src/app.py
 
 
9
  dev:
10
  make format lint gradio
 
1
  setup:
2
  pip install -r requirements.txt
3
+ test:
4
+ pytest
5
  format:
6
  black elise/src/
7
  lint:
8
  pylint elise/src/
9
  gradio:
10
  python elise/src/app.py
11
+ train_t5:
12
+ python elise/src/train_t5_seq2seq.py
13
  dev:
14
  make format lint gradio
elise/data_generation/data_generation_prompts.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Image you are assisting me generating data for training a T5 language model. Each record contains a user prompts where the user describes a place they want to dine, and user intentions and intention category which is label for training model. the labels are user intentions.
2
+ Intentions categories are:
3
+ - Cuisine
4
+ - Location
5
+ - Price
6
+ - Atmosphere
7
+ - Service
8
+ - Reviews
9
+ - Accessibility
10
+ - Amenity & Special features
11
+ - Offerings
12
+ - Recommendations
13
+ - Crowd
14
+ - Payment
15
+ - Category
16
+
17
+ Here is one example:
18
+ Prompt: I have a gluten allergy and need to find a restaurant with gluten-free options. Do you know any good ones in this area?
19
+ Label: { "Location": "in this area", "Dietary restrictions": "gluten-free" }
20
+
21
+ Write 5 random records in json format containing user's prompts and user's intentions.
elise/data_generation/prompt_generation.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Your task is to parse an unstructured job posting and turn it into a JSON containing the most important information. The job posting can describe one or more jobs at the same company. The JSON should consist of the following information:
2
+ - The company name (field name: "companyName", field type: string)
3
+ - the location of the company (field name: "companyLocation", field type: string); if not explictily stated, you can try to infer the company's actual location from other clues, e.g., something like "Remote (US)" usually means that the company is located in the US; if the location cannot be inferred, set it to null
4
+ - a short description of what the company is doing or building (field name: "companyDescription", field type: string); try to keep it short (max length: ca. 300 characters)
5
+ - a list of advertised jobs (field name: "jobs", field type: array).
6
+ Each element of the "jobs" array should contain the following fields:
7
+ - The job title (field name: "jobTitle", field type: string); the job title should be given in the singular form (i.e., Frontend Developer instead of Frontend Developers)
8
+ - the salary range (field name: "salary", field type: string); only include explictly stated salary amounts, otherwise set to null
9
+ - whether equity is part of the compensation (field name: "equity", field type: boolean)
10
+ - the benefits (field name: "benefits", field type: string); include things like 401k, insurance, equipment, child care, etc. if stated, otherwise set to null
11
+ - the location of the job (field name: "location", field type: string)
12
+ - whether this is a job for senior/experienced candidates (field name: "senior", field type: boolean); typically senior, staff, lead, principal, vp, cto, etc. positions are all regarded as senior level
13
+ - whether it is a remote opportunity (field name: "remote", field type: boolean)
14
+ - whether it can be done onsite from an office (field name: "onsite", field type: boolean)
15
+ - whether it can be done part-time (field name: "partTime", field type: boolean)
16
+ - whether it can be done full-time (field name: "fullTime", field type: boolean)
17
+ - the URL to the specific job description (field name: "jobUrl", field type: string)
18
+ - and any specific requirements/skills that might be stated (field name: "requirements", field type: string).
19
+ In general, if certain information is not stated, set the respective field to null. If the company seeks more than one person for the same role, include the role only once.
20
+
21
+ This is the job posting:
22
+
23
+ %s
24
+
25
+ The structured JSON representation is:
26
+ ```json
27
+ {"companyName":
elise/src/app.py CHANGED
@@ -10,7 +10,9 @@ from utils import df_to_json
10
 
11
 
12
  # prep models
13
- MODEL_CHECKPOINT = "BerserkerMother/restaurant_ner"
 
 
14
  parser = SentenceParser.from_huggingface(MODEL_CHECKPOINT)
15
 
16
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
 
10
 
11
 
12
  # prep models
13
+ MODEL_CHECKPOINT = (
14
+ "tner/roberta-large-mit-restaurant" # "BerserkerMother/restaurant_ner"
15
+ )
16
  parser = SentenceParser.from_huggingface(MODEL_CHECKPOINT)
17
 
18
  embedder = SentenceTransformer("all-MiniLM-L6-v2")
elise/src/configs/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """
2
+ All configs for ML project
3
+ """
4
+ from .train_t5 import T5TrainingConfig
elise/src/configs/logging_config.yaml CHANGED
@@ -1,12 +1,14 @@
1
  version: 1
 
2
  formatters:
3
  simple:
4
  format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
5
  handlers:
6
  console:
 
7
  class: logging.StreamHandler
8
  formatter: simple
9
  stream: ext://sys.stdout
10
- Root:
11
  Level: DEBUG
12
  handlers: [console]
 
1
  version: 1
2
+ disable_existing_loggers: False
3
  formatters:
4
  simple:
5
  format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
6
  handlers:
7
  console:
8
+ level: DEBUG
9
  class: logging.StreamHandler
10
  formatter: simple
11
  stream: ext://sys.stdout
12
+ root:
13
  Level: DEBUG
14
  handlers: [console]
elise/src/configs/train_t5.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training config for T5 Seq2Seq training
3
+ """
4
+
5
+ from dataclasses import dataclass
6
+
7
+
8
+ @dataclass
9
+ class T5TrainingConfig:
10
+ """Training configs for T5 finetuing"""
11
+
12
+ train_batch_size: int = 32
13
+ eval_batch_size: int = 32
14
+ epochs: int = 10
15
+ max_length: int = 512
16
+ learning_rate: float = 3e-4
17
+ num_warmup_steps: int = 200
18
+ mixed_precision: str = "bf16"
19
+ gradient_accumulation_steps: int = 4
20
+ output_dir: str = "FlanT5_MIT_ner"
elise/src/data/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """
2
+ Contians datasets and their connectors for model training
3
+ """
4
+
5
+ from .mit_seq2seq_dataset import MITRestaurants, get_default_transforms
elise/src/data/mit_seq2seq_dataset.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ seq2seq models datasets
3
+
4
+ Classes:
5
+ MITRestaurants: tner/mit_restaurant dataset to seq2seq
6
+
7
+ Functions:
8
+ get_default_transforms: default transforms for mit dataset
9
+ """
10
+ import datasets
11
+
12
+
13
+ class MITRestaurants:
14
+ """
15
+ tner/mit_restaurants for seq2seq
16
+
17
+ Atrributes
18
+ ----------
19
+ ner_tags: ner tags and ids of mit restaurant
20
+ dataset: hf dataset
21
+ transforms: transforms to apply
22
+ """
23
+
24
+ ner_tags = {
25
+ "O": 0,
26
+ "B-Rating": 1,
27
+ "I-Rating": 2,
28
+ "B-Amenity": 3,
29
+ "I-Amenity": 4,
30
+ "B-Location": 5,
31
+ "I-Location": 6,
32
+ "B-Restaurant_Name": 7,
33
+ "I-Restaurant_Name": 8,
34
+ "B-Price": 9,
35
+ "B-Hours": 10,
36
+ "I-Hours": 11,
37
+ "B-Dish": 12,
38
+ "I-Dish": 13,
39
+ "B-Cuisine": 14,
40
+ "I-Price": 15,
41
+ "I-Cuisine": 16,
42
+ }
43
+
44
+ def __init__(self, dataset: datasets.DatasetDict, transforms=None):
45
+ """
46
+ Constructs mit datasets
47
+
48
+ Parameters:
49
+ dataset: huggingface mit dataset
50
+ transforms: dataset transform functions
51
+ """
52
+ self.dataset = dataset
53
+ self.transforms = transforms
54
+
55
+ def hf_training(self):
56
+ """
57
+ Returns dataset for huggingface training ecosystem
58
+ """
59
+ dataset_ = self.dataset
60
+ if self.transforms:
61
+ for transfrom in self.transforms:
62
+ dataset_ = dataset_.map(transfrom)
63
+ return dataset_
64
+
65
+ def set_transforms(self, transforms):
66
+ """
67
+ Set tranfroms fn
68
+
69
+ Parameters:
70
+ transforms: transforms functions
71
+ """
72
+ if self.transforms:
73
+ self.transforms += transforms
74
+ else:
75
+ self.transforms = transforms
76
+ return self
77
+
78
+ @classmethod
79
+ def from_hf(cls, hf_path: str):
80
+ """
81
+ Constructs dataset from huggingface
82
+
83
+ Parameters:
84
+ hf_path: path to dataset hf repo
85
+ """
86
+ return cls(datasets.load_dataset(hf_path))
87
+
88
+
89
+ def get_default_transforms():
90
+ """
91
+ Default transformation to convert ner dataset to seq2seq
92
+ """
93
+ label_names = {v: k for k, v in MITRestaurants.ner_tags.items()}
94
+
95
+ def decode_tags(tags, words):
96
+ dict_out = {}
97
+ word_ = ""
98
+ for tag, word in zip(tags[::-1], words[::-1]):
99
+ if tag == 0:
100
+ continue
101
+ word_ = word + " " + word_
102
+ if label_names[tag].startswith("B"):
103
+ tag_name = label_names[tag][2:]
104
+ word_ = word_.strip()
105
+ if tag_name not in dict_out:
106
+ dict_out[tag_name] = [word_]
107
+ else:
108
+ dict_out[tag_name].append(word_)
109
+ word_ = ""
110
+ return dict_out
111
+
112
+ def format_to_text(decoded):
113
+ text = ""
114
+ for key, value in decoded.items():
115
+ text += f"{key}: {', '.join(value)}\n"
116
+ return text
117
+
118
+ def generate_seq2seq_data(example):
119
+ decoded = decode_tags(example["tags"], example["tokens"])
120
+ return {
121
+ "tokens": " ".join(example["tokens"]),
122
+ "labels": format_to_text(decoded),
123
+ }
124
+
125
+ return [generate_seq2seq_data]
elise/src/excutors/__init__.py ADDED
File without changes
elise/src/models/__init__.py ADDED
File without changes
elise/src/notebooks/flant_t5_playground.ipynb ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import transformers"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": 2,
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "pipe = transformers.pipeline(\n",
19
+ " \"text2text-generation\", model=\"/home/kave/work/Elise/output_dir/\"\n",
20
+ ")"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": 7,
26
+ "metadata": {},
27
+ "outputs": [
28
+ {
29
+ "data": {
30
+ "text/plain": [
31
+ "[{'generated_text': ''}]"
32
+ ]
33
+ },
34
+ "execution_count": 7,
35
+ "metadata": {},
36
+ "output_type": "execute_result"
37
+ }
38
+ ],
39
+ "source": [
40
+ "pipe(\"What are you?\")"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": []
49
+ }
50
+ ],
51
+ "metadata": {
52
+ "kernelspec": {
53
+ "display_name": "Python 3 (ipykernel)",
54
+ "language": "python",
55
+ "name": "python3"
56
+ },
57
+ "language_info": {
58
+ "codemirror_mode": {
59
+ "name": "ipython",
60
+ "version": 3
61
+ },
62
+ "file_extension": ".py",
63
+ "mimetype": "text/x-python",
64
+ "name": "python",
65
+ "nbconvert_exporter": "python",
66
+ "pygments_lexer": "ipython3",
67
+ "version": "3.10.0"
68
+ }
69
+ },
70
+ "nbformat": 4,
71
+ "nbformat_minor": 2
72
+ }
elise/src/notebooks/play.ipynb ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from pprint import pprint"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": 2,
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "from playground_prompt.version1 import prompts\n",
19
+ "\n",
20
+ "test_p = prompts[\"test_prompts\"]"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": 3,
26
+ "metadata": {},
27
+ "outputs": [],
28
+ "source": [
29
+ "import sys\n",
30
+ "\n",
31
+ "sys.path.append(\"/home/kave/work/Elise/elise/src\")"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": 4,
37
+ "metadata": {},
38
+ "outputs": [
39
+ {
40
+ "name": "stderr",
41
+ "output_type": "stream",
42
+ "text": [
43
+ "Using /home/kave/.cache/torch_extensions/py310_cu117 as PyTorch extensions root...\n",
44
+ "Detected CUDA files, patching ldflags\n",
45
+ "Emitting ninja build file /home/kave/.cache/torch_extensions/py310_cu117/cuda_kernel/build.ninja...\n",
46
+ "Building extension module cuda_kernel...\n",
47
+ "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n",
48
+ "Loading extension module cuda_kernel...\n",
49
+ "Failed to load CUDA kernels. Mra requires custom CUDA kernels. Please verify that compatible versions of PyTorch and CUDA Toolkit are installed: /home/kave/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/zmq/backend/cython/../../../../.././libstdc++.so.6: version `GLIBCXX_3.4.32' not found (required by /home/kave/.cache/torch_extensions/py310_cu117/cuda_kernel/cuda_kernel.so)\n"
50
+ ]
51
+ },
52
+ {
53
+ "name": "stdout",
54
+ "output_type": "stream",
55
+ "text": [
56
+ "ninja: no work to do.\n"
57
+ ]
58
+ }
59
+ ],
60
+ "source": [
61
+ "from parser import SentenceParser\n",
62
+ "\n",
63
+ "parser = SentenceParser.from_huggingface(\"BerserkerMother/restaurant_ner\")"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": 5,
69
+ "metadata": {},
70
+ "outputs": [
71
+ {
72
+ "name": "stdout",
73
+ "output_type": "stream",
74
+ "text": [
75
+ "[{'Cuisine': ['fast -']},\n",
76
+ " {'Amenity': ['lively']},\n",
77
+ " {'Amenity': ['nice', 'fun'],\n",
78
+ " 'Hours': ['dinner'],\n",
79
+ " 'Rating': ['good', 'good'],\n",
80
+ " 'Services': ['ambian']},\n",
81
+ " {'Cuisine': ['authentic indian']},\n",
82
+ " {'Amenity': ['romantic', 'nice']},\n",
83
+ " {},\n",
84
+ " {},\n",
85
+ " {'Amenity': ['live music']},\n",
86
+ " {'DS': ['gluten allergy', 'gluten - free']},\n",
87
+ " {'Amenity': ['cozy']}]\n"
88
+ ]
89
+ }
90
+ ],
91
+ "source": [
92
+ "ners = parser.get_ner(test_p)\n",
93
+ "pprint(ners)"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "execution_count": 6,
99
+ "metadata": {},
100
+ "outputs": [
101
+ {
102
+ "name": "stdout",
103
+ "output_type": "stream",
104
+ "text": [
105
+ "[{'Semantic': ['fast -']},\n",
106
+ " {'Semantic': ['lively']},\n",
107
+ " {'Hours': ['dinner'], 'Semantic': ['nice', 'fun']},\n",
108
+ " {'Semantic': ['authentic indian']},\n",
109
+ " {'Semantic': ['romantic', 'nice']},\n",
110
+ " {},\n",
111
+ " {},\n",
112
+ " {'Semantic': ['live music']},\n",
113
+ " {},\n",
114
+ " {'Semantic': ['cozy']}]\n"
115
+ ]
116
+ }
117
+ ],
118
+ "source": [
119
+ "parsed_prompts = parser.parse(ners)\n",
120
+ "pprint(parsed_prompts)"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": 7,
126
+ "metadata": {},
127
+ "outputs": [],
128
+ "source": [
129
+ "from sentence_transformers import SentenceTransformer\n",
130
+ "\n",
131
+ "embedder = SentenceTransformer(\"all-MiniLM-L6-v2\")"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": 8,
137
+ "metadata": {},
138
+ "outputs": [
139
+ {
140
+ "name": "stdout",
141
+ "output_type": "stream",
142
+ "text": [
143
+ " Accessibility \\\n",
144
+ "0 Wheelchair accessible entrance, Wheelchair acc... \n",
145
+ "1 Wheelchair accessible entrance, Wheelchair acc... \n",
146
+ "2 Wheelchair accessible entrance, Wheelchair acc... \n",
147
+ "3 Wheelchair accessible entrance, Wheelchair acc... \n",
148
+ "4 Wheelchair accessible entrance, Wheelchair acc... \n",
149
+ ".. ... \n",
150
+ "274 Wheelchair accessible entrance, Wheelchair acc... \n",
151
+ "275 Wheelchair accessible elevator, Wheelchair acc... \n",
152
+ "276 Wheelchair accessible entrance, Wheelchair acc... \n",
153
+ "277 Wheelchair accessible seating, Wheelchair acce... \n",
154
+ "278 Wheelchair accessible parking lot, Wheelchair ... \n",
155
+ "\n",
156
+ " Amenities & Special featrures \\\n",
157
+ "0 Bar onsite, Good for kids, High chairs, Restro... \n",
158
+ "1 Good for kids, High chairs, Restroom \n",
159
+ "2 Bar onsite, Dogs allowed, Good for kids, High ... \n",
160
+ "3 Restroom \n",
161
+ "4 Restroom \n",
162
+ ".. ... \n",
163
+ "274 Bar onsite, Good for kids, High chairs, Restro... \n",
164
+ "275 Bar onsite, High chairs, Restroom, Wi-Fi \n",
165
+ "276 Bar onsite, Dogs allowed, Good for kids, High ... \n",
166
+ "277 Bar onsite, High chairs, Restroom \n",
167
+ "278 Bar onsite, High chairs, Restroom \n",
168
+ "\n",
169
+ " Atmosphere \\\n",
170
+ "0 Casual, Cozy \n",
171
+ "1 Casual, Cozy \n",
172
+ "2 Casual, Cozy \n",
173
+ "3 Casual, Cozy \n",
174
+ "4 Casual, Cozy, Romantic \n",
175
+ ".. ... \n",
176
+ "274 Casual, Cozy \n",
177
+ "275 Casual, Cozy \n",
178
+ "276 Casual, Cozy, Romantic \n",
179
+ "277 Casual, Cozy \n",
180
+ "278 Casual, Cozy \n",
181
+ "\n",
182
+ " Crowd \\\n",
183
+ "0 Family friendly, Groups, LGBTQ+ friendly, Tran... \n",
184
+ "1 Family friendly, Groups \n",
185
+ "2 Family friendly, Groups \n",
186
+ "3 NaN \n",
187
+ "4 Groups \n",
188
+ ".. ... \n",
189
+ "274 Family friendly, Groups, LGBTQ+ friendly \n",
190
+ "275 Groups \n",
191
+ "276 Groups, LGBTQ+ friendly, Transgender safespace \n",
192
+ "277 Groups \n",
193
+ "278 Groups \n",
194
+ "\n",
195
+ " Dining options \\\n",
196
+ "0 Lunch, Dinner, Dessert, Seating \n",
197
+ "1 Lunch, Dinner, Dessert, Seating \n",
198
+ "2 Lunch, Dinner, Catering, Dessert, Seating \n",
199
+ "3 Breakfast, Lunch, Dessert, Seating \n",
200
+ "4 Dinner, Dessert, Seating \n",
201
+ ".. ... \n",
202
+ "274 Dinner, Counter service, Dessert, Seating \n",
203
+ "275 Breakfast, Brunch, Lunch, Dinner, Dessert, Sea... \n",
204
+ "276 Breakfast, Brunch, Lunch, Dinner, Dessert, Sea... \n",
205
+ "277 Brunch, Lunch, Dinner, Dessert, Seating \n",
206
+ "278 Brunch, Lunch, Dinner, Dessert, Seating \n",
207
+ "\n",
208
+ " Offerings \\\n",
209
+ "0 Alcohol, All you can eat, Beer, Cocktails, Cof... \n",
210
+ "1 Alcohol, All you can eat, Beer, Coffee, Halal,... \n",
211
+ "2 Alcohol, Beer, Coffee, Halal, Healthy options,... \n",
212
+ "3 Coffee \n",
213
+ "4 Alcohol, Beer, Cocktails, Coffee, Hard liquor,... \n",
214
+ ".. ... \n",
215
+ "274 Alcohol, Beer, Coffee, Healthy options, Kids' ... \n",
216
+ "275 Alcohol, Beer, Cocktails, Coffee, Hard liquor,... \n",
217
+ "276 Alcohol, Beer, Cocktails, Coffee, Hard liquor,... \n",
218
+ "277 Alcohol, Beer, Cocktails, Coffee, Hard liquor,... \n",
219
+ "278 Alcohol, Beer, Cocktails, Coffee, Hard liquor,... \n",
220
+ "\n",
221
+ " Payment Planning \\\n",
222
+ "0 Debit cards, NFC mobile payments, Credit cards Accepts reservations \n",
223
+ "1 Cash-only, Debit cards, NFC mobile payments Accepts reservations \n",
224
+ "2 Debit cards, NFC mobile payments, Credit cards Accepts reservations \n",
225
+ "3 Debit cards NaN \n",
226
+ "4 NaN Accepts reservations \n",
227
+ ".. ... ... \n",
228
+ "274 Debit cards, NFC mobile payments, Credit cards NaN \n",
229
+ "275 Debit cards, NFC mobile payments, Credit cards Accepts reservations \n",
230
+ "276 Debit cards, NFC mobile payments, Credit cards Accepts reservations \n",
231
+ "277 Debit cards, NFC mobile payments, Credit cards Accepts reservations \n",
232
+ "278 Debit cards, NFC mobile payments Accepts reservations \n",
233
+ "\n",
234
+ " Service \\\n",
235
+ "0 Outdoor seating, Delivery, Takeout, Dine-in \n",
236
+ "1 Outdoor seating, Dine-in \n",
237
+ "2 Outdoor seating, Curbside pickup, No-contact d... \n",
238
+ "3 Outdoor seating, Takeout, Dine-in \n",
239
+ "4 Outdoor seating, Delivery, Takeout, Dine-in \n",
240
+ ".. ... \n",
241
+ "274 Outdoor seating, Takeout, Dine-in, Delivery \n",
242
+ "275 Outdoor seating, Dine-in, Delivery, Takeout \n",
243
+ "276 Outdoor seating, Curbside pickup, Takeout, Din... \n",
244
+ "277 Outdoor seating, Takeout, Dine-in, Delivery \n",
245
+ "278 Outdoor seating, Dine-in, Delivery, Takeout \n",
246
+ "\n",
247
+ " categories \\\n",
248
+ "0 NaN \n",
249
+ "1 Korean barbecue restaurant \n",
250
+ "2 Indian restaurant, Asian restaurant, Health fo... \n",
251
+ "3 Cafe, Breakfast restaurant, Brunch restaurant,... \n",
252
+ "4 Thai restaurant \n",
253
+ ".. ... \n",
254
+ "274 NaN \n",
255
+ "275 , Cocktail bar, Coffee shop, Coworking space, ... \n",
256
+ "276 NaN \n",
257
+ "277 , Cafe \n",
258
+ "278 , Culinary school, Event venue, Function room ... \n",
259
+ "\n",
260
+ " Category \\\n",
261
+ "0 NaN \n",
262
+ "1 Korean barbecue restaurant \n",
263
+ "2 Indian restaurant \n",
264
+ "3 Cafe \n",
265
+ "4 Thai restaurant \n",
266
+ ".. ... \n",
267
+ "274 NaN \n",
268
+ "275 NaN \n",
269
+ "276 NaN \n",
270
+ "277 NaN \n",
271
+ "278 NaN \n",
272
+ "\n",
273
+ " description \n",
274
+ "0 NaN \n",
275
+ "1 NaN \n",
276
+ "2 Light-filled restaurant with colorful seating ... \n",
277
+ "3 Seattle-based coffeehouse chain known for its ... \n",
278
+ "4 NaN \n",
279
+ ".. ... \n",
280
+ "274 NaN \n",
281
+ "275 NaN \n",
282
+ "276 NaN \n",
283
+ "277 NaN \n",
284
+ "278 NaN \n",
285
+ "\n",
286
+ "[279 rows x 12 columns]\n"
287
+ ]
288
+ }
289
+ ],
290
+ "source": [
291
+ "from matcher import Matcher\n",
292
+ "\n",
293
+ "matcher = Matcher.from_path(\"/home/kave/work/Elise/elise/data/final_data.csv\", embedder)"
294
+ ]
295
+ },
296
+ {
297
+ "cell_type": "code",
298
+ "execution_count": 9,
299
+ "metadata": {},
300
+ "outputs": [
301
+ {
302
+ "data": {
303
+ "text/plain": [
304
+ "tensor([[ 0.0409, -0.0148, 0.0419, ..., 0.0347, -0.0333, 0.0381],\n",
305
+ " [ 0.0090, 0.0329, 0.0421, ..., 0.0387, -0.0885, 0.0523],\n",
306
+ " [ 0.0155, -0.0316, 0.0217, ..., 0.0344, -0.0364, 0.0223],\n",
307
+ " ...,\n",
308
+ " [ 0.0363, -0.0182, 0.0596, ..., 0.0276, -0.0117, 0.0335],\n",
309
+ " [ 0.0415, -0.0292, 0.0533, ..., 0.0515, -0.0225, 0.0368],\n",
310
+ " [ 0.0540, -0.0334, 0.0415, ..., 0.0756, -0.0274, 0.0454]],\n",
311
+ " device='cuda:0')"
312
+ ]
313
+ },
314
+ "execution_count": 9,
315
+ "metadata": {},
316
+ "output_type": "execute_result"
317
+ }
318
+ ],
319
+ "source": [
320
+ "matcher.semantics"
321
+ ]
322
+ },
323
+ {
324
+ "cell_type": "code",
325
+ "execution_count": 10,
326
+ "metadata": {},
327
+ "outputs": [
328
+ {
329
+ "name": "stdout",
330
+ "output_type": "stream",
331
+ "text": [
332
+ "CPU times: user 331 ms, sys: 0 ns, total: 331 ms\n",
333
+ "Wall time: 91.5 ms\n"
334
+ ]
335
+ }
336
+ ],
337
+ "source": [
338
+ "%%time\n",
339
+ "ners = parser.get_ner(test_p)\n",
340
+ "parsed_prompts = parser.parse(ners)\n",
341
+ "kk = matcher.handle_jobs(parsed_prompts)"
342
+ ]
343
+ },
344
+ {
345
+ "cell_type": "code",
346
+ "execution_count": 12,
347
+ "metadata": {},
348
+ "outputs": [],
349
+ "source": [
350
+ "from utils import df_to_json"
351
+ ]
352
+ },
353
+ {
354
+ "cell_type": "code",
355
+ "execution_count": 13,
356
+ "metadata": {},
357
+ "outputs": [
358
+ {
359
+ "data": {
360
+ "text/plain": [
361
+ "{\"I don't feel like cooking tonight. Where's a good place to get fast-food?\": [{'Name': 'Cafetaria Edison',\n",
362
+ " 'Score': 20.519733428955078},\n",
363
+ " {'Name': 'Five Guys', 'Score': 20.5447940826416},\n",
364
+ " {'Name': 'Five Guys', 'Score': 20.5447940826416}],\n",
365
+ " \"I'm planning a dinner with some friends. Any recommendations for a restaurant with a lively atmosphere?\": [{'Name': 'De Garde',\n",
366
+ " 'Score': 25.97490882873535},\n",
367
+ " {'Name': 'Hemel & Aarde', 'Score': 26.467164993286133},\n",
368
+ " {'Name': 'The Thai Orchid', 'Score': 36.23405838012695}],\n",
369
+ " \"I want to celebrate my graduation with a nice dinner out. What's a good restaurant with good food and a fun ambiance?\": [{'Name': 'Luc Utrecht',\n",
370
+ " 'Score': 20.497264862060547},\n",
371
+ " {'Name': 'Spice Monkey', 'Score': 20.64722442626953},\n",
372
+ " {'Name': 'Ethiopian Sunshine', 'Score': 20.84449005126953}],\n",
373
+ " \"I want to try some new cuisines I've never had before. Can you recommend a restaurant with authentic Indian food?\": [{'Name': 'India Port',\n",
374
+ " 'Score': 35.02317428588867},\n",
375
+ " {'Name': 'Kashmir Kitchen Utrecht', 'Score': 35.351844787597656},\n",
376
+ " {'Name': 'Surya Utrecht | Indiaas & Nepalees restaurant & bar',\n",
377
+ " 'Score': 40.53645706176758}],\n",
378
+ " \"I'm planning a special date night and want to go somewhere romantic. What's a good restaurant with a nice view?\": [{'Name': 'Sevilla',\n",
379
+ " 'Score': 20.176685333251953},\n",
380
+ " {'Name': 'Pand 33 Utrecht', 'Score': 20.325525283813477},\n",
381
+ " {'Name': 'Hemel & Aarde', 'Score': 23.112428665161133}],\n",
382
+ " \"I'm meeting a client for lunch. Can you recommend a good restaurant for a business meeting?\": [{'Name': 'Sevilla',\n",
383
+ " 'Score': 20.176685333251953},\n",
384
+ " {'Name': 'Pand 33 Utrecht', 'Score': 20.325525283813477},\n",
385
+ " {'Name': 'Hemel & Aarde', 'Score': 23.112428665161133}],\n",
386
+ " \"I'm traveling through this city and need to find a good place to eat. Any recommendations near the airport?\": [{'Name': 'Sevilla',\n",
387
+ " 'Score': 20.176685333251953},\n",
388
+ " {'Name': 'Pand 33 Utrecht', 'Score': 20.325525283813477},\n",
389
+ " {'Name': 'Hemel & Aarde', 'Score': 23.112428665161133}],\n",
390
+ " \"I'm looking for a restaurant with live music or other entertainment. Any suggestions?\": [{'Name': 'Silk Road Utrecht',\n",
391
+ " 'Score': 19.077302932739258},\n",
392
+ " {'Name': 'Hemel & Aarde', 'Score': 20.864463806152344},\n",
393
+ " {'Name': 'The Thai Orchid', 'Score': 26.012712478637695}],\n",
394
+ " 'I have a gluten allergy and need to find a restaurant with gluten-free options. Do you know any good ones in this area?': [{'Name': 'Silk Road Utrecht',\n",
395
+ " 'Score': 19.077302932739258},\n",
396
+ " {'Name': 'Hemel & Aarde', 'Score': 20.864463806152344},\n",
397
+ " {'Name': 'The Thai Orchid', 'Score': 26.012712478637695}],\n",
398
+ " \"I just want to relax and have a nice meal out. What's a good restaurant with a cozy atmosphere?\": [{'Name': 'De Garde',\n",
399
+ " 'Score': 34.23313903808594},\n",
400
+ " {'Name': 'Asia Street Cooking', 'Score': 34.71759796142578},\n",
401
+ " {'Name': 'Hemel & Aarde', 'Score': 37.8213005065918}]}"
402
+ ]
403
+ },
404
+ "execution_count": 13,
405
+ "metadata": {},
406
+ "output_type": "execute_result"
407
+ }
408
+ ],
409
+ "source": [
410
+ "df_to_json(kk, test_p)"
411
+ ]
412
+ },
413
+ {
414
+ "cell_type": "code",
415
+ "execution_count": null,
416
+ "metadata": {},
417
+ "outputs": [],
418
+ "source": []
419
+ }
420
+ ],
421
+ "metadata": {
422
+ "kernelspec": {
423
+ "display_name": "afterhours_dev",
424
+ "language": "python",
425
+ "name": "python3"
426
+ },
427
+ "language_info": {
428
+ "codemirror_mode": {
429
+ "name": "ipython",
430
+ "version": 3
431
+ },
432
+ "file_extension": ".py",
433
+ "mimetype": "text/x-python",
434
+ "name": "python",
435
+ "nbconvert_exporter": "python",
436
+ "pygments_lexer": "ipython3",
437
+ "version": "3.10.0"
438
+ },
439
+ "orig_nbformat": 4
440
+ },
441
+ "nbformat": 4,
442
+ "nbformat_minor": 2
443
+ }
elise/src/notebooks/playground_prompt/version1.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prompts = {
2
+ "instruction": """Identify user needs from the prompts.""",
3
+ "examples": """
4
+ Prompt: I want to celebrate my graduation with my friend. Recommend me somewhere nice with live music near the Harvard campus.
5
+ Identification:
6
+ occasion: celebrating graduation
7
+ features: somewhere nice, live music
8
+ location: near Harvard campus
9
+ """,
10
+ "test_prompts": [
11
+ "I don't feel like cooking tonight. Where's a good place to get fast-food?",
12
+ "I'm planning a dinner with some friends. Any recommendations for a restaurant with a lively atmosphere?",
13
+ "I want to celebrate my graduation with a nice dinner out. What's a good restaurant with good food and a fun ambiance?",
14
+ "I want to try some new cuisines I've never had before. Can you recommend a restaurant with authentic Indian food?",
15
+ "I'm planning a special date night and want to go somewhere romantic. What's a good restaurant with a nice view?",
16
+ "I'm meeting a client for lunch. Can you recommend a good restaurant for a business meeting?",
17
+ "I'm traveling through this city and need to find a good place to eat. Any recommendations near the airport?",
18
+ "I'm looking for a restaurant with live music or other entertainment. Any suggestions?",
19
+ "I have a gluten allergy and need to find a restaurant with gluten-free options. Do you know any good ones in this area?",
20
+ "I just want to relax and have a nice meal out. What's a good restaurant with a cozy atmosphere?",
21
+ ],
22
+ }
elise/src/notebooks/t5 funetinung.ipynb ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "9510dd98",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM\n",
11
+ "from datasets import load_dataset\n",
12
+ "from transformers import get_scheduler\n",
13
+ "import torch\n",
14
+ "from torch.utils.data import DataLoader\n",
15
+ "from datasets import load_dataset\n",
16
+ "from transformers import AutoTokenizer, AutoModelForTokenClassification\n",
17
+ "from transformers import DataCollatorForSeq2Seq\n",
18
+ "from accelerate import Accelerator\n",
19
+ "import evaluate\n",
20
+ "import datasets\n",
21
+ "\n",
22
+ "from tqdm.auto import tqdm"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": 2,
28
+ "id": "f1da6c6c",
29
+ "metadata": {
30
+ "scrolled": true
31
+ },
32
+ "outputs": [
33
+ {
34
+ "name": "stderr",
35
+ "output_type": "stream",
36
+ "text": [
37
+ "/home/kave/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/transformers/models/t5/tokenization_t5_fast.py:155: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n",
38
+ "For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n",
39
+ "- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.\n",
40
+ "- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n",
41
+ "- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.\n",
42
+ " warnings.warn(\n"
43
+ ]
44
+ }
45
+ ],
46
+ "source": [
47
+ "tokenizer = AutoTokenizer.from_pretrained(\"t5-base\")\n",
48
+ "model = AutoModelForSeq2SeqLM.from_pretrained(\"t5-base\")"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": 3,
54
+ "id": "a6de1719",
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "# prep dataset\n",
59
+ "dataset = load_dataset(\"tner/mit_restaurant\")"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": 4,
65
+ "id": "8617d7d6",
66
+ "metadata": {},
67
+ "outputs": [],
68
+ "source": [
69
+ "ner_tags = {\n",
70
+ " \"O\": 0,\n",
71
+ " \"B-Rating\": 1,\n",
72
+ " \"I-Rating\": 2,\n",
73
+ " \"B-Amenity\": 3,\n",
74
+ " \"I-Amenity\": 4,\n",
75
+ " \"B-Location\": 5,\n",
76
+ " \"I-Location\": 6,\n",
77
+ " \"B-Restaurant_Name\": 7,\n",
78
+ " \"I-Restaurant_Name\": 8,\n",
79
+ " \"B-Price\": 9,\n",
80
+ " \"B-Hours\": 10,\n",
81
+ " \"I-Hours\": 11,\n",
82
+ " \"B-Dish\": 12,\n",
83
+ " \"I-Dish\": 13,\n",
84
+ " \"B-Cuisine\": 14,\n",
85
+ " \"I-Price\": 15,\n",
86
+ " \"I-Cuisine\": 16,\n",
87
+ "}\n",
88
+ "\n",
89
+ "\n",
90
+ "label_names = {v: k for k, v in ner_tags.items()}"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": 5,
96
+ "id": "de52b597",
97
+ "metadata": {},
98
+ "outputs": [],
99
+ "source": [
100
+ "def decode_tags(tags, words):\n",
101
+ " dict_out = {}\n",
102
+ " word_ = \"\"\n",
103
+ " for tag, word in zip(tags[::-1], words[::-1]):\n",
104
+ " if tag == 0:\n",
105
+ " continue\n",
106
+ " word_ = word_ + \" \" + word\n",
107
+ " if label_names[tag].startswith(\"B\"):\n",
108
+ " tag_name = label_names[tag][2:]\n",
109
+ " word_ = word_.strip()\n",
110
+ " if tag_name not in dict_out:\n",
111
+ " dict_out[tag_name] = [word_]\n",
112
+ " else:\n",
113
+ " dict_out[tag_name].append(word_)\n",
114
+ " word_ = \"\"\n",
115
+ " return dict_out\n",
116
+ "\n",
117
+ "\n",
118
+ "def format_to_text(decoded):\n",
119
+ " text = \"\"\n",
120
+ " for key, value in decoded.items():\n",
121
+ " text += f\"{key}: {', '.join(value)}\\n\"\n",
122
+ " return text"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "code",
127
+ "execution_count": 6,
128
+ "id": "5da715a8",
129
+ "metadata": {},
130
+ "outputs": [],
131
+ "source": [
132
+ "def generate_t5_data(example):\n",
133
+ " decoded = decode_tags(example[\"tags\"], example[\"tokens\"])\n",
134
+ " return {\"tokens\": \" \".join(example[\"tokens\"]), \"labels\": format_to_text(decoded)}"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": 7,
140
+ "id": "57416e20",
141
+ "metadata": {},
142
+ "outputs": [],
143
+ "source": [
144
+ "from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
145
+ "import torch\n",
146
+ "\n",
147
+ "# the following 2 hyperparameters are task-specific\n",
148
+ "max_source_length = 512\n",
149
+ "max_target_length = 128\n",
150
+ "\n",
151
+ "# encode the inputs\n",
152
+ "task_prefix = \"What is the user intent?\"\n",
153
+ "\n",
154
+ "\n",
155
+ "def tokenize(example):\n",
156
+ " tokenized = tokenizer(\n",
157
+ " task_prefix + example[\"tokens\"],\n",
158
+ " text_target=example[\"labels\"],\n",
159
+ " max_length=512,\n",
160
+ " truncation=True,\n",
161
+ " )\n",
162
+ " return tokenized"
163
+ ]
164
+ },
165
+ {
166
+ "cell_type": "code",
167
+ "execution_count": 8,
168
+ "id": "137905d7",
169
+ "metadata": {
170
+ "scrolled": true
171
+ },
172
+ "outputs": [
173
+ {
174
+ "data": {
175
+ "application/vnd.jupyter.widget-view+json": {
176
+ "model_id": "23bafa0f97bc4d4da8a96397f0f3bd5a",
177
+ "version_major": 2,
178
+ "version_minor": 0
179
+ },
180
+ "text/plain": [
181
+ "Map: 0%| | 0/6900 [00:00<?, ? examples/s]"
182
+ ]
183
+ },
184
+ "metadata": {},
185
+ "output_type": "display_data"
186
+ }
187
+ ],
188
+ "source": [
189
+ "tokenized_datasets = dataset.map(generate_t5_data)\n",
190
+ "tokenized_datasets = tokenized_datasets.remove_columns([\"tags\"])\n",
191
+ "tokenized_datasets = tokenized_datasets.map(tokenize)"
192
+ ]
193
+ },
194
+ {
195
+ "cell_type": "code",
196
+ "execution_count": 9,
197
+ "id": "e2bdf1b0",
198
+ "metadata": {},
199
+ "outputs": [],
200
+ "source": [
201
+ "import evaluate\n",
202
+ "\n",
203
+ "metric = evaluate.load(\"sacrebleu\")"
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "code",
208
+ "execution_count": 10,
209
+ "id": "cd9871bf",
210
+ "metadata": {},
211
+ "outputs": [],
212
+ "source": [
213
+ "import numpy as np\n",
214
+ "\n",
215
+ "\n",
216
+ "def compute_metrics(eval_preds):\n",
217
+ " preds, labels = eval_preds\n",
218
+ " # In case the model returns more than the prediction logits\n",
219
+ " if isinstance(preds, tuple):\n",
220
+ " preds = preds[0]\n",
221
+ "\n",
222
+ " decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n",
223
+ "\n",
224
+ " # Replace -100s in the labels as we can't decode them\n",
225
+ " labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n",
226
+ " decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
227
+ "\n",
228
+ " # Some simple post-processing\n",
229
+ " decoded_preds = [pred.strip() for pred in decoded_preds]\n",
230
+ " decoded_labels = [[label.strip()] for label in decoded_labels]\n",
231
+ "\n",
232
+ " result = metric.compute(predictions=decoded_preds, references=decoded_labels)\n",
233
+ " return {\"bleu\": result[\"score\"]}"
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "code",
238
+ "execution_count": 11,
239
+ "id": "09afe1d0",
240
+ "metadata": {},
241
+ "outputs": [],
242
+ "source": [
243
+ "data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)"
244
+ ]
245
+ },
246
+ {
247
+ "cell_type": "code",
248
+ "execution_count": 12,
249
+ "id": "58e84fd1",
250
+ "metadata": {},
251
+ "outputs": [],
252
+ "source": [
253
+ "from transformers import Seq2SeqTrainingArguments\n",
254
+ "\n",
255
+ "args = Seq2SeqTrainingArguments(\n",
256
+ " f\"T5 test\",\n",
257
+ " evaluation_strategy=\"no\",\n",
258
+ " save_strategy=\"epoch\",\n",
259
+ " learning_rate=3e-4,\n",
260
+ " per_device_train_batch_size=64,\n",
261
+ " per_device_eval_batch_size=32,\n",
262
+ " weight_decay=0.01,\n",
263
+ " save_total_limit=3,\n",
264
+ " num_train_epochs=20,\n",
265
+ " predict_with_generate=True,\n",
266
+ " fp16=True,\n",
267
+ ")"
268
+ ]
269
+ },
270
+ {
271
+ "cell_type": "code",
272
+ "execution_count": 13,
273
+ "id": "edfcbac1",
274
+ "metadata": {},
275
+ "outputs": [],
276
+ "source": [
277
+ "from transformers import Seq2SeqTrainer\n",
278
+ "\n",
279
+ "trainer = Seq2SeqTrainer(\n",
280
+ " model,\n",
281
+ " args,\n",
282
+ " train_dataset=tokenized_datasets[\"train\"],\n",
283
+ " eval_dataset=tokenized_datasets[\"validation\"],\n",
284
+ " data_collator=data_collator,\n",
285
+ " tokenizer=tokenizer,\n",
286
+ " compute_metrics=compute_metrics,\n",
287
+ ")"
288
+ ]
289
+ },
290
+ {
291
+ "cell_type": "code",
292
+ "execution_count": 14,
293
+ "id": "e0065364",
294
+ "metadata": {},
295
+ "outputs": [
296
+ {
297
+ "name": "stderr",
298
+ "output_type": "stream",
299
+ "text": [
300
+ "You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
301
+ ]
302
+ },
303
+ {
304
+ "data": {
305
+ "text/html": [
306
+ "\n",
307
+ " <div>\n",
308
+ " \n",
309
+ " <progress value='12' max='12' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
310
+ " [12/12 00:15]\n",
311
+ " </div>\n",
312
+ " "
313
+ ],
314
+ "text/plain": [
315
+ "<IPython.core.display.HTML object>"
316
+ ]
317
+ },
318
+ "metadata": {},
319
+ "output_type": "display_data"
320
+ },
321
+ {
322
+ "name": "stderr",
323
+ "output_type": "stream",
324
+ "text": [
325
+ "Trainer is attempting to log a value of \"{'summarization': {'early_stopping': True, 'length_penalty': 2.0, 'max_length': 200, 'min_length': 30, 'no_repeat_ngram_size': 3, 'num_beams': 4, 'prefix': 'summarize: '}, 'translation_en_to_de': {'early_stopping': True, 'max_length': 300, 'num_beams': 4, 'prefix': 'translate English to German: '}, 'translation_en_to_fr': {'early_stopping': True, 'max_length': 300, 'num_beams': 4, 'prefix': 'translate English to French: '}, 'translation_en_to_ro': {'early_stopping': True, 'max_length': 300, 'num_beams': 4, 'prefix': 'translate English to Romanian: '}}\" for key \"task_specific_params\" as a parameter. MLflow's log_param() only accepts values no longer than 250 characters so we dropped this attribute. You can use `MLFLOW_FLATTEN_PARAMS` environment variable to flatten the parameters and avoid this message.\n"
326
+ ]
327
+ },
328
+ {
329
+ "data": {
330
+ "text/plain": [
331
+ "{'eval_loss': 6.675447940826416,\n",
332
+ " 'eval_bleu': 0.006728795795564811,\n",
333
+ " 'eval_runtime': 17.5858,\n",
334
+ " 'eval_samples_per_second': 43.217,\n",
335
+ " 'eval_steps_per_second': 0.682}"
336
+ ]
337
+ },
338
+ "execution_count": 14,
339
+ "metadata": {},
340
+ "output_type": "execute_result"
341
+ }
342
+ ],
343
+ "source": [
344
+ "trainer.evaluate(max_length=512)"
345
+ ]
346
+ },
347
+ {
348
+ "cell_type": "code",
349
+ "execution_count": 15,
350
+ "id": "64ad307b",
351
+ "metadata": {},
352
+ "outputs": [
353
+ {
354
+ "name": "stderr",
355
+ "output_type": "stream",
356
+ "text": [
357
+ "/home/kave/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/transformers/optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
358
+ " warnings.warn(\n"
359
+ ]
360
+ },
361
+ {
362
+ "data": {
363
+ "text/html": [
364
+ "\n",
365
+ " <div>\n",
366
+ " \n",
367
+ " <progress value='4' max='1080' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
368
+ " [ 4/1080 00:01 < 09:22, 1.91 it/s, Epoch 0.06/20]\n",
369
+ " </div>\n",
370
+ " <table border=\"1\" class=\"dataframe\">\n",
371
+ " <thead>\n",
372
+ " <tr style=\"text-align: left;\">\n",
373
+ " <th>Step</th>\n",
374
+ " <th>Training Loss</th>\n",
375
+ " </tr>\n",
376
+ " </thead>\n",
377
+ " <tbody>\n",
378
+ " </tbody>\n",
379
+ "</table><p>"
380
+ ],
381
+ "text/plain": [
382
+ "<IPython.core.display.HTML object>"
383
+ ]
384
+ },
385
+ "metadata": {},
386
+ "output_type": "display_data"
387
+ },
388
+ {
389
+ "ename": "OutOfMemoryError",
390
+ "evalue": "CUDA out of memory. Tried to allocate 456.00 MiB (GPU 0; 11.75 GiB total capacity; 10.26 GiB already allocated; 131.12 MiB free; 10.83 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF",
391
+ "output_type": "error",
392
+ "traceback": [
393
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
394
+ "\u001b[0;31mOutOfMemoryError\u001b[0m Traceback (most recent call last)",
395
+ "Cell \u001b[0;32mIn[15], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
396
+ "File \u001b[0;32m~/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/transformers/trainer.py:1539\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 1534\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_wrapped \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\n\u001b[1;32m 1536\u001b[0m inner_training_loop \u001b[38;5;241m=\u001b[39m find_executable_batch_size(\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_inner_training_loop, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_train_batch_size, args\u001b[38;5;241m.\u001b[39mauto_find_batch_size\n\u001b[1;32m 1538\u001b[0m )\n\u001b[0;32m-> 1539\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1540\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1541\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1542\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1544\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
397
+ "File \u001b[0;32m~/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/transformers/trainer.py:1809\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 1806\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_step_begin(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[1;32m 1808\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39maccumulate(model):\n\u001b[0;32m-> 1809\u001b[0m tr_loss_step \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1811\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 1812\u001b[0m args\u001b[38;5;241m.\u001b[39mlogging_nan_inf_filter\n\u001b[1;32m 1813\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_tpu_available()\n\u001b[1;32m 1814\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39misnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch\u001b[38;5;241m.\u001b[39misinf(tr_loss_step))\n\u001b[1;32m 1815\u001b[0m ):\n\u001b[1;32m 1816\u001b[0m \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[1;32m 1817\u001b[0m tr_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m tr_loss \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_globalstep_last_logged)\n",
398
+ "File \u001b[0;32m~/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/transformers/trainer.py:2654\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[0;34m(self, model, inputs)\u001b[0m\n\u001b[1;32m 2651\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss_mb\u001b[38;5;241m.\u001b[39mreduce_mean()\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m 2653\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcompute_loss_context_manager():\n\u001b[0;32m-> 2654\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2656\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mn_gpu \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 2657\u001b[0m loss \u001b[38;5;241m=\u001b[39m loss\u001b[38;5;241m.\u001b[39mmean() \u001b[38;5;66;03m# mean() to average on multi-gpu parallel training\u001b[39;00m\n",
399
+ "File \u001b[0;32m~/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/transformers/trainer.py:2679\u001b[0m, in \u001b[0;36mTrainer.compute_loss\u001b[0;34m(self, model, inputs, return_outputs)\u001b[0m\n\u001b[1;32m 2677\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 2678\u001b[0m labels \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m-> 2679\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2680\u001b[0m \u001b[38;5;66;03m# Save past state if it exists\u001b[39;00m\n\u001b[1;32m 2681\u001b[0m \u001b[38;5;66;03m# TODO: this needs to be fixed and made cleaner later.\u001b[39;00m\n\u001b[1;32m 2682\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mpast_index \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m:\n",
400
+ "File \u001b[0;32m~/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
401
+ "File \u001b[0;32m~/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/accelerate/utils/operations.py:581\u001b[0m, in \u001b[0;36mconvert_outputs_to_fp32.<locals>.forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 580\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 581\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmodel_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
402
+ "File \u001b[0;32m~/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/accelerate/utils/operations.py:569\u001b[0m, in \u001b[0;36mConvertOutputsToFp32.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 568\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 569\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m convert_to_fp32(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m)\n",
403
+ "File \u001b[0;32m~/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/torch/amp/autocast_mode.py:14\u001b[0m, in \u001b[0;36mautocast_decorator.<locals>.decorate_autocast\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_autocast\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m autocast_instance:\n\u001b[0;32m---> 14\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
404
+ "File \u001b[0;32m~/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/accelerate/utils/operations.py:581\u001b[0m, in \u001b[0;36mconvert_outputs_to_fp32.<locals>.forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 580\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 581\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmodel_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
405
+ "File \u001b[0;32m~/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/accelerate/utils/operations.py:569\u001b[0m, in \u001b[0;36mConvertOutputsToFp32.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 568\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 569\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mconvert_to_fp32\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n",
406
+ "File \u001b[0;32m~/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/accelerate/utils/operations.py:548\u001b[0m, in \u001b[0;36mconvert_to_fp32\u001b[0;34m(tensor)\u001b[0m\n\u001b[1;32m 545\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_is_fp16_bf16_tensor\u001b[39m(tensor):\n\u001b[1;32m 546\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(tensor, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdtype\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01mand\u001b[39;00m tensor\u001b[38;5;241m.\u001b[39mdtype \u001b[38;5;129;01min\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39mfloat16, torch\u001b[38;5;241m.\u001b[39mbfloat16)\n\u001b[0;32m--> 548\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mrecursively_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_convert_to_fp32\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtensor\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_is_fp16_bf16_tensor\u001b[49m\u001b[43m)\u001b[49m\n",
407
+ "File \u001b[0;32m~/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/accelerate/utils/operations.py:120\u001b[0m, in \u001b[0;36mrecursively_apply\u001b[0;34m(func, data, test_type, error_on_other_type, *args, **kwargs)\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m honor_type(\n\u001b[1;32m 110\u001b[0m data,\n\u001b[1;32m 111\u001b[0m (\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 116\u001b[0m ),\n\u001b[1;32m 117\u001b[0m )\n\u001b[1;32m 118\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data, Mapping):\n\u001b[1;32m 119\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(data)(\n\u001b[0;32m--> 120\u001b[0m {\n\u001b[1;32m 121\u001b[0m k: recursively_apply(\n\u001b[1;32m 122\u001b[0m func, v, \u001b[38;5;241m*\u001b[39margs, test_type\u001b[38;5;241m=\u001b[39mtest_type, error_on_other_type\u001b[38;5;241m=\u001b[39merror_on_other_type, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs\n\u001b[1;32m 123\u001b[0m )\n\u001b[1;32m 124\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m data\u001b[38;5;241m.\u001b[39mitems()\n\u001b[1;32m 125\u001b[0m }\n\u001b[1;32m 126\u001b[0m )\n\u001b[1;32m 127\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m test_type(data):\n\u001b[1;32m 128\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m func(data, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
408
+ "File \u001b[0;32m~/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/accelerate/utils/operations.py:121\u001b[0m, in \u001b[0;36m<dictcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m honor_type(\n\u001b[1;32m 110\u001b[0m data,\n\u001b[1;32m 111\u001b[0m (\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 116\u001b[0m ),\n\u001b[1;32m 117\u001b[0m )\n\u001b[1;32m 118\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data, Mapping):\n\u001b[1;32m 119\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(data)(\n\u001b[1;32m 120\u001b[0m {\n\u001b[0;32m--> 121\u001b[0m k: \u001b[43mrecursively_apply\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 122\u001b[0m \u001b[43m \u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtest_type\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43merror_on_other_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43merror_on_other_type\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 123\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 124\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m data\u001b[38;5;241m.\u001b[39mitems()\n\u001b[1;32m 125\u001b[0m }\n\u001b[1;32m 126\u001b[0m )\n\u001b[1;32m 127\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m test_type(data):\n\u001b[1;32m 128\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m func(data, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
409
+ "File \u001b[0;32m~/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/accelerate/utils/operations.py:128\u001b[0m, in \u001b[0;36mrecursively_apply\u001b[0;34m(func, data, test_type, error_on_other_type, *args, **kwargs)\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(data)(\n\u001b[1;32m 120\u001b[0m {\n\u001b[1;32m 121\u001b[0m k: recursively_apply(\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 125\u001b[0m }\n\u001b[1;32m 126\u001b[0m )\n\u001b[1;32m 127\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m test_type(data):\n\u001b[0;32m--> 128\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 129\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m error_on_other_type:\n\u001b[1;32m 130\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[1;32m 131\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnsupported types (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(data)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m) passed to `\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m`. Only nested list/tuple/dicts of \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 132\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mobjects that are valid for `\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtest_type\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m` should be passed.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 133\u001b[0m )\n",
410
+ "File \u001b[0;32m~/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/accelerate/utils/operations.py:543\u001b[0m, in \u001b[0;36mconvert_to_fp32.<locals>._convert_to_fp32\u001b[0;34m(tensor)\u001b[0m\n\u001b[1;32m 542\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_convert_to_fp32\u001b[39m(tensor):\n\u001b[0;32m--> 543\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtensor\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfloat\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
411
+ "\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 456.00 MiB (GPU 0; 11.75 GiB total capacity; 10.26 GiB already allocated; 131.12 MiB free; 10.83 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"
412
+ ]
413
+ }
414
+ ],
415
+ "source": [
416
+ "trainer.train()"
417
+ ]
418
+ },
419
+ {
420
+ "cell_type": "code",
421
+ "execution_count": null,
422
+ "id": "6672ff53",
423
+ "metadata": {},
424
+ "outputs": [],
425
+ "source": []
426
+ }
427
+ ],
428
+ "metadata": {
429
+ "kernelspec": {
430
+ "display_name": "Python 3 (ipykernel)",
431
+ "language": "python",
432
+ "name": "python3"
433
+ },
434
+ "language_info": {
435
+ "codemirror_mode": {
436
+ "name": "ipython",
437
+ "version": 3
438
+ },
439
+ "file_extension": ".py",
440
+ "mimetype": "text/x-python",
441
+ "name": "python",
442
+ "nbconvert_exporter": "python",
443
+ "pygments_lexer": "ipython3",
444
+ "version": "3.10.0"
445
+ }
446
+ },
447
+ "nbformat": 4,
448
+ "nbformat_minor": 5
449
+ }
elise/src/train_t5_seq2seq.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training Flant_T5 model on tner/mit_restaurant on seq2seq task
3
+ """
4
+ from dataclasses import asdict
5
+
6
+ import torch
7
+ import evaluate
8
+ import datasets
9
+ from torch.utils.data import DataLoader
10
+ from transformers import (
11
+ AutoTokenizer,
12
+ AutoModelForSeq2SeqLM,
13
+ DataCollatorForSeq2Seq,
14
+ get_scheduler,
15
+ )
16
+ from accelerate import Accelerator
17
+ import numpy as np
18
+ import mlflow
19
+
20
+ from tqdm.auto import tqdm
21
+
22
+ from utils.logger import get_logger
23
+ from configs import T5TrainingConfig
24
+ from data import MITRestaurants, get_default_transforms
25
+
26
+ log = get_logger("Flan_T5")
27
+ log.debug("heloooooooooooo?")
28
+
29
+ # get dataset
30
+ transforms = get_default_transforms()
31
+ dataset = (
32
+ MITRestaurants.from_hf("tner/mit_restaurant")
33
+ .set_transforms(transforms)
34
+ .hf_training()
35
+ )
36
+ dataset["train"] = datasets.concatenate_datasets([dataset["train"], dataset["test"]])
37
+ # log.info(dataset)
38
+ print(dataset)
39
+ tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
40
+ model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
41
+
42
+
43
+ def tokenize(example):
44
+ """Tokenizes dataset for seq2seq task"""
45
+ tokenized = tokenizer(
46
+ example["tokens"],
47
+ text_target=example["labels"],
48
+ max_length=512,
49
+ truncation=True,
50
+ )
51
+
52
+ return tokenized
53
+
54
+
55
+ tokenized_datasets = dataset.map(
56
+ tokenize,
57
+ batched=True,
58
+ remove_columns=dataset["train"].column_names,
59
+ )
60
+
61
+ # bleu metric
62
+ metric = evaluate.load("sacrebleu")
63
+
64
+
65
+ def postprocess(predictions, labels):
66
+ """Post processing to convert model output for evaluation"""
67
+ predictions = predictions.cpu().numpy()
68
+ labels = labels.cpu().numpy()
69
+
70
+ decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
71
+
72
+ # Replace -100 in the labels as we can't decode them.
73
+ labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
74
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
75
+
76
+ # Some simple post-processing
77
+ decoded_preds = [pred.strip() for pred in decoded_preds]
78
+ decoded_labels = [[label.strip()] for label in decoded_labels]
79
+ return decoded_preds, decoded_labels
80
+
81
+
82
+ config = T5TrainingConfig()
83
+
84
+ # data collator
85
+ data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
86
+
87
+ # data loaders
88
+ tokenized_datasets.set_format("torch")
89
+ train_dataloader = DataLoader(
90
+ tokenized_datasets["train"],
91
+ shuffle=True,
92
+ collate_fn=data_collator,
93
+ batch_size=config.train_batch_size,
94
+ )
95
+ eval_dataloader = DataLoader(
96
+ tokenized_datasets["validation"],
97
+ collate_fn=data_collator,
98
+ batch_size=config.eval_batch_size,
99
+ )
100
+
101
+ # optimizer
102
+ optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
103
+ num_update_steps_per_epoch = len(train_dataloader)
104
+ num_training_steps = config.epochs * num_update_steps_per_epoch
105
+
106
+ lr_scheduler = get_scheduler(
107
+ "linear",
108
+ optimizer=optimizer,
109
+ num_warmup_steps=config.num_warmup_steps,
110
+ num_training_steps=num_training_steps,
111
+ )
112
+
113
+ # accelerator
114
+ accelerator = Accelerator(
115
+ mixed_precision=config.mixed_precision,
116
+ gradient_accumulation_steps=config.gradient_accumulation_steps,
117
+ )
118
+ model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
119
+ model, optimizer, train_dataloader, eval_dataloader
120
+ )
121
+
122
+ progress_bar = tqdm(range(num_training_steps))
123
+
124
+
125
+ def train():
126
+ """Training function for finetuing flanT5"""
127
+ # log.info("Starting Training")
128
+ print("Starting Traning")
129
+ for epoch in range(config.epochs):
130
+ # Training
131
+ model.train()
132
+ for batch in train_dataloader:
133
+ with accelerator.accumulate(model):
134
+ outputs = model(**batch)
135
+ loss = outputs.loss
136
+ accelerator.backward(loss)
137
+
138
+ optimizer.step()
139
+ lr_scheduler.step()
140
+ optimizer.zero_grad()
141
+ progress_bar.update(1)
142
+
143
+ # Evaluation
144
+ model.eval()
145
+ for batch in tqdm(eval_dataloader):
146
+ with torch.no_grad():
147
+ generated_tokens = accelerator.unwrap_model(model).generate(
148
+ batch["input_ids"],
149
+ attention_mask=batch["attention_mask"],
150
+ max_length=128,
151
+ )
152
+ labels = batch["labels"]
153
+
154
+ # Necessary to pad predictions and labels for being gathered
155
+ generated_tokens = accelerator.pad_across_processes(
156
+ generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
157
+ )
158
+ labels = accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
159
+
160
+ predictions_gathered = accelerator.gather(generated_tokens)
161
+ labels_gathered = accelerator.gather(labels)
162
+
163
+ decoded_preds, decoded_labels = postprocess(
164
+ predictions_gathered, labels_gathered
165
+ )
166
+ metric.add_batch(predictions=decoded_preds, references=decoded_labels)
167
+
168
+ results = metric.compute()
169
+ mlflow.log_metrics({"epoch": epoch, "BLEU score": results["score"]})
170
+ print(f"epoch {epoch}, BLEU score: {results['score']:.2f}")
171
+
172
+ # Save and upload
173
+ accelerator.wait_for_everyone()
174
+ unwrapped_model = accelerator.unwrap_model(model)
175
+ unwrapped_model.save_pretrained(
176
+ config.output_dir, save_function=accelerator.save
177
+ )
178
+ if accelerator.is_main_process:
179
+ tokenizer.save_pretrained(config.output_dir)
180
+ # save model with mlflow
181
+ mlflow.transformers.log_model(
182
+ transformers_model={"model": unwrapped_model, "tokenizer": tokenizer},
183
+ task="text2text-generation",
184
+ artifact_path="seq2seq_model",
185
+ registered_model_name="FlanT5_MIT",
186
+ )
187
+
188
+
189
+ mlflow.set_tracking_uri("http://127.0.0.1:5000")
190
+ with mlflow.start_run() as mlflow_run:
191
+ mlflow.log_params(asdict(config))
192
+ train()
elise/src/utils/logger.py CHANGED
@@ -4,7 +4,9 @@ Logging helper module
4
  import logging.config
5
  import yaml
6
 
7
- with open("elise/src/configs/logging_config.yaml", "r", encoding="utf-8") as f:
 
 
8
  config = yaml.safe_load(f.read())
9
  logging.config.dictConfig(config)
10
  logging.captureWarnings(True)
 
4
  import logging.config
5
  import yaml
6
 
7
+ with open(
8
+ "/home/kave/work/Elise/elise/src/configs/logging_config.yaml", "r", encoding="utf-8"
9
+ ) as f:
10
  config = yaml.safe_load(f.read())
11
  logging.config.dictConfig(config)
12
  logging.captureWarnings(True)
requirements.txt CHANGED
@@ -11,3 +11,5 @@ transformers==4.31.0
11
  pylint==2.17.5
12
  gradio==3.39.0
13
  gradio_client==0.3.0
 
 
 
11
  pylint==2.17.5
12
  gradio==3.39.0
13
  gradio_client==0.3.0
14
+ accelerate==0.21.0
15
+ evaluate==0.4.0