Spaces:
Runtime error
Runtime error
Merge pull request #3 from BerserkerMother/dev
Browse files- .gitignore +8 -0
- .pylintrc +2 -1
- MLproject +6 -0
- Makefile +4 -0
- elise/data_generation/data_generation_prompts.txt +21 -0
- elise/data_generation/prompt_generation.txt +27 -0
- elise/src/app.py +3 -1
- elise/src/configs/__init__.py +4 -0
- elise/src/configs/logging_config.yaml +3 -1
- elise/src/configs/train_t5.py +20 -0
- elise/src/data/__init__.py +5 -0
- elise/src/data/mit_seq2seq_dataset.py +125 -0
- elise/src/excutors/__init__.py +0 -0
- elise/src/models/__init__.py +0 -0
- elise/src/notebooks/flant_t5_playground.ipynb +72 -0
- elise/src/notebooks/play.ipynb +443 -0
- elise/src/notebooks/playground_prompt/version1.py +22 -0
- elise/src/notebooks/t5 funetinung.ipynb +449 -0
- elise/src/train_t5_seq2seq.py +192 -0
- elise/src/utils/logger.py +3 -1
- requirements.txt +2 -0
.gitignore
CHANGED
@@ -3,6 +3,14 @@ __pycache__/
|
|
3 |
*.py[cod]
|
4 |
*$py.class
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
# IDEs file
|
7 |
.idea
|
8 |
|
|
|
3 |
*.py[cod]
|
4 |
*$py.class
|
5 |
|
6 |
+
# Project files
|
7 |
+
mlruns/
|
8 |
+
mlartifacts/
|
9 |
+
experiment_models/
|
10 |
+
|
11 |
+
# mlflow databases
|
12 |
+
mlflow.db
|
13 |
+
|
14 |
# IDEs file
|
15 |
.idea
|
16 |
|
.pylintrc
CHANGED
@@ -428,7 +428,8 @@ disable=raw-checker-failed,
|
|
428 |
suppressed-message,
|
429 |
useless-suppression,
|
430 |
deprecated-pragma,
|
431 |
-
use-symbolic-message-instead
|
|
|
432 |
|
433 |
# Enable the message, report, category or checker with the given id(s). You can
|
434 |
# either give multiple identifier separated by comma (,) or put this option
|
|
|
428 |
suppressed-message,
|
429 |
useless-suppression,
|
430 |
deprecated-pragma,
|
431 |
+
use-symbolic-message-instead,
|
432 |
+
R0902
|
433 |
|
434 |
# Enable the message, report, category or checker with the given id(s). You can
|
435 |
# either give multiple identifier separated by comma (,) or put this option
|
MLproject
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Afterhours Elise Model pipeline
|
2 |
+
|
3 |
+
entry_points:
|
4 |
+
train_t5:
|
5 |
+
command: "python elise/src/train_t5_seq2seq.py"
|
6 |
+
|
Makefile
CHANGED
@@ -1,10 +1,14 @@
|
|
1 |
setup:
|
2 |
pip install -r requirements.txt
|
|
|
|
|
3 |
format:
|
4 |
black elise/src/
|
5 |
lint:
|
6 |
pylint elise/src/
|
7 |
gradio:
|
8 |
python elise/src/app.py
|
|
|
|
|
9 |
dev:
|
10 |
make format lint gradio
|
|
|
1 |
setup:
|
2 |
pip install -r requirements.txt
|
3 |
+
test:
|
4 |
+
pytest
|
5 |
format:
|
6 |
black elise/src/
|
7 |
lint:
|
8 |
pylint elise/src/
|
9 |
gradio:
|
10 |
python elise/src/app.py
|
11 |
+
train_t5:
|
12 |
+
python elise/src/train_t5_seq2seq.py
|
13 |
dev:
|
14 |
make format lint gradio
|
elise/data_generation/data_generation_prompts.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Image you are assisting me generating data for training a T5 language model. Each record contains a user prompts where the user describes a place they want to dine, and user intentions and intention category which is label for training model. the labels are user intentions.
|
2 |
+
Intentions categories are:
|
3 |
+
- Cuisine
|
4 |
+
- Location
|
5 |
+
- Price
|
6 |
+
- Atmosphere
|
7 |
+
- Service
|
8 |
+
- Reviews
|
9 |
+
- Accessibility
|
10 |
+
- Amenity & Special features
|
11 |
+
- Offerings
|
12 |
+
- Recommendations
|
13 |
+
- Crowd
|
14 |
+
- Payment
|
15 |
+
- Category
|
16 |
+
|
17 |
+
Here is one example:
|
18 |
+
Prompt: I have a gluten allergy and need to find a restaurant with gluten-free options. Do you know any good ones in this area?
|
19 |
+
Label: { "Location": "in this area", "Dietary restrictions": "gluten-free" }
|
20 |
+
|
21 |
+
Write 5 random records in json format containing user's prompts and user's intentions.
|
elise/data_generation/prompt_generation.txt
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Your task is to parse an unstructured job posting and turn it into a JSON containing the most important information. The job posting can describe one or more jobs at the same company. The JSON should consist of the following information:
|
2 |
+
- The company name (field name: "companyName", field type: string)
|
3 |
+
- the location of the company (field name: "companyLocation", field type: string); if not explictily stated, you can try to infer the company's actual location from other clues, e.g., something like "Remote (US)" usually means that the company is located in the US; if the location cannot be inferred, set it to null
|
4 |
+
- a short description of what the company is doing or building (field name: "companyDescription", field type: string); try to keep it short (max length: ca. 300 characters)
|
5 |
+
- a list of advertised jobs (field name: "jobs", field type: array).
|
6 |
+
Each element of the "jobs" array should contain the following fields:
|
7 |
+
- The job title (field name: "jobTitle", field type: string); the job title should be given in the singular form (i.e., Frontend Developer instead of Frontend Developers)
|
8 |
+
- the salary range (field name: "salary", field type: string); only include explictly stated salary amounts, otherwise set to null
|
9 |
+
- whether equity is part of the compensation (field name: "equity", field type: boolean)
|
10 |
+
- the benefits (field name: "benefits", field type: string); include things like 401k, insurance, equipment, child care, etc. if stated, otherwise set to null
|
11 |
+
- the location of the job (field name: "location", field type: string)
|
12 |
+
- whether this is a job for senior/experienced candidates (field name: "senior", field type: boolean); typically senior, staff, lead, principal, vp, cto, etc. positions are all regarded as senior level
|
13 |
+
- whether it is a remote opportunity (field name: "remote", field type: boolean)
|
14 |
+
- whether it can be done onsite from an office (field name: "onsite", field type: boolean)
|
15 |
+
- whether it can be done part-time (field name: "partTime", field type: boolean)
|
16 |
+
- whether it can be done full-time (field name: "fullTime", field type: boolean)
|
17 |
+
- the URL to the specific job description (field name: "jobUrl", field type: string)
|
18 |
+
- and any specific requirements/skills that might be stated (field name: "requirements", field type: string).
|
19 |
+
In general, if certain information is not stated, set the respective field to null. If the company seeks more than one person for the same role, include the role only once.
|
20 |
+
|
21 |
+
This is the job posting:
|
22 |
+
|
23 |
+
%s
|
24 |
+
|
25 |
+
The structured JSON representation is:
|
26 |
+
```json
|
27 |
+
{"companyName":
|
elise/src/app.py
CHANGED
@@ -10,7 +10,9 @@ from utils import df_to_json
|
|
10 |
|
11 |
|
12 |
# prep models
|
13 |
-
MODEL_CHECKPOINT =
|
|
|
|
|
14 |
parser = SentenceParser.from_huggingface(MODEL_CHECKPOINT)
|
15 |
|
16 |
embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
|
|
10 |
|
11 |
|
12 |
# prep models
|
13 |
+
MODEL_CHECKPOINT = (
|
14 |
+
"tner/roberta-large-mit-restaurant" # "BerserkerMother/restaurant_ner"
|
15 |
+
)
|
16 |
parser = SentenceParser.from_huggingface(MODEL_CHECKPOINT)
|
17 |
|
18 |
embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
elise/src/configs/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
All configs for ML project
|
3 |
+
"""
|
4 |
+
from .train_t5 import T5TrainingConfig
|
elise/src/configs/logging_config.yaml
CHANGED
@@ -1,12 +1,14 @@
|
|
1 |
version: 1
|
|
|
2 |
formatters:
|
3 |
simple:
|
4 |
format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
5 |
handlers:
|
6 |
console:
|
|
|
7 |
class: logging.StreamHandler
|
8 |
formatter: simple
|
9 |
stream: ext://sys.stdout
|
10 |
-
|
11 |
Level: DEBUG
|
12 |
handlers: [console]
|
|
|
1 |
version: 1
|
2 |
+
disable_existing_loggers: False
|
3 |
formatters:
|
4 |
simple:
|
5 |
format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
6 |
handlers:
|
7 |
console:
|
8 |
+
level: DEBUG
|
9 |
class: logging.StreamHandler
|
10 |
formatter: simple
|
11 |
stream: ext://sys.stdout
|
12 |
+
root:
|
13 |
Level: DEBUG
|
14 |
handlers: [console]
|
elise/src/configs/train_t5.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Training config for T5 Seq2Seq training
|
3 |
+
"""
|
4 |
+
|
5 |
+
from dataclasses import dataclass
|
6 |
+
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class T5TrainingConfig:
|
10 |
+
"""Training configs for T5 finetuing"""
|
11 |
+
|
12 |
+
train_batch_size: int = 32
|
13 |
+
eval_batch_size: int = 32
|
14 |
+
epochs: int = 10
|
15 |
+
max_length: int = 512
|
16 |
+
learning_rate: float = 3e-4
|
17 |
+
num_warmup_steps: int = 200
|
18 |
+
mixed_precision: str = "bf16"
|
19 |
+
gradient_accumulation_steps: int = 4
|
20 |
+
output_dir: str = "FlanT5_MIT_ner"
|
elise/src/data/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Contians datasets and their connectors for model training
|
3 |
+
"""
|
4 |
+
|
5 |
+
from .mit_seq2seq_dataset import MITRestaurants, get_default_transforms
|
elise/src/data/mit_seq2seq_dataset.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
seq2seq models datasets
|
3 |
+
|
4 |
+
Classes:
|
5 |
+
MITRestaurants: tner/mit_restaurant dataset to seq2seq
|
6 |
+
|
7 |
+
Functions:
|
8 |
+
get_default_transforms: default transforms for mit dataset
|
9 |
+
"""
|
10 |
+
import datasets
|
11 |
+
|
12 |
+
|
13 |
+
class MITRestaurants:
|
14 |
+
"""
|
15 |
+
tner/mit_restaurants for seq2seq
|
16 |
+
|
17 |
+
Atrributes
|
18 |
+
----------
|
19 |
+
ner_tags: ner tags and ids of mit restaurant
|
20 |
+
dataset: hf dataset
|
21 |
+
transforms: transforms to apply
|
22 |
+
"""
|
23 |
+
|
24 |
+
ner_tags = {
|
25 |
+
"O": 0,
|
26 |
+
"B-Rating": 1,
|
27 |
+
"I-Rating": 2,
|
28 |
+
"B-Amenity": 3,
|
29 |
+
"I-Amenity": 4,
|
30 |
+
"B-Location": 5,
|
31 |
+
"I-Location": 6,
|
32 |
+
"B-Restaurant_Name": 7,
|
33 |
+
"I-Restaurant_Name": 8,
|
34 |
+
"B-Price": 9,
|
35 |
+
"B-Hours": 10,
|
36 |
+
"I-Hours": 11,
|
37 |
+
"B-Dish": 12,
|
38 |
+
"I-Dish": 13,
|
39 |
+
"B-Cuisine": 14,
|
40 |
+
"I-Price": 15,
|
41 |
+
"I-Cuisine": 16,
|
42 |
+
}
|
43 |
+
|
44 |
+
def __init__(self, dataset: datasets.DatasetDict, transforms=None):
|
45 |
+
"""
|
46 |
+
Constructs mit datasets
|
47 |
+
|
48 |
+
Parameters:
|
49 |
+
dataset: huggingface mit dataset
|
50 |
+
transforms: dataset transform functions
|
51 |
+
"""
|
52 |
+
self.dataset = dataset
|
53 |
+
self.transforms = transforms
|
54 |
+
|
55 |
+
def hf_training(self):
|
56 |
+
"""
|
57 |
+
Returns dataset for huggingface training ecosystem
|
58 |
+
"""
|
59 |
+
dataset_ = self.dataset
|
60 |
+
if self.transforms:
|
61 |
+
for transfrom in self.transforms:
|
62 |
+
dataset_ = dataset_.map(transfrom)
|
63 |
+
return dataset_
|
64 |
+
|
65 |
+
def set_transforms(self, transforms):
|
66 |
+
"""
|
67 |
+
Set tranfroms fn
|
68 |
+
|
69 |
+
Parameters:
|
70 |
+
transforms: transforms functions
|
71 |
+
"""
|
72 |
+
if self.transforms:
|
73 |
+
self.transforms += transforms
|
74 |
+
else:
|
75 |
+
self.transforms = transforms
|
76 |
+
return self
|
77 |
+
|
78 |
+
@classmethod
|
79 |
+
def from_hf(cls, hf_path: str):
|
80 |
+
"""
|
81 |
+
Constructs dataset from huggingface
|
82 |
+
|
83 |
+
Parameters:
|
84 |
+
hf_path: path to dataset hf repo
|
85 |
+
"""
|
86 |
+
return cls(datasets.load_dataset(hf_path))
|
87 |
+
|
88 |
+
|
89 |
+
def get_default_transforms():
|
90 |
+
"""
|
91 |
+
Default transformation to convert ner dataset to seq2seq
|
92 |
+
"""
|
93 |
+
label_names = {v: k for k, v in MITRestaurants.ner_tags.items()}
|
94 |
+
|
95 |
+
def decode_tags(tags, words):
|
96 |
+
dict_out = {}
|
97 |
+
word_ = ""
|
98 |
+
for tag, word in zip(tags[::-1], words[::-1]):
|
99 |
+
if tag == 0:
|
100 |
+
continue
|
101 |
+
word_ = word + " " + word_
|
102 |
+
if label_names[tag].startswith("B"):
|
103 |
+
tag_name = label_names[tag][2:]
|
104 |
+
word_ = word_.strip()
|
105 |
+
if tag_name not in dict_out:
|
106 |
+
dict_out[tag_name] = [word_]
|
107 |
+
else:
|
108 |
+
dict_out[tag_name].append(word_)
|
109 |
+
word_ = ""
|
110 |
+
return dict_out
|
111 |
+
|
112 |
+
def format_to_text(decoded):
|
113 |
+
text = ""
|
114 |
+
for key, value in decoded.items():
|
115 |
+
text += f"{key}: {', '.join(value)}\n"
|
116 |
+
return text
|
117 |
+
|
118 |
+
def generate_seq2seq_data(example):
|
119 |
+
decoded = decode_tags(example["tags"], example["tokens"])
|
120 |
+
return {
|
121 |
+
"tokens": " ".join(example["tokens"]),
|
122 |
+
"labels": format_to_text(decoded),
|
123 |
+
}
|
124 |
+
|
125 |
+
return [generate_seq2seq_data]
|
elise/src/excutors/__init__.py
ADDED
File without changes
|
elise/src/models/__init__.py
ADDED
File without changes
|
elise/src/notebooks/flant_t5_playground.ipynb
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import transformers"
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": 2,
|
15 |
+
"metadata": {},
|
16 |
+
"outputs": [],
|
17 |
+
"source": [
|
18 |
+
"pipe = transformers.pipeline(\n",
|
19 |
+
" \"text2text-generation\", model=\"/home/kave/work/Elise/output_dir/\"\n",
|
20 |
+
")"
|
21 |
+
]
|
22 |
+
},
|
23 |
+
{
|
24 |
+
"cell_type": "code",
|
25 |
+
"execution_count": 7,
|
26 |
+
"metadata": {},
|
27 |
+
"outputs": [
|
28 |
+
{
|
29 |
+
"data": {
|
30 |
+
"text/plain": [
|
31 |
+
"[{'generated_text': ''}]"
|
32 |
+
]
|
33 |
+
},
|
34 |
+
"execution_count": 7,
|
35 |
+
"metadata": {},
|
36 |
+
"output_type": "execute_result"
|
37 |
+
}
|
38 |
+
],
|
39 |
+
"source": [
|
40 |
+
"pipe(\"What are you?\")"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "code",
|
45 |
+
"execution_count": null,
|
46 |
+
"metadata": {},
|
47 |
+
"outputs": [],
|
48 |
+
"source": []
|
49 |
+
}
|
50 |
+
],
|
51 |
+
"metadata": {
|
52 |
+
"kernelspec": {
|
53 |
+
"display_name": "Python 3 (ipykernel)",
|
54 |
+
"language": "python",
|
55 |
+
"name": "python3"
|
56 |
+
},
|
57 |
+
"language_info": {
|
58 |
+
"codemirror_mode": {
|
59 |
+
"name": "ipython",
|
60 |
+
"version": 3
|
61 |
+
},
|
62 |
+
"file_extension": ".py",
|
63 |
+
"mimetype": "text/x-python",
|
64 |
+
"name": "python",
|
65 |
+
"nbconvert_exporter": "python",
|
66 |
+
"pygments_lexer": "ipython3",
|
67 |
+
"version": "3.10.0"
|
68 |
+
}
|
69 |
+
},
|
70 |
+
"nbformat": 4,
|
71 |
+
"nbformat_minor": 2
|
72 |
+
}
|
elise/src/notebooks/play.ipynb
ADDED
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"from pprint import pprint"
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": 2,
|
15 |
+
"metadata": {},
|
16 |
+
"outputs": [],
|
17 |
+
"source": [
|
18 |
+
"from playground_prompt.version1 import prompts\n",
|
19 |
+
"\n",
|
20 |
+
"test_p = prompts[\"test_prompts\"]"
|
21 |
+
]
|
22 |
+
},
|
23 |
+
{
|
24 |
+
"cell_type": "code",
|
25 |
+
"execution_count": 3,
|
26 |
+
"metadata": {},
|
27 |
+
"outputs": [],
|
28 |
+
"source": [
|
29 |
+
"import sys\n",
|
30 |
+
"\n",
|
31 |
+
"sys.path.append(\"/home/kave/work/Elise/elise/src\")"
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "code",
|
36 |
+
"execution_count": 4,
|
37 |
+
"metadata": {},
|
38 |
+
"outputs": [
|
39 |
+
{
|
40 |
+
"name": "stderr",
|
41 |
+
"output_type": "stream",
|
42 |
+
"text": [
|
43 |
+
"Using /home/kave/.cache/torch_extensions/py310_cu117 as PyTorch extensions root...\n",
|
44 |
+
"Detected CUDA files, patching ldflags\n",
|
45 |
+
"Emitting ninja build file /home/kave/.cache/torch_extensions/py310_cu117/cuda_kernel/build.ninja...\n",
|
46 |
+
"Building extension module cuda_kernel...\n",
|
47 |
+
"Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n",
|
48 |
+
"Loading extension module cuda_kernel...\n",
|
49 |
+
"Failed to load CUDA kernels. Mra requires custom CUDA kernels. Please verify that compatible versions of PyTorch and CUDA Toolkit are installed: /home/kave/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/zmq/backend/cython/../../../../.././libstdc++.so.6: version `GLIBCXX_3.4.32' not found (required by /home/kave/.cache/torch_extensions/py310_cu117/cuda_kernel/cuda_kernel.so)\n"
|
50 |
+
]
|
51 |
+
},
|
52 |
+
{
|
53 |
+
"name": "stdout",
|
54 |
+
"output_type": "stream",
|
55 |
+
"text": [
|
56 |
+
"ninja: no work to do.\n"
|
57 |
+
]
|
58 |
+
}
|
59 |
+
],
|
60 |
+
"source": [
|
61 |
+
"from parser import SentenceParser\n",
|
62 |
+
"\n",
|
63 |
+
"parser = SentenceParser.from_huggingface(\"BerserkerMother/restaurant_ner\")"
|
64 |
+
]
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"cell_type": "code",
|
68 |
+
"execution_count": 5,
|
69 |
+
"metadata": {},
|
70 |
+
"outputs": [
|
71 |
+
{
|
72 |
+
"name": "stdout",
|
73 |
+
"output_type": "stream",
|
74 |
+
"text": [
|
75 |
+
"[{'Cuisine': ['fast -']},\n",
|
76 |
+
" {'Amenity': ['lively']},\n",
|
77 |
+
" {'Amenity': ['nice', 'fun'],\n",
|
78 |
+
" 'Hours': ['dinner'],\n",
|
79 |
+
" 'Rating': ['good', 'good'],\n",
|
80 |
+
" 'Services': ['ambian']},\n",
|
81 |
+
" {'Cuisine': ['authentic indian']},\n",
|
82 |
+
" {'Amenity': ['romantic', 'nice']},\n",
|
83 |
+
" {},\n",
|
84 |
+
" {},\n",
|
85 |
+
" {'Amenity': ['live music']},\n",
|
86 |
+
" {'DS': ['gluten allergy', 'gluten - free']},\n",
|
87 |
+
" {'Amenity': ['cozy']}]\n"
|
88 |
+
]
|
89 |
+
}
|
90 |
+
],
|
91 |
+
"source": [
|
92 |
+
"ners = parser.get_ner(test_p)\n",
|
93 |
+
"pprint(ners)"
|
94 |
+
]
|
95 |
+
},
|
96 |
+
{
|
97 |
+
"cell_type": "code",
|
98 |
+
"execution_count": 6,
|
99 |
+
"metadata": {},
|
100 |
+
"outputs": [
|
101 |
+
{
|
102 |
+
"name": "stdout",
|
103 |
+
"output_type": "stream",
|
104 |
+
"text": [
|
105 |
+
"[{'Semantic': ['fast -']},\n",
|
106 |
+
" {'Semantic': ['lively']},\n",
|
107 |
+
" {'Hours': ['dinner'], 'Semantic': ['nice', 'fun']},\n",
|
108 |
+
" {'Semantic': ['authentic indian']},\n",
|
109 |
+
" {'Semantic': ['romantic', 'nice']},\n",
|
110 |
+
" {},\n",
|
111 |
+
" {},\n",
|
112 |
+
" {'Semantic': ['live music']},\n",
|
113 |
+
" {},\n",
|
114 |
+
" {'Semantic': ['cozy']}]\n"
|
115 |
+
]
|
116 |
+
}
|
117 |
+
],
|
118 |
+
"source": [
|
119 |
+
"parsed_prompts = parser.parse(ners)\n",
|
120 |
+
"pprint(parsed_prompts)"
|
121 |
+
]
|
122 |
+
},
|
123 |
+
{
|
124 |
+
"cell_type": "code",
|
125 |
+
"execution_count": 7,
|
126 |
+
"metadata": {},
|
127 |
+
"outputs": [],
|
128 |
+
"source": [
|
129 |
+
"from sentence_transformers import SentenceTransformer\n",
|
130 |
+
"\n",
|
131 |
+
"embedder = SentenceTransformer(\"all-MiniLM-L6-v2\")"
|
132 |
+
]
|
133 |
+
},
|
134 |
+
{
|
135 |
+
"cell_type": "code",
|
136 |
+
"execution_count": 8,
|
137 |
+
"metadata": {},
|
138 |
+
"outputs": [
|
139 |
+
{
|
140 |
+
"name": "stdout",
|
141 |
+
"output_type": "stream",
|
142 |
+
"text": [
|
143 |
+
" Accessibility \\\n",
|
144 |
+
"0 Wheelchair accessible entrance, Wheelchair acc... \n",
|
145 |
+
"1 Wheelchair accessible entrance, Wheelchair acc... \n",
|
146 |
+
"2 Wheelchair accessible entrance, Wheelchair acc... \n",
|
147 |
+
"3 Wheelchair accessible entrance, Wheelchair acc... \n",
|
148 |
+
"4 Wheelchair accessible entrance, Wheelchair acc... \n",
|
149 |
+
".. ... \n",
|
150 |
+
"274 Wheelchair accessible entrance, Wheelchair acc... \n",
|
151 |
+
"275 Wheelchair accessible elevator, Wheelchair acc... \n",
|
152 |
+
"276 Wheelchair accessible entrance, Wheelchair acc... \n",
|
153 |
+
"277 Wheelchair accessible seating, Wheelchair acce... \n",
|
154 |
+
"278 Wheelchair accessible parking lot, Wheelchair ... \n",
|
155 |
+
"\n",
|
156 |
+
" Amenities & Special featrures \\\n",
|
157 |
+
"0 Bar onsite, Good for kids, High chairs, Restro... \n",
|
158 |
+
"1 Good for kids, High chairs, Restroom \n",
|
159 |
+
"2 Bar onsite, Dogs allowed, Good for kids, High ... \n",
|
160 |
+
"3 Restroom \n",
|
161 |
+
"4 Restroom \n",
|
162 |
+
".. ... \n",
|
163 |
+
"274 Bar onsite, Good for kids, High chairs, Restro... \n",
|
164 |
+
"275 Bar onsite, High chairs, Restroom, Wi-Fi \n",
|
165 |
+
"276 Bar onsite, Dogs allowed, Good for kids, High ... \n",
|
166 |
+
"277 Bar onsite, High chairs, Restroom \n",
|
167 |
+
"278 Bar onsite, High chairs, Restroom \n",
|
168 |
+
"\n",
|
169 |
+
" Atmosphere \\\n",
|
170 |
+
"0 Casual, Cozy \n",
|
171 |
+
"1 Casual, Cozy \n",
|
172 |
+
"2 Casual, Cozy \n",
|
173 |
+
"3 Casual, Cozy \n",
|
174 |
+
"4 Casual, Cozy, Romantic \n",
|
175 |
+
".. ... \n",
|
176 |
+
"274 Casual, Cozy \n",
|
177 |
+
"275 Casual, Cozy \n",
|
178 |
+
"276 Casual, Cozy, Romantic \n",
|
179 |
+
"277 Casual, Cozy \n",
|
180 |
+
"278 Casual, Cozy \n",
|
181 |
+
"\n",
|
182 |
+
" Crowd \\\n",
|
183 |
+
"0 Family friendly, Groups, LGBTQ+ friendly, Tran... \n",
|
184 |
+
"1 Family friendly, Groups \n",
|
185 |
+
"2 Family friendly, Groups \n",
|
186 |
+
"3 NaN \n",
|
187 |
+
"4 Groups \n",
|
188 |
+
".. ... \n",
|
189 |
+
"274 Family friendly, Groups, LGBTQ+ friendly \n",
|
190 |
+
"275 Groups \n",
|
191 |
+
"276 Groups, LGBTQ+ friendly, Transgender safespace \n",
|
192 |
+
"277 Groups \n",
|
193 |
+
"278 Groups \n",
|
194 |
+
"\n",
|
195 |
+
" Dining options \\\n",
|
196 |
+
"0 Lunch, Dinner, Dessert, Seating \n",
|
197 |
+
"1 Lunch, Dinner, Dessert, Seating \n",
|
198 |
+
"2 Lunch, Dinner, Catering, Dessert, Seating \n",
|
199 |
+
"3 Breakfast, Lunch, Dessert, Seating \n",
|
200 |
+
"4 Dinner, Dessert, Seating \n",
|
201 |
+
".. ... \n",
|
202 |
+
"274 Dinner, Counter service, Dessert, Seating \n",
|
203 |
+
"275 Breakfast, Brunch, Lunch, Dinner, Dessert, Sea... \n",
|
204 |
+
"276 Breakfast, Brunch, Lunch, Dinner, Dessert, Sea... \n",
|
205 |
+
"277 Brunch, Lunch, Dinner, Dessert, Seating \n",
|
206 |
+
"278 Brunch, Lunch, Dinner, Dessert, Seating \n",
|
207 |
+
"\n",
|
208 |
+
" Offerings \\\n",
|
209 |
+
"0 Alcohol, All you can eat, Beer, Cocktails, Cof... \n",
|
210 |
+
"1 Alcohol, All you can eat, Beer, Coffee, Halal,... \n",
|
211 |
+
"2 Alcohol, Beer, Coffee, Halal, Healthy options,... \n",
|
212 |
+
"3 Coffee \n",
|
213 |
+
"4 Alcohol, Beer, Cocktails, Coffee, Hard liquor,... \n",
|
214 |
+
".. ... \n",
|
215 |
+
"274 Alcohol, Beer, Coffee, Healthy options, Kids' ... \n",
|
216 |
+
"275 Alcohol, Beer, Cocktails, Coffee, Hard liquor,... \n",
|
217 |
+
"276 Alcohol, Beer, Cocktails, Coffee, Hard liquor,... \n",
|
218 |
+
"277 Alcohol, Beer, Cocktails, Coffee, Hard liquor,... \n",
|
219 |
+
"278 Alcohol, Beer, Cocktails, Coffee, Hard liquor,... \n",
|
220 |
+
"\n",
|
221 |
+
" Payment Planning \\\n",
|
222 |
+
"0 Debit cards, NFC mobile payments, Credit cards Accepts reservations \n",
|
223 |
+
"1 Cash-only, Debit cards, NFC mobile payments Accepts reservations \n",
|
224 |
+
"2 Debit cards, NFC mobile payments, Credit cards Accepts reservations \n",
|
225 |
+
"3 Debit cards NaN \n",
|
226 |
+
"4 NaN Accepts reservations \n",
|
227 |
+
".. ... ... \n",
|
228 |
+
"274 Debit cards, NFC mobile payments, Credit cards NaN \n",
|
229 |
+
"275 Debit cards, NFC mobile payments, Credit cards Accepts reservations \n",
|
230 |
+
"276 Debit cards, NFC mobile payments, Credit cards Accepts reservations \n",
|
231 |
+
"277 Debit cards, NFC mobile payments, Credit cards Accepts reservations \n",
|
232 |
+
"278 Debit cards, NFC mobile payments Accepts reservations \n",
|
233 |
+
"\n",
|
234 |
+
" Service \\\n",
|
235 |
+
"0 Outdoor seating, Delivery, Takeout, Dine-in \n",
|
236 |
+
"1 Outdoor seating, Dine-in \n",
|
237 |
+
"2 Outdoor seating, Curbside pickup, No-contact d... \n",
|
238 |
+
"3 Outdoor seating, Takeout, Dine-in \n",
|
239 |
+
"4 Outdoor seating, Delivery, Takeout, Dine-in \n",
|
240 |
+
".. ... \n",
|
241 |
+
"274 Outdoor seating, Takeout, Dine-in, Delivery \n",
|
242 |
+
"275 Outdoor seating, Dine-in, Delivery, Takeout \n",
|
243 |
+
"276 Outdoor seating, Curbside pickup, Takeout, Din... \n",
|
244 |
+
"277 Outdoor seating, Takeout, Dine-in, Delivery \n",
|
245 |
+
"278 Outdoor seating, Dine-in, Delivery, Takeout \n",
|
246 |
+
"\n",
|
247 |
+
" categories \\\n",
|
248 |
+
"0 NaN \n",
|
249 |
+
"1 Korean barbecue restaurant \n",
|
250 |
+
"2 Indian restaurant, Asian restaurant, Health fo... \n",
|
251 |
+
"3 Cafe, Breakfast restaurant, Brunch restaurant,... \n",
|
252 |
+
"4 Thai restaurant \n",
|
253 |
+
".. ... \n",
|
254 |
+
"274 NaN \n",
|
255 |
+
"275 , Cocktail bar, Coffee shop, Coworking space, ... \n",
|
256 |
+
"276 NaN \n",
|
257 |
+
"277 , Cafe \n",
|
258 |
+
"278 , Culinary school, Event venue, Function room ... \n",
|
259 |
+
"\n",
|
260 |
+
" Category \\\n",
|
261 |
+
"0 NaN \n",
|
262 |
+
"1 Korean barbecue restaurant \n",
|
263 |
+
"2 Indian restaurant \n",
|
264 |
+
"3 Cafe \n",
|
265 |
+
"4 Thai restaurant \n",
|
266 |
+
".. ... \n",
|
267 |
+
"274 NaN \n",
|
268 |
+
"275 NaN \n",
|
269 |
+
"276 NaN \n",
|
270 |
+
"277 NaN \n",
|
271 |
+
"278 NaN \n",
|
272 |
+
"\n",
|
273 |
+
" description \n",
|
274 |
+
"0 NaN \n",
|
275 |
+
"1 NaN \n",
|
276 |
+
"2 Light-filled restaurant with colorful seating ... \n",
|
277 |
+
"3 Seattle-based coffeehouse chain known for its ... \n",
|
278 |
+
"4 NaN \n",
|
279 |
+
".. ... \n",
|
280 |
+
"274 NaN \n",
|
281 |
+
"275 NaN \n",
|
282 |
+
"276 NaN \n",
|
283 |
+
"277 NaN \n",
|
284 |
+
"278 NaN \n",
|
285 |
+
"\n",
|
286 |
+
"[279 rows x 12 columns]\n"
|
287 |
+
]
|
288 |
+
}
|
289 |
+
],
|
290 |
+
"source": [
|
291 |
+
"from matcher import Matcher\n",
|
292 |
+
"\n",
|
293 |
+
"matcher = Matcher.from_path(\"/home/kave/work/Elise/elise/data/final_data.csv\", embedder)"
|
294 |
+
]
|
295 |
+
},
|
296 |
+
{
|
297 |
+
"cell_type": "code",
|
298 |
+
"execution_count": 9,
|
299 |
+
"metadata": {},
|
300 |
+
"outputs": [
|
301 |
+
{
|
302 |
+
"data": {
|
303 |
+
"text/plain": [
|
304 |
+
"tensor([[ 0.0409, -0.0148, 0.0419, ..., 0.0347, -0.0333, 0.0381],\n",
|
305 |
+
" [ 0.0090, 0.0329, 0.0421, ..., 0.0387, -0.0885, 0.0523],\n",
|
306 |
+
" [ 0.0155, -0.0316, 0.0217, ..., 0.0344, -0.0364, 0.0223],\n",
|
307 |
+
" ...,\n",
|
308 |
+
" [ 0.0363, -0.0182, 0.0596, ..., 0.0276, -0.0117, 0.0335],\n",
|
309 |
+
" [ 0.0415, -0.0292, 0.0533, ..., 0.0515, -0.0225, 0.0368],\n",
|
310 |
+
" [ 0.0540, -0.0334, 0.0415, ..., 0.0756, -0.0274, 0.0454]],\n",
|
311 |
+
" device='cuda:0')"
|
312 |
+
]
|
313 |
+
},
|
314 |
+
"execution_count": 9,
|
315 |
+
"metadata": {},
|
316 |
+
"output_type": "execute_result"
|
317 |
+
}
|
318 |
+
],
|
319 |
+
"source": [
|
320 |
+
"matcher.semantics"
|
321 |
+
]
|
322 |
+
},
|
323 |
+
{
|
324 |
+
"cell_type": "code",
|
325 |
+
"execution_count": 10,
|
326 |
+
"metadata": {},
|
327 |
+
"outputs": [
|
328 |
+
{
|
329 |
+
"name": "stdout",
|
330 |
+
"output_type": "stream",
|
331 |
+
"text": [
|
332 |
+
"CPU times: user 331 ms, sys: 0 ns, total: 331 ms\n",
|
333 |
+
"Wall time: 91.5 ms\n"
|
334 |
+
]
|
335 |
+
}
|
336 |
+
],
|
337 |
+
"source": [
|
338 |
+
"%%time\n",
|
339 |
+
"ners = parser.get_ner(test_p)\n",
|
340 |
+
"parsed_prompts = parser.parse(ners)\n",
|
341 |
+
"kk = matcher.handle_jobs(parsed_prompts)"
|
342 |
+
]
|
343 |
+
},
|
344 |
+
{
|
345 |
+
"cell_type": "code",
|
346 |
+
"execution_count": 12,
|
347 |
+
"metadata": {},
|
348 |
+
"outputs": [],
|
349 |
+
"source": [
|
350 |
+
"from utils import df_to_json"
|
351 |
+
]
|
352 |
+
},
|
353 |
+
{
|
354 |
+
"cell_type": "code",
|
355 |
+
"execution_count": 13,
|
356 |
+
"metadata": {},
|
357 |
+
"outputs": [
|
358 |
+
{
|
359 |
+
"data": {
|
360 |
+
"text/plain": [
|
361 |
+
"{\"I don't feel like cooking tonight. Where's a good place to get fast-food?\": [{'Name': 'Cafetaria Edison',\n",
|
362 |
+
" 'Score': 20.519733428955078},\n",
|
363 |
+
" {'Name': 'Five Guys', 'Score': 20.5447940826416},\n",
|
364 |
+
" {'Name': 'Five Guys', 'Score': 20.5447940826416}],\n",
|
365 |
+
" \"I'm planning a dinner with some friends. Any recommendations for a restaurant with a lively atmosphere?\": [{'Name': 'De Garde',\n",
|
366 |
+
" 'Score': 25.97490882873535},\n",
|
367 |
+
" {'Name': 'Hemel & Aarde', 'Score': 26.467164993286133},\n",
|
368 |
+
" {'Name': 'The Thai Orchid', 'Score': 36.23405838012695}],\n",
|
369 |
+
" \"I want to celebrate my graduation with a nice dinner out. What's a good restaurant with good food and a fun ambiance?\": [{'Name': 'Luc Utrecht',\n",
|
370 |
+
" 'Score': 20.497264862060547},\n",
|
371 |
+
" {'Name': 'Spice Monkey', 'Score': 20.64722442626953},\n",
|
372 |
+
" {'Name': 'Ethiopian Sunshine', 'Score': 20.84449005126953}],\n",
|
373 |
+
" \"I want to try some new cuisines I've never had before. Can you recommend a restaurant with authentic Indian food?\": [{'Name': 'India Port',\n",
|
374 |
+
" 'Score': 35.02317428588867},\n",
|
375 |
+
" {'Name': 'Kashmir Kitchen Utrecht', 'Score': 35.351844787597656},\n",
|
376 |
+
" {'Name': 'Surya Utrecht | Indiaas & Nepalees restaurant & bar',\n",
|
377 |
+
" 'Score': 40.53645706176758}],\n",
|
378 |
+
" \"I'm planning a special date night and want to go somewhere romantic. What's a good restaurant with a nice view?\": [{'Name': 'Sevilla',\n",
|
379 |
+
" 'Score': 20.176685333251953},\n",
|
380 |
+
" {'Name': 'Pand 33 Utrecht', 'Score': 20.325525283813477},\n",
|
381 |
+
" {'Name': 'Hemel & Aarde', 'Score': 23.112428665161133}],\n",
|
382 |
+
" \"I'm meeting a client for lunch. Can you recommend a good restaurant for a business meeting?\": [{'Name': 'Sevilla',\n",
|
383 |
+
" 'Score': 20.176685333251953},\n",
|
384 |
+
" {'Name': 'Pand 33 Utrecht', 'Score': 20.325525283813477},\n",
|
385 |
+
" {'Name': 'Hemel & Aarde', 'Score': 23.112428665161133}],\n",
|
386 |
+
" \"I'm traveling through this city and need to find a good place to eat. Any recommendations near the airport?\": [{'Name': 'Sevilla',\n",
|
387 |
+
" 'Score': 20.176685333251953},\n",
|
388 |
+
" {'Name': 'Pand 33 Utrecht', 'Score': 20.325525283813477},\n",
|
389 |
+
" {'Name': 'Hemel & Aarde', 'Score': 23.112428665161133}],\n",
|
390 |
+
" \"I'm looking for a restaurant with live music or other entertainment. Any suggestions?\": [{'Name': 'Silk Road Utrecht',\n",
|
391 |
+
" 'Score': 19.077302932739258},\n",
|
392 |
+
" {'Name': 'Hemel & Aarde', 'Score': 20.864463806152344},\n",
|
393 |
+
" {'Name': 'The Thai Orchid', 'Score': 26.012712478637695}],\n",
|
394 |
+
" 'I have a gluten allergy and need to find a restaurant with gluten-free options. Do you know any good ones in this area?': [{'Name': 'Silk Road Utrecht',\n",
|
395 |
+
" 'Score': 19.077302932739258},\n",
|
396 |
+
" {'Name': 'Hemel & Aarde', 'Score': 20.864463806152344},\n",
|
397 |
+
" {'Name': 'The Thai Orchid', 'Score': 26.012712478637695}],\n",
|
398 |
+
" \"I just want to relax and have a nice meal out. What's a good restaurant with a cozy atmosphere?\": [{'Name': 'De Garde',\n",
|
399 |
+
" 'Score': 34.23313903808594},\n",
|
400 |
+
" {'Name': 'Asia Street Cooking', 'Score': 34.71759796142578},\n",
|
401 |
+
" {'Name': 'Hemel & Aarde', 'Score': 37.8213005065918}]}"
|
402 |
+
]
|
403 |
+
},
|
404 |
+
"execution_count": 13,
|
405 |
+
"metadata": {},
|
406 |
+
"output_type": "execute_result"
|
407 |
+
}
|
408 |
+
],
|
409 |
+
"source": [
|
410 |
+
"df_to_json(kk, test_p)"
|
411 |
+
]
|
412 |
+
},
|
413 |
+
{
|
414 |
+
"cell_type": "code",
|
415 |
+
"execution_count": null,
|
416 |
+
"metadata": {},
|
417 |
+
"outputs": [],
|
418 |
+
"source": []
|
419 |
+
}
|
420 |
+
],
|
421 |
+
"metadata": {
|
422 |
+
"kernelspec": {
|
423 |
+
"display_name": "afterhours_dev",
|
424 |
+
"language": "python",
|
425 |
+
"name": "python3"
|
426 |
+
},
|
427 |
+
"language_info": {
|
428 |
+
"codemirror_mode": {
|
429 |
+
"name": "ipython",
|
430 |
+
"version": 3
|
431 |
+
},
|
432 |
+
"file_extension": ".py",
|
433 |
+
"mimetype": "text/x-python",
|
434 |
+
"name": "python",
|
435 |
+
"nbconvert_exporter": "python",
|
436 |
+
"pygments_lexer": "ipython3",
|
437 |
+
"version": "3.10.0"
|
438 |
+
},
|
439 |
+
"orig_nbformat": 4
|
440 |
+
},
|
441 |
+
"nbformat": 4,
|
442 |
+
"nbformat_minor": 2
|
443 |
+
}
|
elise/src/notebooks/playground_prompt/version1.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
prompts = {
|
2 |
+
"instruction": """Identify user needs from the prompts.""",
|
3 |
+
"examples": """
|
4 |
+
Prompt: I want to celebrate my graduation with my friend. Recommend me somewhere nice with live music near the Harvard campus.
|
5 |
+
Identification:
|
6 |
+
occasion: celebrating graduation
|
7 |
+
features: somewhere nice, live music
|
8 |
+
location: near Harvard campus
|
9 |
+
""",
|
10 |
+
"test_prompts": [
|
11 |
+
"I don't feel like cooking tonight. Where's a good place to get fast-food?",
|
12 |
+
"I'm planning a dinner with some friends. Any recommendations for a restaurant with a lively atmosphere?",
|
13 |
+
"I want to celebrate my graduation with a nice dinner out. What's a good restaurant with good food and a fun ambiance?",
|
14 |
+
"I want to try some new cuisines I've never had before. Can you recommend a restaurant with authentic Indian food?",
|
15 |
+
"I'm planning a special date night and want to go somewhere romantic. What's a good restaurant with a nice view?",
|
16 |
+
"I'm meeting a client for lunch. Can you recommend a good restaurant for a business meeting?",
|
17 |
+
"I'm traveling through this city and need to find a good place to eat. Any recommendations near the airport?",
|
18 |
+
"I'm looking for a restaurant with live music or other entertainment. Any suggestions?",
|
19 |
+
"I have a gluten allergy and need to find a restaurant with gluten-free options. Do you know any good ones in this area?",
|
20 |
+
"I just want to relax and have a nice meal out. What's a good restaurant with a cozy atmosphere?",
|
21 |
+
],
|
22 |
+
}
|
elise/src/notebooks/t5 funetinung.ipynb
ADDED
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "9510dd98",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"from transformers import AutoTokenizer, AutoModelForSeq2SeqLM\n",
|
11 |
+
"from datasets import load_dataset\n",
|
12 |
+
"from transformers import get_scheduler\n",
|
13 |
+
"import torch\n",
|
14 |
+
"from torch.utils.data import DataLoader\n",
|
15 |
+
"from datasets import load_dataset\n",
|
16 |
+
"from transformers import AutoTokenizer, AutoModelForTokenClassification\n",
|
17 |
+
"from transformers import DataCollatorForSeq2Seq\n",
|
18 |
+
"from accelerate import Accelerator\n",
|
19 |
+
"import evaluate\n",
|
20 |
+
"import datasets\n",
|
21 |
+
"\n",
|
22 |
+
"from tqdm.auto import tqdm"
|
23 |
+
]
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"cell_type": "code",
|
27 |
+
"execution_count": 2,
|
28 |
+
"id": "f1da6c6c",
|
29 |
+
"metadata": {
|
30 |
+
"scrolled": true
|
31 |
+
},
|
32 |
+
"outputs": [
|
33 |
+
{
|
34 |
+
"name": "stderr",
|
35 |
+
"output_type": "stream",
|
36 |
+
"text": [
|
37 |
+
"/home/kave/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/transformers/models/t5/tokenization_t5_fast.py:155: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n",
|
38 |
+
"For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n",
|
39 |
+
"- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.\n",
|
40 |
+
"- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n",
|
41 |
+
"- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.\n",
|
42 |
+
" warnings.warn(\n"
|
43 |
+
]
|
44 |
+
}
|
45 |
+
],
|
46 |
+
"source": [
|
47 |
+
"tokenizer = AutoTokenizer.from_pretrained(\"t5-base\")\n",
|
48 |
+
"model = AutoModelForSeq2SeqLM.from_pretrained(\"t5-base\")"
|
49 |
+
]
|
50 |
+
},
|
51 |
+
{
|
52 |
+
"cell_type": "code",
|
53 |
+
"execution_count": 3,
|
54 |
+
"id": "a6de1719",
|
55 |
+
"metadata": {},
|
56 |
+
"outputs": [],
|
57 |
+
"source": [
|
58 |
+
"# prep dataset\n",
|
59 |
+
"dataset = load_dataset(\"tner/mit_restaurant\")"
|
60 |
+
]
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"cell_type": "code",
|
64 |
+
"execution_count": 4,
|
65 |
+
"id": "8617d7d6",
|
66 |
+
"metadata": {},
|
67 |
+
"outputs": [],
|
68 |
+
"source": [
|
69 |
+
"ner_tags = {\n",
|
70 |
+
" \"O\": 0,\n",
|
71 |
+
" \"B-Rating\": 1,\n",
|
72 |
+
" \"I-Rating\": 2,\n",
|
73 |
+
" \"B-Amenity\": 3,\n",
|
74 |
+
" \"I-Amenity\": 4,\n",
|
75 |
+
" \"B-Location\": 5,\n",
|
76 |
+
" \"I-Location\": 6,\n",
|
77 |
+
" \"B-Restaurant_Name\": 7,\n",
|
78 |
+
" \"I-Restaurant_Name\": 8,\n",
|
79 |
+
" \"B-Price\": 9,\n",
|
80 |
+
" \"B-Hours\": 10,\n",
|
81 |
+
" \"I-Hours\": 11,\n",
|
82 |
+
" \"B-Dish\": 12,\n",
|
83 |
+
" \"I-Dish\": 13,\n",
|
84 |
+
" \"B-Cuisine\": 14,\n",
|
85 |
+
" \"I-Price\": 15,\n",
|
86 |
+
" \"I-Cuisine\": 16,\n",
|
87 |
+
"}\n",
|
88 |
+
"\n",
|
89 |
+
"\n",
|
90 |
+
"label_names = {v: k for k, v in ner_tags.items()}"
|
91 |
+
]
|
92 |
+
},
|
93 |
+
{
|
94 |
+
"cell_type": "code",
|
95 |
+
"execution_count": 5,
|
96 |
+
"id": "de52b597",
|
97 |
+
"metadata": {},
|
98 |
+
"outputs": [],
|
99 |
+
"source": [
|
100 |
+
"def decode_tags(tags, words):\n",
|
101 |
+
" dict_out = {}\n",
|
102 |
+
" word_ = \"\"\n",
|
103 |
+
" for tag, word in zip(tags[::-1], words[::-1]):\n",
|
104 |
+
" if tag == 0:\n",
|
105 |
+
" continue\n",
|
106 |
+
" word_ = word_ + \" \" + word\n",
|
107 |
+
" if label_names[tag].startswith(\"B\"):\n",
|
108 |
+
" tag_name = label_names[tag][2:]\n",
|
109 |
+
" word_ = word_.strip()\n",
|
110 |
+
" if tag_name not in dict_out:\n",
|
111 |
+
" dict_out[tag_name] = [word_]\n",
|
112 |
+
" else:\n",
|
113 |
+
" dict_out[tag_name].append(word_)\n",
|
114 |
+
" word_ = \"\"\n",
|
115 |
+
" return dict_out\n",
|
116 |
+
"\n",
|
117 |
+
"\n",
|
118 |
+
"def format_to_text(decoded):\n",
|
119 |
+
" text = \"\"\n",
|
120 |
+
" for key, value in decoded.items():\n",
|
121 |
+
" text += f\"{key}: {', '.join(value)}\\n\"\n",
|
122 |
+
" return text"
|
123 |
+
]
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"cell_type": "code",
|
127 |
+
"execution_count": 6,
|
128 |
+
"id": "5da715a8",
|
129 |
+
"metadata": {},
|
130 |
+
"outputs": [],
|
131 |
+
"source": [
|
132 |
+
"def generate_t5_data(example):\n",
|
133 |
+
" decoded = decode_tags(example[\"tags\"], example[\"tokens\"])\n",
|
134 |
+
" return {\"tokens\": \" \".join(example[\"tokens\"]), \"labels\": format_to_text(decoded)}"
|
135 |
+
]
|
136 |
+
},
|
137 |
+
{
|
138 |
+
"cell_type": "code",
|
139 |
+
"execution_count": 7,
|
140 |
+
"id": "57416e20",
|
141 |
+
"metadata": {},
|
142 |
+
"outputs": [],
|
143 |
+
"source": [
|
144 |
+
"from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
|
145 |
+
"import torch\n",
|
146 |
+
"\n",
|
147 |
+
"# the following 2 hyperparameters are task-specific\n",
|
148 |
+
"max_source_length = 512\n",
|
149 |
+
"max_target_length = 128\n",
|
150 |
+
"\n",
|
151 |
+
"# encode the inputs\n",
|
152 |
+
"task_prefix = \"What is the user intent?\"\n",
|
153 |
+
"\n",
|
154 |
+
"\n",
|
155 |
+
"def tokenize(example):\n",
|
156 |
+
" tokenized = tokenizer(\n",
|
157 |
+
" task_prefix + example[\"tokens\"],\n",
|
158 |
+
" text_target=example[\"labels\"],\n",
|
159 |
+
" max_length=512,\n",
|
160 |
+
" truncation=True,\n",
|
161 |
+
" )\n",
|
162 |
+
" return tokenized"
|
163 |
+
]
|
164 |
+
},
|
165 |
+
{
|
166 |
+
"cell_type": "code",
|
167 |
+
"execution_count": 8,
|
168 |
+
"id": "137905d7",
|
169 |
+
"metadata": {
|
170 |
+
"scrolled": true
|
171 |
+
},
|
172 |
+
"outputs": [
|
173 |
+
{
|
174 |
+
"data": {
|
175 |
+
"application/vnd.jupyter.widget-view+json": {
|
176 |
+
"model_id": "23bafa0f97bc4d4da8a96397f0f3bd5a",
|
177 |
+
"version_major": 2,
|
178 |
+
"version_minor": 0
|
179 |
+
},
|
180 |
+
"text/plain": [
|
181 |
+
"Map: 0%| | 0/6900 [00:00<?, ? examples/s]"
|
182 |
+
]
|
183 |
+
},
|
184 |
+
"metadata": {},
|
185 |
+
"output_type": "display_data"
|
186 |
+
}
|
187 |
+
],
|
188 |
+
"source": [
|
189 |
+
"tokenized_datasets = dataset.map(generate_t5_data)\n",
|
190 |
+
"tokenized_datasets = tokenized_datasets.remove_columns([\"tags\"])\n",
|
191 |
+
"tokenized_datasets = tokenized_datasets.map(tokenize)"
|
192 |
+
]
|
193 |
+
},
|
194 |
+
{
|
195 |
+
"cell_type": "code",
|
196 |
+
"execution_count": 9,
|
197 |
+
"id": "e2bdf1b0",
|
198 |
+
"metadata": {},
|
199 |
+
"outputs": [],
|
200 |
+
"source": [
|
201 |
+
"import evaluate\n",
|
202 |
+
"\n",
|
203 |
+
"metric = evaluate.load(\"sacrebleu\")"
|
204 |
+
]
|
205 |
+
},
|
206 |
+
{
|
207 |
+
"cell_type": "code",
|
208 |
+
"execution_count": 10,
|
209 |
+
"id": "cd9871bf",
|
210 |
+
"metadata": {},
|
211 |
+
"outputs": [],
|
212 |
+
"source": [
|
213 |
+
"import numpy as np\n",
|
214 |
+
"\n",
|
215 |
+
"\n",
|
216 |
+
"def compute_metrics(eval_preds):\n",
|
217 |
+
" preds, labels = eval_preds\n",
|
218 |
+
" # In case the model returns more than the prediction logits\n",
|
219 |
+
" if isinstance(preds, tuple):\n",
|
220 |
+
" preds = preds[0]\n",
|
221 |
+
"\n",
|
222 |
+
" decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n",
|
223 |
+
"\n",
|
224 |
+
" # Replace -100s in the labels as we can't decode them\n",
|
225 |
+
" labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n",
|
226 |
+
" decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
|
227 |
+
"\n",
|
228 |
+
" # Some simple post-processing\n",
|
229 |
+
" decoded_preds = [pred.strip() for pred in decoded_preds]\n",
|
230 |
+
" decoded_labels = [[label.strip()] for label in decoded_labels]\n",
|
231 |
+
"\n",
|
232 |
+
" result = metric.compute(predictions=decoded_preds, references=decoded_labels)\n",
|
233 |
+
" return {\"bleu\": result[\"score\"]}"
|
234 |
+
]
|
235 |
+
},
|
236 |
+
{
|
237 |
+
"cell_type": "code",
|
238 |
+
"execution_count": 11,
|
239 |
+
"id": "09afe1d0",
|
240 |
+
"metadata": {},
|
241 |
+
"outputs": [],
|
242 |
+
"source": [
|
243 |
+
"data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)"
|
244 |
+
]
|
245 |
+
},
|
246 |
+
{
|
247 |
+
"cell_type": "code",
|
248 |
+
"execution_count": 12,
|
249 |
+
"id": "58e84fd1",
|
250 |
+
"metadata": {},
|
251 |
+
"outputs": [],
|
252 |
+
"source": [
|
253 |
+
"from transformers import Seq2SeqTrainingArguments\n",
|
254 |
+
"\n",
|
255 |
+
"args = Seq2SeqTrainingArguments(\n",
|
256 |
+
" f\"T5 test\",\n",
|
257 |
+
" evaluation_strategy=\"no\",\n",
|
258 |
+
" save_strategy=\"epoch\",\n",
|
259 |
+
" learning_rate=3e-4,\n",
|
260 |
+
" per_device_train_batch_size=64,\n",
|
261 |
+
" per_device_eval_batch_size=32,\n",
|
262 |
+
" weight_decay=0.01,\n",
|
263 |
+
" save_total_limit=3,\n",
|
264 |
+
" num_train_epochs=20,\n",
|
265 |
+
" predict_with_generate=True,\n",
|
266 |
+
" fp16=True,\n",
|
267 |
+
")"
|
268 |
+
]
|
269 |
+
},
|
270 |
+
{
|
271 |
+
"cell_type": "code",
|
272 |
+
"execution_count": 13,
|
273 |
+
"id": "edfcbac1",
|
274 |
+
"metadata": {},
|
275 |
+
"outputs": [],
|
276 |
+
"source": [
|
277 |
+
"from transformers import Seq2SeqTrainer\n",
|
278 |
+
"\n",
|
279 |
+
"trainer = Seq2SeqTrainer(\n",
|
280 |
+
" model,\n",
|
281 |
+
" args,\n",
|
282 |
+
" train_dataset=tokenized_datasets[\"train\"],\n",
|
283 |
+
" eval_dataset=tokenized_datasets[\"validation\"],\n",
|
284 |
+
" data_collator=data_collator,\n",
|
285 |
+
" tokenizer=tokenizer,\n",
|
286 |
+
" compute_metrics=compute_metrics,\n",
|
287 |
+
")"
|
288 |
+
]
|
289 |
+
},
|
290 |
+
{
|
291 |
+
"cell_type": "code",
|
292 |
+
"execution_count": 14,
|
293 |
+
"id": "e0065364",
|
294 |
+
"metadata": {},
|
295 |
+
"outputs": [
|
296 |
+
{
|
297 |
+
"name": "stderr",
|
298 |
+
"output_type": "stream",
|
299 |
+
"text": [
|
300 |
+
"You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
|
301 |
+
]
|
302 |
+
},
|
303 |
+
{
|
304 |
+
"data": {
|
305 |
+
"text/html": [
|
306 |
+
"\n",
|
307 |
+
" <div>\n",
|
308 |
+
" \n",
|
309 |
+
" <progress value='12' max='12' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
310 |
+
" [12/12 00:15]\n",
|
311 |
+
" </div>\n",
|
312 |
+
" "
|
313 |
+
],
|
314 |
+
"text/plain": [
|
315 |
+
"<IPython.core.display.HTML object>"
|
316 |
+
]
|
317 |
+
},
|
318 |
+
"metadata": {},
|
319 |
+
"output_type": "display_data"
|
320 |
+
},
|
321 |
+
{
|
322 |
+
"name": "stderr",
|
323 |
+
"output_type": "stream",
|
324 |
+
"text": [
|
325 |
+
"Trainer is attempting to log a value of \"{'summarization': {'early_stopping': True, 'length_penalty': 2.0, 'max_length': 200, 'min_length': 30, 'no_repeat_ngram_size': 3, 'num_beams': 4, 'prefix': 'summarize: '}, 'translation_en_to_de': {'early_stopping': True, 'max_length': 300, 'num_beams': 4, 'prefix': 'translate English to German: '}, 'translation_en_to_fr': {'early_stopping': True, 'max_length': 300, 'num_beams': 4, 'prefix': 'translate English to French: '}, 'translation_en_to_ro': {'early_stopping': True, 'max_length': 300, 'num_beams': 4, 'prefix': 'translate English to Romanian: '}}\" for key \"task_specific_params\" as a parameter. MLflow's log_param() only accepts values no longer than 250 characters so we dropped this attribute. You can use `MLFLOW_FLATTEN_PARAMS` environment variable to flatten the parameters and avoid this message.\n"
|
326 |
+
]
|
327 |
+
},
|
328 |
+
{
|
329 |
+
"data": {
|
330 |
+
"text/plain": [
|
331 |
+
"{'eval_loss': 6.675447940826416,\n",
|
332 |
+
" 'eval_bleu': 0.006728795795564811,\n",
|
333 |
+
" 'eval_runtime': 17.5858,\n",
|
334 |
+
" 'eval_samples_per_second': 43.217,\n",
|
335 |
+
" 'eval_steps_per_second': 0.682}"
|
336 |
+
]
|
337 |
+
},
|
338 |
+
"execution_count": 14,
|
339 |
+
"metadata": {},
|
340 |
+
"output_type": "execute_result"
|
341 |
+
}
|
342 |
+
],
|
343 |
+
"source": [
|
344 |
+
"trainer.evaluate(max_length=512)"
|
345 |
+
]
|
346 |
+
},
|
347 |
+
{
|
348 |
+
"cell_type": "code",
|
349 |
+
"execution_count": 15,
|
350 |
+
"id": "64ad307b",
|
351 |
+
"metadata": {},
|
352 |
+
"outputs": [
|
353 |
+
{
|
354 |
+
"name": "stderr",
|
355 |
+
"output_type": "stream",
|
356 |
+
"text": [
|
357 |
+
"/home/kave/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/transformers/optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
|
358 |
+
" warnings.warn(\n"
|
359 |
+
]
|
360 |
+
},
|
361 |
+
{
|
362 |
+
"data": {
|
363 |
+
"text/html": [
|
364 |
+
"\n",
|
365 |
+
" <div>\n",
|
366 |
+
" \n",
|
367 |
+
" <progress value='4' max='1080' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
368 |
+
" [ 4/1080 00:01 < 09:22, 1.91 it/s, Epoch 0.06/20]\n",
|
369 |
+
" </div>\n",
|
370 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
371 |
+
" <thead>\n",
|
372 |
+
" <tr style=\"text-align: left;\">\n",
|
373 |
+
" <th>Step</th>\n",
|
374 |
+
" <th>Training Loss</th>\n",
|
375 |
+
" </tr>\n",
|
376 |
+
" </thead>\n",
|
377 |
+
" <tbody>\n",
|
378 |
+
" </tbody>\n",
|
379 |
+
"</table><p>"
|
380 |
+
],
|
381 |
+
"text/plain": [
|
382 |
+
"<IPython.core.display.HTML object>"
|
383 |
+
]
|
384 |
+
},
|
385 |
+
"metadata": {},
|
386 |
+
"output_type": "display_data"
|
387 |
+
},
|
388 |
+
{
|
389 |
+
"ename": "OutOfMemoryError",
|
390 |
+
"evalue": "CUDA out of memory. Tried to allocate 456.00 MiB (GPU 0; 11.75 GiB total capacity; 10.26 GiB already allocated; 131.12 MiB free; 10.83 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF",
|
391 |
+
"output_type": "error",
|
392 |
+
"traceback": [
|
393 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
394 |
+
"\u001b[0;31mOutOfMemoryError\u001b[0m Traceback (most recent call last)",
|
395 |
+
"Cell \u001b[0;32mIn[15], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
396 |
+
"File \u001b[0;32m~/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/transformers/trainer.py:1539\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 1534\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_wrapped \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\n\u001b[1;32m 1536\u001b[0m inner_training_loop \u001b[38;5;241m=\u001b[39m find_executable_batch_size(\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_inner_training_loop, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_train_batch_size, args\u001b[38;5;241m.\u001b[39mauto_find_batch_size\n\u001b[1;32m 1538\u001b[0m )\n\u001b[0;32m-> 1539\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1540\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1541\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1542\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1544\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
|
397 |
+
"File \u001b[0;32m~/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/transformers/trainer.py:1809\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 1806\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_step_begin(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[1;32m 1808\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39maccumulate(model):\n\u001b[0;32m-> 1809\u001b[0m tr_loss_step \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1811\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 1812\u001b[0m args\u001b[38;5;241m.\u001b[39mlogging_nan_inf_filter\n\u001b[1;32m 1813\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_tpu_available()\n\u001b[1;32m 1814\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39misnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch\u001b[38;5;241m.\u001b[39misinf(tr_loss_step))\n\u001b[1;32m 1815\u001b[0m ):\n\u001b[1;32m 1816\u001b[0m \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[1;32m 1817\u001b[0m tr_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m tr_loss \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_globalstep_last_logged)\n",
|
398 |
+
"File \u001b[0;32m~/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/transformers/trainer.py:2654\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[0;34m(self, model, inputs)\u001b[0m\n\u001b[1;32m 2651\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss_mb\u001b[38;5;241m.\u001b[39mreduce_mean()\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m 2653\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcompute_loss_context_manager():\n\u001b[0;32m-> 2654\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2656\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mn_gpu \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 2657\u001b[0m loss \u001b[38;5;241m=\u001b[39m loss\u001b[38;5;241m.\u001b[39mmean() \u001b[38;5;66;03m# mean() to average on multi-gpu parallel training\u001b[39;00m\n",
|
399 |
+
"File \u001b[0;32m~/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/transformers/trainer.py:2679\u001b[0m, in \u001b[0;36mTrainer.compute_loss\u001b[0;34m(self, model, inputs, return_outputs)\u001b[0m\n\u001b[1;32m 2677\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 2678\u001b[0m labels \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m-> 2679\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2680\u001b[0m \u001b[38;5;66;03m# Save past state if it exists\u001b[39;00m\n\u001b[1;32m 2681\u001b[0m \u001b[38;5;66;03m# TODO: this needs to be fixed and made cleaner later.\u001b[39;00m\n\u001b[1;32m 2682\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mpast_index \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m:\n",
|
400 |
+
"File \u001b[0;32m~/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1502\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
|
401 |
+
"File \u001b[0;32m~/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/accelerate/utils/operations.py:581\u001b[0m, in \u001b[0;36mconvert_outputs_to_fp32.<locals>.forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 580\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 581\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmodel_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
402 |
+
"File \u001b[0;32m~/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/accelerate/utils/operations.py:569\u001b[0m, in \u001b[0;36mConvertOutputsToFp32.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 568\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 569\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m convert_to_fp32(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m)\n",
|
403 |
+
"File \u001b[0;32m~/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/torch/amp/autocast_mode.py:14\u001b[0m, in \u001b[0;36mautocast_decorator.<locals>.decorate_autocast\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_autocast\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m autocast_instance:\n\u001b[0;32m---> 14\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
404 |
+
"File \u001b[0;32m~/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/accelerate/utils/operations.py:581\u001b[0m, in \u001b[0;36mconvert_outputs_to_fp32.<locals>.forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 580\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 581\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmodel_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
405 |
+
"File \u001b[0;32m~/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/accelerate/utils/operations.py:569\u001b[0m, in \u001b[0;36mConvertOutputsToFp32.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 568\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 569\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mconvert_to_fp32\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n",
|
406 |
+
"File \u001b[0;32m~/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/accelerate/utils/operations.py:548\u001b[0m, in \u001b[0;36mconvert_to_fp32\u001b[0;34m(tensor)\u001b[0m\n\u001b[1;32m 545\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_is_fp16_bf16_tensor\u001b[39m(tensor):\n\u001b[1;32m 546\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(tensor, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdtype\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;129;01mand\u001b[39;00m tensor\u001b[38;5;241m.\u001b[39mdtype \u001b[38;5;129;01min\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39mfloat16, torch\u001b[38;5;241m.\u001b[39mbfloat16)\n\u001b[0;32m--> 548\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mrecursively_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_convert_to_fp32\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtensor\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_is_fp16_bf16_tensor\u001b[49m\u001b[43m)\u001b[49m\n",
|
407 |
+
"File \u001b[0;32m~/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/accelerate/utils/operations.py:120\u001b[0m, in \u001b[0;36mrecursively_apply\u001b[0;34m(func, data, test_type, error_on_other_type, *args, **kwargs)\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m honor_type(\n\u001b[1;32m 110\u001b[0m data,\n\u001b[1;32m 111\u001b[0m (\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 116\u001b[0m ),\n\u001b[1;32m 117\u001b[0m )\n\u001b[1;32m 118\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data, Mapping):\n\u001b[1;32m 119\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(data)(\n\u001b[0;32m--> 120\u001b[0m {\n\u001b[1;32m 121\u001b[0m k: recursively_apply(\n\u001b[1;32m 122\u001b[0m func, v, \u001b[38;5;241m*\u001b[39margs, test_type\u001b[38;5;241m=\u001b[39mtest_type, error_on_other_type\u001b[38;5;241m=\u001b[39merror_on_other_type, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs\n\u001b[1;32m 123\u001b[0m )\n\u001b[1;32m 124\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m data\u001b[38;5;241m.\u001b[39mitems()\n\u001b[1;32m 125\u001b[0m }\n\u001b[1;32m 126\u001b[0m )\n\u001b[1;32m 127\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m test_type(data):\n\u001b[1;32m 128\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m func(data, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
|
408 |
+
"File \u001b[0;32m~/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/accelerate/utils/operations.py:121\u001b[0m, in \u001b[0;36m<dictcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m honor_type(\n\u001b[1;32m 110\u001b[0m data,\n\u001b[1;32m 111\u001b[0m (\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 116\u001b[0m ),\n\u001b[1;32m 117\u001b[0m )\n\u001b[1;32m 118\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(data, Mapping):\n\u001b[1;32m 119\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(data)(\n\u001b[1;32m 120\u001b[0m {\n\u001b[0;32m--> 121\u001b[0m k: \u001b[43mrecursively_apply\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 122\u001b[0m \u001b[43m \u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtest_type\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43merror_on_other_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43merror_on_other_type\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 123\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 124\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m data\u001b[38;5;241m.\u001b[39mitems()\n\u001b[1;32m 125\u001b[0m }\n\u001b[1;32m 126\u001b[0m )\n\u001b[1;32m 127\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m test_type(data):\n\u001b[1;32m 128\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m func(data, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
|
409 |
+
"File \u001b[0;32m~/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/accelerate/utils/operations.py:128\u001b[0m, in \u001b[0;36mrecursively_apply\u001b[0;34m(func, data, test_type, error_on_other_type, *args, **kwargs)\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(data)(\n\u001b[1;32m 120\u001b[0m {\n\u001b[1;32m 121\u001b[0m k: recursively_apply(\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 125\u001b[0m }\n\u001b[1;32m 126\u001b[0m )\n\u001b[1;32m 127\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m test_type(data):\n\u001b[0;32m--> 128\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 129\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m error_on_other_type:\n\u001b[1;32m 130\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[1;32m 131\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnsupported types (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(data)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m) passed to `\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m`. Only nested list/tuple/dicts of \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 132\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mobjects that are valid for `\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtest_type\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m` should be passed.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 133\u001b[0m )\n",
|
410 |
+
"File \u001b[0;32m~/miniconda3/envs/afterhours_dev/lib/python3.10/site-packages/accelerate/utils/operations.py:543\u001b[0m, in \u001b[0;36mconvert_to_fp32.<locals>._convert_to_fp32\u001b[0;34m(tensor)\u001b[0m\n\u001b[1;32m 542\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_convert_to_fp32\u001b[39m(tensor):\n\u001b[0;32m--> 543\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtensor\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfloat\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
411 |
+
"\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 456.00 MiB (GPU 0; 11.75 GiB total capacity; 10.26 GiB already allocated; 131.12 MiB free; 10.83 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"
|
412 |
+
]
|
413 |
+
}
|
414 |
+
],
|
415 |
+
"source": [
|
416 |
+
"trainer.train()"
|
417 |
+
]
|
418 |
+
},
|
419 |
+
{
|
420 |
+
"cell_type": "code",
|
421 |
+
"execution_count": null,
|
422 |
+
"id": "6672ff53",
|
423 |
+
"metadata": {},
|
424 |
+
"outputs": [],
|
425 |
+
"source": []
|
426 |
+
}
|
427 |
+
],
|
428 |
+
"metadata": {
|
429 |
+
"kernelspec": {
|
430 |
+
"display_name": "Python 3 (ipykernel)",
|
431 |
+
"language": "python",
|
432 |
+
"name": "python3"
|
433 |
+
},
|
434 |
+
"language_info": {
|
435 |
+
"codemirror_mode": {
|
436 |
+
"name": "ipython",
|
437 |
+
"version": 3
|
438 |
+
},
|
439 |
+
"file_extension": ".py",
|
440 |
+
"mimetype": "text/x-python",
|
441 |
+
"name": "python",
|
442 |
+
"nbconvert_exporter": "python",
|
443 |
+
"pygments_lexer": "ipython3",
|
444 |
+
"version": "3.10.0"
|
445 |
+
}
|
446 |
+
},
|
447 |
+
"nbformat": 4,
|
448 |
+
"nbformat_minor": 5
|
449 |
+
}
|
elise/src/train_t5_seq2seq.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Training Flant_T5 model on tner/mit_restaurant on seq2seq task
|
3 |
+
"""
|
4 |
+
from dataclasses import asdict
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import evaluate
|
8 |
+
import datasets
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
from transformers import (
|
11 |
+
AutoTokenizer,
|
12 |
+
AutoModelForSeq2SeqLM,
|
13 |
+
DataCollatorForSeq2Seq,
|
14 |
+
get_scheduler,
|
15 |
+
)
|
16 |
+
from accelerate import Accelerator
|
17 |
+
import numpy as np
|
18 |
+
import mlflow
|
19 |
+
|
20 |
+
from tqdm.auto import tqdm
|
21 |
+
|
22 |
+
from utils.logger import get_logger
|
23 |
+
from configs import T5TrainingConfig
|
24 |
+
from data import MITRestaurants, get_default_transforms
|
25 |
+
|
26 |
+
log = get_logger("Flan_T5")
|
27 |
+
log.debug("heloooooooooooo?")
|
28 |
+
|
29 |
+
# get dataset
|
30 |
+
transforms = get_default_transforms()
|
31 |
+
dataset = (
|
32 |
+
MITRestaurants.from_hf("tner/mit_restaurant")
|
33 |
+
.set_transforms(transforms)
|
34 |
+
.hf_training()
|
35 |
+
)
|
36 |
+
dataset["train"] = datasets.concatenate_datasets([dataset["train"], dataset["test"]])
|
37 |
+
# log.info(dataset)
|
38 |
+
print(dataset)
|
39 |
+
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
|
40 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
|
41 |
+
|
42 |
+
|
43 |
+
def tokenize(example):
|
44 |
+
"""Tokenizes dataset for seq2seq task"""
|
45 |
+
tokenized = tokenizer(
|
46 |
+
example["tokens"],
|
47 |
+
text_target=example["labels"],
|
48 |
+
max_length=512,
|
49 |
+
truncation=True,
|
50 |
+
)
|
51 |
+
|
52 |
+
return tokenized
|
53 |
+
|
54 |
+
|
55 |
+
tokenized_datasets = dataset.map(
|
56 |
+
tokenize,
|
57 |
+
batched=True,
|
58 |
+
remove_columns=dataset["train"].column_names,
|
59 |
+
)
|
60 |
+
|
61 |
+
# bleu metric
|
62 |
+
metric = evaluate.load("sacrebleu")
|
63 |
+
|
64 |
+
|
65 |
+
def postprocess(predictions, labels):
|
66 |
+
"""Post processing to convert model output for evaluation"""
|
67 |
+
predictions = predictions.cpu().numpy()
|
68 |
+
labels = labels.cpu().numpy()
|
69 |
+
|
70 |
+
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
|
71 |
+
|
72 |
+
# Replace -100 in the labels as we can't decode them.
|
73 |
+
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
|
74 |
+
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
75 |
+
|
76 |
+
# Some simple post-processing
|
77 |
+
decoded_preds = [pred.strip() for pred in decoded_preds]
|
78 |
+
decoded_labels = [[label.strip()] for label in decoded_labels]
|
79 |
+
return decoded_preds, decoded_labels
|
80 |
+
|
81 |
+
|
82 |
+
config = T5TrainingConfig()
|
83 |
+
|
84 |
+
# data collator
|
85 |
+
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
|
86 |
+
|
87 |
+
# data loaders
|
88 |
+
tokenized_datasets.set_format("torch")
|
89 |
+
train_dataloader = DataLoader(
|
90 |
+
tokenized_datasets["train"],
|
91 |
+
shuffle=True,
|
92 |
+
collate_fn=data_collator,
|
93 |
+
batch_size=config.train_batch_size,
|
94 |
+
)
|
95 |
+
eval_dataloader = DataLoader(
|
96 |
+
tokenized_datasets["validation"],
|
97 |
+
collate_fn=data_collator,
|
98 |
+
batch_size=config.eval_batch_size,
|
99 |
+
)
|
100 |
+
|
101 |
+
# optimizer
|
102 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
|
103 |
+
num_update_steps_per_epoch = len(train_dataloader)
|
104 |
+
num_training_steps = config.epochs * num_update_steps_per_epoch
|
105 |
+
|
106 |
+
lr_scheduler = get_scheduler(
|
107 |
+
"linear",
|
108 |
+
optimizer=optimizer,
|
109 |
+
num_warmup_steps=config.num_warmup_steps,
|
110 |
+
num_training_steps=num_training_steps,
|
111 |
+
)
|
112 |
+
|
113 |
+
# accelerator
|
114 |
+
accelerator = Accelerator(
|
115 |
+
mixed_precision=config.mixed_precision,
|
116 |
+
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
117 |
+
)
|
118 |
+
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
|
119 |
+
model, optimizer, train_dataloader, eval_dataloader
|
120 |
+
)
|
121 |
+
|
122 |
+
progress_bar = tqdm(range(num_training_steps))
|
123 |
+
|
124 |
+
|
125 |
+
def train():
|
126 |
+
"""Training function for finetuing flanT5"""
|
127 |
+
# log.info("Starting Training")
|
128 |
+
print("Starting Traning")
|
129 |
+
for epoch in range(config.epochs):
|
130 |
+
# Training
|
131 |
+
model.train()
|
132 |
+
for batch in train_dataloader:
|
133 |
+
with accelerator.accumulate(model):
|
134 |
+
outputs = model(**batch)
|
135 |
+
loss = outputs.loss
|
136 |
+
accelerator.backward(loss)
|
137 |
+
|
138 |
+
optimizer.step()
|
139 |
+
lr_scheduler.step()
|
140 |
+
optimizer.zero_grad()
|
141 |
+
progress_bar.update(1)
|
142 |
+
|
143 |
+
# Evaluation
|
144 |
+
model.eval()
|
145 |
+
for batch in tqdm(eval_dataloader):
|
146 |
+
with torch.no_grad():
|
147 |
+
generated_tokens = accelerator.unwrap_model(model).generate(
|
148 |
+
batch["input_ids"],
|
149 |
+
attention_mask=batch["attention_mask"],
|
150 |
+
max_length=128,
|
151 |
+
)
|
152 |
+
labels = batch["labels"]
|
153 |
+
|
154 |
+
# Necessary to pad predictions and labels for being gathered
|
155 |
+
generated_tokens = accelerator.pad_across_processes(
|
156 |
+
generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
|
157 |
+
)
|
158 |
+
labels = accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
|
159 |
+
|
160 |
+
predictions_gathered = accelerator.gather(generated_tokens)
|
161 |
+
labels_gathered = accelerator.gather(labels)
|
162 |
+
|
163 |
+
decoded_preds, decoded_labels = postprocess(
|
164 |
+
predictions_gathered, labels_gathered
|
165 |
+
)
|
166 |
+
metric.add_batch(predictions=decoded_preds, references=decoded_labels)
|
167 |
+
|
168 |
+
results = metric.compute()
|
169 |
+
mlflow.log_metrics({"epoch": epoch, "BLEU score": results["score"]})
|
170 |
+
print(f"epoch {epoch}, BLEU score: {results['score']:.2f}")
|
171 |
+
|
172 |
+
# Save and upload
|
173 |
+
accelerator.wait_for_everyone()
|
174 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
175 |
+
unwrapped_model.save_pretrained(
|
176 |
+
config.output_dir, save_function=accelerator.save
|
177 |
+
)
|
178 |
+
if accelerator.is_main_process:
|
179 |
+
tokenizer.save_pretrained(config.output_dir)
|
180 |
+
# save model with mlflow
|
181 |
+
mlflow.transformers.log_model(
|
182 |
+
transformers_model={"model": unwrapped_model, "tokenizer": tokenizer},
|
183 |
+
task="text2text-generation",
|
184 |
+
artifact_path="seq2seq_model",
|
185 |
+
registered_model_name="FlanT5_MIT",
|
186 |
+
)
|
187 |
+
|
188 |
+
|
189 |
+
mlflow.set_tracking_uri("http://127.0.0.1:5000")
|
190 |
+
with mlflow.start_run() as mlflow_run:
|
191 |
+
mlflow.log_params(asdict(config))
|
192 |
+
train()
|
elise/src/utils/logger.py
CHANGED
@@ -4,7 +4,9 @@ Logging helper module
|
|
4 |
import logging.config
|
5 |
import yaml
|
6 |
|
7 |
-
with open(
|
|
|
|
|
8 |
config = yaml.safe_load(f.read())
|
9 |
logging.config.dictConfig(config)
|
10 |
logging.captureWarnings(True)
|
|
|
4 |
import logging.config
|
5 |
import yaml
|
6 |
|
7 |
+
with open(
|
8 |
+
"/home/kave/work/Elise/elise/src/configs/logging_config.yaml", "r", encoding="utf-8"
|
9 |
+
) as f:
|
10 |
config = yaml.safe_load(f.read())
|
11 |
logging.config.dictConfig(config)
|
12 |
logging.captureWarnings(True)
|
requirements.txt
CHANGED
@@ -11,3 +11,5 @@ transformers==4.31.0
|
|
11 |
pylint==2.17.5
|
12 |
gradio==3.39.0
|
13 |
gradio_client==0.3.0
|
|
|
|
|
|
11 |
pylint==2.17.5
|
12 |
gradio==3.39.0
|
13 |
gradio_client==0.3.0
|
14 |
+
accelerate==0.21.0
|
15 |
+
evaluate==0.4.0
|