oracat commited on
Commit
73010ee
·
1 Parent(s): 14f225e

Upload finetuning-arxiv.ipynb

Browse files
Files changed (1) hide show
  1. finetuning-arxiv.ipynb +1256 -0
finetuning-arxiv.ipynb ADDED
@@ -0,0 +1,1256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "1c71aba7-c0f3-4378-9b63-55529e0994b4",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Data\n",
9
+ "\n",
10
+ "Мы используем следующий датасет для файнтюнинга:\n",
11
+ "\n",
12
+ "- [arXiv papers](https://www.kaggle.com/datasets/neelshah18/arxivdataset)\n",
13
+ "\n",
14
+ "Среди статей на arXiv есть также статьи по вычислительной биологии, геномике, etc.\n",
15
+ "\n",
16
+ "Среди альтернатив — [датасет](https://zenodo.org/record/7695390) из [недавнего исследования](https://www.biorxiv.org/content/10.1101/2023.04.10.536208v1.full.pdf) с названиями и лейблами статей из PubMed. В нём 20 миллионов статей, но приведены только заголовки (без абстрактов).\n",
17
+ "\n",
18
+ "В данном ноутбуке мы используем данные и теги с arXiv."
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "markdown",
23
+ "id": "e9874f4a-3898-4c89-a0f7-04eeabf2b389",
24
+ "metadata": {
25
+ "tags": []
26
+ },
27
+ "source": [
28
+ "# Models\n",
29
+ "\n",
30
+ "В качестве базовой модели мы используем BERT, натренированный на биомедицинских данных (из PubMed). \n",
31
+ "\n",
32
+ "- [BiomedNLP-PubMedBERT](https://huggingface.co/microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract)"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "markdown",
37
+ "id": "991e48e7-897f-45a3-8a0b-539ea67b4eb5",
38
+ "metadata": {},
39
+ "source": [
40
+ "---"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "markdown",
45
+ "id": "2f130f05-21ee-46f9-889f-488e8c676aba",
46
+ "metadata": {},
47
+ "source": [
48
+ "# Imports"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": 1,
54
+ "id": "757a0582-1b8c-4f1c-b26f-544688e391f4",
55
+ "metadata": {
56
+ "tags": []
57
+ },
58
+ "outputs": [],
59
+ "source": [
60
+ "import torch\n",
61
+ "import transformers\n",
62
+ "import numpy as np\n",
63
+ "import pandas as pd\n",
64
+ "from tqdm import tqdm\n",
65
+ "\n",
66
+ "import torch\n",
67
+ "from datasets import Dataset, ClassLabel\n",
68
+ "from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModelForSequenceClassification\n",
69
+ "from transformers import TrainingArguments, Trainer\n",
70
+ "from transformers import pipeline\n",
71
+ "import evaluate"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "markdown",
76
+ "id": "03847b87-d096-49a5-b6e2-023fa08b94c2",
77
+ "metadata": {},
78
+ "source": [
79
+ "# Load data"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "markdown",
84
+ "id": "b3e902ea-4e0f-4d76-b27b-59e472b2b556",
85
+ "metadata": {},
86
+ "source": [
87
+ "Загрузим данные для файнтюнинга — в частности, нам понадобятся названия статей, их абстракты и теги."
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": 2,
93
+ "id": "1be8f69e-bd7d-4ca9-ba9f-044b8e7bc497",
94
+ "metadata": {
95
+ "tags": []
96
+ },
97
+ "outputs": [],
98
+ "source": [
99
+ "df = pd.read_json(\"arxivData.json\")"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "markdown",
104
+ "id": "791edb3c-a96d-4042-b35d-c8097bbbef79",
105
+ "metadata": {},
106
+ "source": [
107
+ " "
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "markdown",
112
+ "id": "d5b6158a-728e-4ada-bcdc-a4a49328f002",
113
+ "metadata": {},
114
+ "source": [
115
+ "Совместим заголовки и абстракты и сохраним текст в соответствующей колонке:"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "execution_count": 3,
121
+ "id": "c8709a7b-becf-4f19-8b4f-8773cd5c60f1",
122
+ "metadata": {
123
+ "tags": []
124
+ },
125
+ "outputs": [],
126
+ "source": [
127
+ "df['text'] = df['title'] + \"\\n\" + df['summary']"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "execution_count": 4,
133
+ "id": "ed0ed687-6439-494a-a5a8-c572bc2e4059",
134
+ "metadata": {
135
+ "tags": []
136
+ },
137
+ "outputs": [
138
+ {
139
+ "data": {
140
+ "text/html": [
141
+ "<div>\n",
142
+ "<style scoped>\n",
143
+ " .dataframe tbody tr th:only-of-type {\n",
144
+ " vertical-align: middle;\n",
145
+ " }\n",
146
+ "\n",
147
+ " .dataframe tbody tr th {\n",
148
+ " vertical-align: top;\n",
149
+ " }\n",
150
+ "\n",
151
+ " .dataframe thead th {\n",
152
+ " text-align: right;\n",
153
+ " }\n",
154
+ "</style>\n",
155
+ "<table border=\"1\" class=\"dataframe\">\n",
156
+ " <thead>\n",
157
+ " <tr style=\"text-align: right;\">\n",
158
+ " <th></th>\n",
159
+ " <th>author</th>\n",
160
+ " <th>day</th>\n",
161
+ " <th>id</th>\n",
162
+ " <th>link</th>\n",
163
+ " <th>month</th>\n",
164
+ " <th>summary</th>\n",
165
+ " <th>tag</th>\n",
166
+ " <th>title</th>\n",
167
+ " <th>year</th>\n",
168
+ " <th>text</th>\n",
169
+ " </tr>\n",
170
+ " </thead>\n",
171
+ " <tbody>\n",
172
+ " <tr>\n",
173
+ " <th>0</th>\n",
174
+ " <td>[{'name': 'Ahmed Osman'}, {'name': 'Wojciech S...</td>\n",
175
+ " <td>1</td>\n",
176
+ " <td>1802.00209v1</td>\n",
177
+ " <td>[{'rel': 'alternate', 'href': 'http://arxiv.or...</td>\n",
178
+ " <td>2</td>\n",
179
+ " <td>We propose an architecture for VQA which utili...</td>\n",
180
+ " <td>[{'term': 'cs.AI', 'scheme': 'http://arxiv.org...</td>\n",
181
+ " <td>Dual Recurrent Attention Units for Visual Ques...</td>\n",
182
+ " <td>2018</td>\n",
183
+ " <td>Dual Recurrent Attention Units for Visual Ques...</td>\n",
184
+ " </tr>\n",
185
+ " <tr>\n",
186
+ " <th>1</th>\n",
187
+ " <td>[{'name': 'Ji Young Lee'}, {'name': 'Franck De...</td>\n",
188
+ " <td>12</td>\n",
189
+ " <td>1603.03827v1</td>\n",
190
+ " <td>[{'rel': 'alternate', 'href': 'http://arxiv.or...</td>\n",
191
+ " <td>3</td>\n",
192
+ " <td>Recent approaches based on artificial neural n...</td>\n",
193
+ " <td>[{'term': 'cs.CL', 'scheme': 'http://arxiv.org...</td>\n",
194
+ " <td>Sequential Short-Text Classification with Recu...</td>\n",
195
+ " <td>2016</td>\n",
196
+ " <td>Sequential Short-Text Classification with Recu...</td>\n",
197
+ " </tr>\n",
198
+ " </tbody>\n",
199
+ "</table>\n",
200
+ "</div>"
201
+ ],
202
+ "text/plain": [
203
+ " author day id \\\n",
204
+ "0 [{'name': 'Ahmed Osman'}, {'name': 'Wojciech S... 1 1802.00209v1 \n",
205
+ "1 [{'name': 'Ji Young Lee'}, {'name': 'Franck De... 12 1603.03827v1 \n",
206
+ "\n",
207
+ " link month \\\n",
208
+ "0 [{'rel': 'alternate', 'href': 'http://arxiv.or... 2 \n",
209
+ "1 [{'rel': 'alternate', 'href': 'http://arxiv.or... 3 \n",
210
+ "\n",
211
+ " summary \\\n",
212
+ "0 We propose an architecture for VQA which utili... \n",
213
+ "1 Recent approaches based on artificial neural n... \n",
214
+ "\n",
215
+ " tag \\\n",
216
+ "0 [{'term': 'cs.AI', 'scheme': 'http://arxiv.org... \n",
217
+ "1 [{'term': 'cs.CL', 'scheme': 'http://arxiv.org... \n",
218
+ "\n",
219
+ " title year \\\n",
220
+ "0 Dual Recurrent Attention Units for Visual Ques... 2018 \n",
221
+ "1 Sequential Short-Text Classification with Recu... 2016 \n",
222
+ "\n",
223
+ " text \n",
224
+ "0 Dual Recurrent Attention Units for Visual Ques... \n",
225
+ "1 Sequential Short-Text Classification with Recu... "
226
+ ]
227
+ },
228
+ "execution_count": 4,
229
+ "metadata": {},
230
+ "output_type": "execute_result"
231
+ }
232
+ ],
233
+ "source": [
234
+ "df.head(2)"
235
+ ]
236
+ },
237
+ {
238
+ "cell_type": "markdown",
239
+ "id": "ce1de806-a4d2-4e58-a3a8-f3542392f22e",
240
+ "metadata": {},
241
+ "source": [
242
+ "## Labels"
243
+ ]
244
+ },
245
+ {
246
+ "cell_type": "markdown",
247
+ "id": "b5183517-8b02-47bc-812a-415b5651e07d",
248
+ "metadata": {},
249
+ "source": [
250
+ "Будем использовать категории из arXiv'а, такие как `astro-ph` для статей по астрофизике или `cs.CV` для computer vision (computer science)."
251
+ ]
252
+ },
253
+ {
254
+ "cell_type": "code",
255
+ "execution_count": 5,
256
+ "id": "ba4e7197-23b6-4cb4-9b44-620c6b730eb7",
257
+ "metadata": {
258
+ "tags": []
259
+ },
260
+ "outputs": [
261
+ {
262
+ "name": "stdout",
263
+ "output_type": "stream",
264
+ "text": [
265
+ "Total: 126 labels such as adap-org, astro-ph, ..., stat.OT\n"
266
+ ]
267
+ }
268
+ ],
269
+ "source": [
270
+ "df['category'] = [eval(i)[0]['term'].strip() for i in df['tag']]\n",
271
+ "categories = np.unique(df['category'])\n",
272
+ "num_labels = len(categories)\n",
273
+ "print(f\"Total: {num_labels} labels such as {categories[0]}, {categories[1]}, ..., {categories[-1]}\")"
274
+ ]
275
+ },
276
+ {
277
+ "cell_type": "code",
278
+ "execution_count": 6,
279
+ "id": "1508a6d9-856d-4ecf-a0f3-895d3ffbe99b",
280
+ "metadata": {
281
+ "tags": []
282
+ },
283
+ "outputs": [
284
+ {
285
+ "data": {
286
+ "text/html": [
287
+ "<div>\n",
288
+ "<style scoped>\n",
289
+ " .dataframe tbody tr th:only-of-type {\n",
290
+ " vertical-align: middle;\n",
291
+ " }\n",
292
+ "\n",
293
+ " .dataframe tbody tr th {\n",
294
+ " vertical-align: top;\n",
295
+ " }\n",
296
+ "\n",
297
+ " .dataframe thead th {\n",
298
+ " text-align: right;\n",
299
+ " }\n",
300
+ "</style>\n",
301
+ "<table border=\"1\" class=\"dataframe\">\n",
302
+ " <thead>\n",
303
+ " <tr style=\"text-align: right;\">\n",
304
+ " <th></th>\n",
305
+ " <th>category</th>\n",
306
+ " <th>category_index</th>\n",
307
+ " </tr>\n",
308
+ " </thead>\n",
309
+ " <tbody>\n",
310
+ " <tr>\n",
311
+ " <th>0</th>\n",
312
+ " <td>adap-org</td>\n",
313
+ " <td>0</td>\n",
314
+ " </tr>\n",
315
+ " <tr>\n",
316
+ " <th>1</th>\n",
317
+ " <td>astro-ph</td>\n",
318
+ " <td>1</td>\n",
319
+ " </tr>\n",
320
+ " <tr>\n",
321
+ " <th>2</th>\n",
322
+ " <td>astro-ph.CO</td>\n",
323
+ " <td>2</td>\n",
324
+ " </tr>\n",
325
+ " <tr>\n",
326
+ " <th>3</th>\n",
327
+ " <td>astro-ph.EP</td>\n",
328
+ " <td>3</td>\n",
329
+ " </tr>\n",
330
+ " <tr>\n",
331
+ " <th>4</th>\n",
332
+ " <td>astro-ph.GA</td>\n",
333
+ " <td>4</td>\n",
334
+ " </tr>\n",
335
+ " </tbody>\n",
336
+ "</table>\n",
337
+ "</div>"
338
+ ],
339
+ "text/plain": [
340
+ " category category_index\n",
341
+ "0 adap-org 0\n",
342
+ "1 astro-ph 1\n",
343
+ "2 astro-ph.CO 2\n",
344
+ "3 astro-ph.EP 3\n",
345
+ "4 astro-ph.GA 4"
346
+ ]
347
+ },
348
+ "execution_count": 6,
349
+ "metadata": {},
350
+ "output_type": "execute_result"
351
+ }
352
+ ],
353
+ "source": [
354
+ "pd.DataFrame({\n",
355
+ " \"category\": categories,\n",
356
+ " \"category_index\": np.arange(num_labels),\n",
357
+ "}).head()"
358
+ ]
359
+ },
360
+ {
361
+ "cell_type": "code",
362
+ "execution_count": 7,
363
+ "id": "5c082c3a-7b0e-4320-b62d-f75a6c9f2398",
364
+ "metadata": {
365
+ "tags": []
366
+ },
367
+ "outputs": [],
368
+ "source": [
369
+ "df = pd.DataFrame({\n",
370
+ " \"category\": categories,\n",
371
+ " \"category_index\": np.arange(num_labels),\n",
372
+ "}).set_index(\"category\").join(df.set_index(\"category\"), how=\"right\", sort=False).reset_index()"
373
+ ]
374
+ },
375
+ {
376
+ "cell_type": "markdown",
377
+ "id": "76d8ccb9-a993-4d82-9dd3-689380e92e55",
378
+ "metadata": {},
379
+ "source": [
380
+ "# Model"
381
+ ]
382
+ },
383
+ {
384
+ "cell_type": "code",
385
+ "execution_count": 8,
386
+ "id": "a0c154f7-d2fa-46a1-8b69-57174bf00632",
387
+ "metadata": {
388
+ "tags": []
389
+ },
390
+ "outputs": [],
391
+ "source": [
392
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
393
+ "print(device)"
394
+ ]
395
+ },
396
+ {
397
+ "cell_type": "markdown",
398
+ "id": "2bf6513d-664d-4b94-8b05-7e8df205e3ec",
399
+ "metadata": {},
400
+ "source": [
401
+ "Токенайзер (название + абстракт -> токены):"
402
+ ]
403
+ },
404
+ {
405
+ "cell_type": "code",
406
+ "execution_count": 9,
407
+ "id": "12fa49a7-2ac5-4f78-84fe-93305926692e",
408
+ "metadata": {
409
+ "tags": []
410
+ },
411
+ "outputs": [],
412
+ "source": [
413
+ "tokenizer = AutoTokenizer.from_pretrained(\"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract\")"
414
+ ]
415
+ },
416
+ {
417
+ "cell_type": "markdown",
418
+ "id": "0ea1b4e5-9067-4292-ba12-8f560bbf26fd",
419
+ "metadata": {},
420
+ "source": [
421
+ "Сама модель, в которой `AutoModelForSequenceClassification` заменит голову для задачи классификации:"
422
+ ]
423
+ },
424
+ {
425
+ "cell_type": "code",
426
+ "execution_count": 10,
427
+ "id": "d6eb92bc-c293-47ad-b9cc-2a63e8f1de69",
428
+ "metadata": {
429
+ "tags": []
430
+ },
431
+ "outputs": [
432
+ {
433
+ "name": "stderr",
434
+ "output_type": "stream",
435
+ "text": [
436
+ "Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']\n",
437
+ "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
438
+ "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
439
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
440
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
441
+ ]
442
+ }
443
+ ],
444
+ "source": [
445
+ "model = AutoModelForSequenceClassification.from_pretrained(\"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract\", num_labels=num_labels).to(device)"
446
+ ]
447
+ },
448
+ {
449
+ "cell_type": "code",
450
+ "execution_count": 11,
451
+ "id": "f5c79846-e6fc-42c0-bb8d-949678f5e60a",
452
+ "metadata": {
453
+ "scrolled": true,
454
+ "tags": []
455
+ },
456
+ "outputs": [
457
+ {
458
+ "name": "stdout",
459
+ "output_type": "stream",
460
+ "text": [
461
+ "BertForSequenceClassification(\n",
462
+ " (bert): BertModel(\n",
463
+ " (embeddings): BertEmbeddings(\n",
464
+ " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
465
+ " (position_embeddings): Embedding(512, 768)\n",
466
+ " (token_type_embeddings): Embedding(2, 768)\n",
467
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
468
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
469
+ " )\n",
470
+ " (encoder): BertEncoder(\n",
471
+ " (layer): ModuleList(\n",
472
+ " (0-11): 12 x BertLayer(\n",
473
+ " (attention): BertAttention(\n",
474
+ " (self): BertSelfAttention(\n",
475
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
476
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
477
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
478
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
479
+ " )\n",
480
+ " (output): BertSelfOutput(\n",
481
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
482
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
483
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
484
+ " )\n",
485
+ " )\n",
486
+ " (intermediate): BertIntermediate(\n",
487
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
488
+ " (intermediate_act_fn): GELUActivation()\n",
489
+ " )\n",
490
+ " (output): BertOutput(\n",
491
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
492
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
493
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
494
+ " )\n",
495
+ " )\n",
496
+ " )\n",
497
+ " )\n",
498
+ " (pooler): BertPooler(\n",
499
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
500
+ " (activation): Tanh()\n",
501
+ " )\n",
502
+ " )\n",
503
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
504
+ " (classifier): Linear(in_features=768, out_features=126, bias=True)\n",
505
+ ")\n"
506
+ ]
507
+ }
508
+ ],
509
+ "source": [
510
+ "print(model)"
511
+ ]
512
+ },
513
+ {
514
+ "cell_type": "markdown",
515
+ "id": "5ce6eefc-91ce-4486-9568-b686d04adcc7",
516
+ "metadata": {},
517
+ "source": [
518
+ "# Training"
519
+ ]
520
+ },
521
+ {
522
+ "cell_type": "markdown",
523
+ "id": "71add72c-eafb-491a-8820-31ce7336524f",
524
+ "metadata": {},
525
+ "source": [
526
+ "## Data Loaders"
527
+ ]
528
+ },
529
+ {
530
+ "cell_type": "markdown",
531
+ "id": "2a0b579c-998a-4d2e-bf0e-d4c7406d22da",
532
+ "metadata": {},
533
+ "source": [
534
+ "Для работы с `transformers`, возможно, будет удобнее использовать библиотеку `datasets` для работы с данными."
535
+ ]
536
+ },
537
+ {
538
+ "cell_type": "markdown",
539
+ "id": "47b0e14a-866b-49ac-8b95-49a91a0bcc22",
540
+ "metadata": {},
541
+ "source": [
542
+ "Создадим (hugging face) [датасет](https://huggingface.co/docs/datasets/tabular_load#pandas-dataframes):"
543
+ ]
544
+ },
545
+ {
546
+ "cell_type": "code",
547
+ "execution_count": 13,
548
+ "id": "dc1a3f33-0ef9-43c9-ab5f-eb9ae304b897",
549
+ "metadata": {
550
+ "tags": []
551
+ },
552
+ "outputs": [],
553
+ "source": [
554
+ "np.random.seed(42)\n",
555
+ "train_indices = np.sort(np.random.choice(np.arange(len(df)), size=37_000, replace=False))\n",
556
+ "test_indices = np.array([i for i in np.arange(len(df)) if i not in train_indices])"
557
+ ]
558
+ },
559
+ {
560
+ "cell_type": "code",
561
+ "execution_count": 14,
562
+ "id": "d948f8a6-1a7a-4baa-88a0-418596a1f275",
563
+ "metadata": {
564
+ "tags": []
565
+ },
566
+ "outputs": [],
567
+ "source": [
568
+ "train_df = df.loc[:,[\"text\", \"category\"]].iloc[train_indices]\n",
569
+ "test_df = df.loc[:,[\"text\", \"category\"]].iloc[test_indices]\n",
570
+ "\n",
571
+ "train_ds = Dataset.from_pandas(train_df, split=\"train\")\n",
572
+ "test_ds = Dataset.from_pandas(test_df, split=\"test\")"
573
+ ]
574
+ },
575
+ {
576
+ "cell_type": "code",
577
+ "execution_count": 15,
578
+ "id": "50242a35-3067-41e5-8de8-f7e6a4fb6e9c",
579
+ "metadata": {
580
+ "tags": []
581
+ },
582
+ "outputs": [
583
+ {
584
+ "data": {
585
+ "application/vnd.jupyter.widget-view+json": {
586
+ "model_id": "",
587
+ "version_major": 2,
588
+ "version_minor": 0
589
+ },
590
+ "text/plain": [
591
+ "Map: 0%| | 0/37000 [00:00<?, ? examples/s]"
592
+ ]
593
+ },
594
+ "metadata": {},
595
+ "output_type": "display_data"
596
+ },
597
+ {
598
+ "data": {
599
+ "application/vnd.jupyter.widget-view+json": {
600
+ "model_id": "",
601
+ "version_major": 2,
602
+ "version_minor": 0
603
+ },
604
+ "text/plain": [
605
+ "Map: 0%| | 0/4000 [00:00<?, ? examples/s]"
606
+ ]
607
+ },
608
+ "metadata": {},
609
+ "output_type": "display_data"
610
+ }
611
+ ],
612
+ "source": [
613
+ "def tokenize_text(row):\n",
614
+ " return tokenizer(\n",
615
+ " row[\"text\"],\n",
616
+ " max_length=512,\n",
617
+ " truncation=True,\n",
618
+ " padding='max_length',\n",
619
+ " )\n",
620
+ "\n",
621
+ "train_ds = train_ds.map(tokenize_text, batched=True)\n",
622
+ "test_ds = test_ds.map(tokenize_text, batched=True)"
623
+ ]
624
+ },
625
+ {
626
+ "cell_type": "code",
627
+ "execution_count": 77,
628
+ "id": "35d454d1-fbdc-4847-8b60-4c6c442364b1",
629
+ "metadata": {
630
+ "tags": []
631
+ },
632
+ "outputs": [
633
+ {
634
+ "data": {
635
+ "application/vnd.jupyter.widget-view+json": {
636
+ "model_id": "",
637
+ "version_major": 2,
638
+ "version_minor": 0
639
+ },
640
+ "text/plain": [
641
+ "Map: 0%| | 0/37000 [00:00<?, ? examples/s]"
642
+ ]
643
+ },
644
+ "metadata": {},
645
+ "output_type": "display_data"
646
+ },
647
+ {
648
+ "data": {
649
+ "application/vnd.jupyter.widget-view+json": {
650
+ "model_id": "",
651
+ "version_major": 2,
652
+ "version_minor": 0
653
+ },
654
+ "text/plain": [
655
+ "Map: 0%| | 0/4000 [00:00<?, ? examples/s]"
656
+ ]
657
+ },
658
+ "metadata": {},
659
+ "output_type": "display_data"
660
+ },
661
+ {
662
+ "data": {
663
+ "application/vnd.jupyter.widget-view+json": {
664
+ "model_id": "",
665
+ "version_major": 2,
666
+ "version_minor": 0
667
+ },
668
+ "text/plain": [
669
+ "Casting the dataset: 0%| | 0/37000 [00:00<?, ? examples/s]"
670
+ ]
671
+ },
672
+ "metadata": {},
673
+ "output_type": "display_data"
674
+ },
675
+ {
676
+ "data": {
677
+ "application/vnd.jupyter.widget-view+json": {
678
+ "model_id": "",
679
+ "version_major": 2,
680
+ "version_minor": 0
681
+ },
682
+ "text/plain": [
683
+ "Casting the dataset: 0%| | 0/4000 [00:00<?, ? examples/s]"
684
+ ]
685
+ },
686
+ "metadata": {},
687
+ "output_type": "display_data"
688
+ }
689
+ ],
690
+ "source": [
691
+ "labels_map = ClassLabel(num_classes=num_labels, names=list(categories))\n",
692
+ "\n",
693
+ "def transform_labels(row):\n",
694
+ " # default name for a label (label or label_ids)\n",
695
+ " return {\"label\": labels_map.str2int(row[\"category\"])}\n",
696
+ "\n",
697
+ "# OR: \n",
698
+ "# \n",
699
+ "# labels_map = pd.Series(\n",
700
+ "# np.arange(num_labels),\n",
701
+ "# index=categories,\n",
702
+ "# )\n",
703
+ "# \n",
704
+ "# def transform_labels(row):\n",
705
+ "# return {\"label\": labels_map[row[\"category\"]]}\n",
706
+ "\n",
707
+ "train_ds = train_ds.map(transform_labels, batched=True)\n",
708
+ "test_ds = test_ds.map(transform_labels, batched=True)\n",
709
+ "\n",
710
+ "train_ds = train_ds.cast_column('label', labels_map)\n",
711
+ "test_ds = test_ds.cast_column('label', labels_map)"
712
+ ]
713
+ },
714
+ {
715
+ "cell_type": "markdown",
716
+ "id": "6f3862ef-ed78-461f-ba68-8f059f01d355",
717
+ "metadata": {},
718
+ "source": [
719
+ " "
720
+ ]
721
+ },
722
+ {
723
+ "cell_type": "markdown",
724
+ "id": "811c5fe3-218e-4187-878d-65abc157f802",
725
+ "metadata": {},
726
+ "source": [
727
+ "## Prepare training"
728
+ ]
729
+ },
730
+ {
731
+ "cell_type": "code",
732
+ "execution_count": 110,
733
+ "id": "d2160c7d-4130-47ae-9d6d-6684e4ba7e9b",
734
+ "metadata": {
735
+ "tags": []
736
+ },
737
+ "outputs": [
738
+ {
739
+ "name": "stderr",
740
+ "output_type": "stream",
741
+ "text": [
742
+ "Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']\n",
743
+ "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
744
+ "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
745
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
746
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
747
+ ]
748
+ }
749
+ ],
750
+ "source": [
751
+ "model = AutoModelForSequenceClassification.from_pretrained(\n",
752
+ " \"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract\", \n",
753
+ " num_labels=num_labels,\n",
754
+ " id2label={i:labels_map.names[i] for i in range(len(categories))},\n",
755
+ " label2id={labels_map.names[i]:i for i in range(len(categories))},\n",
756
+ ").to(device)"
757
+ ]
758
+ },
759
+ {
760
+ "cell_type": "code",
761
+ "execution_count": 111,
762
+ "id": "72e74c2b-89d7-4c17-8df1-dcfd40ead01e",
763
+ "metadata": {
764
+ "tags": []
765
+ },
766
+ "outputs": [],
767
+ "source": [
768
+ "tokenizer = AutoTokenizer.from_pretrained(\"microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract\")"
769
+ ]
770
+ },
771
+ {
772
+ "cell_type": "markdown",
773
+ "id": "ebb91037-fbdf-4453-87de-6da5eec3304f",
774
+ "metadata": {},
775
+ "source": [
776
+ "Будем вычислять accuracy:"
777
+ ]
778
+ },
779
+ {
780
+ "cell_type": "code",
781
+ "execution_count": 112,
782
+ "id": "630f6fa5-4c53-4962-b36d-5ee9aad6e29d",
783
+ "metadata": {
784
+ "tags": []
785
+ },
786
+ "outputs": [],
787
+ "source": [
788
+ "metric = evaluate.load(\"accuracy\")\n",
789
+ "\n",
790
+ "def compute_metrics(eval_pred):\n",
791
+ " logits, labels = eval_pred\n",
792
+ " predictions = np.argmax(logits, axis=-1)\n",
793
+ " return metric.compute(predictions=predictions, references=labels)"
794
+ ]
795
+ },
796
+ {
797
+ "cell_type": "code",
798
+ "execution_count": 113,
799
+ "id": "f64425b7-72b7-466a-8e3e-cd7624893139",
800
+ "metadata": {
801
+ "tags": []
802
+ },
803
+ "outputs": [],
804
+ "source": [
805
+ "training_args = TrainingArguments(\n",
806
+ " output_dir=\"bert-paper-classifier-arxiv\", \n",
807
+ " evaluation_strategy=\"epoch\",\n",
808
+ " per_device_train_batch_size=64,\n",
809
+ " num_train_epochs=10,\n",
810
+ " logging_steps=10,\n",
811
+ ")"
812
+ ]
813
+ },
814
+ {
815
+ "cell_type": "code",
816
+ "execution_count": 114,
817
+ "id": "b850cd9b-eb36-40ec-8cf2-26206fedcf27",
818
+ "metadata": {
819
+ "tags": []
820
+ },
821
+ "outputs": [],
822
+ "source": [
823
+ "trainer = Trainer(\n",
824
+ " model=model,\n",
825
+ " args=training_args,\n",
826
+ " train_dataset=train_ds,\n",
827
+ " eval_dataset=test_ds,\n",
828
+ " compute_metrics=compute_metrics,\n",
829
+ ")"
830
+ ]
831
+ },
832
+ {
833
+ "cell_type": "code",
834
+ "execution_count": null,
835
+ "id": "e6b88166-d82e-4502-acef-494fbb206d30",
836
+ "metadata": {},
837
+ "outputs": [],
838
+ "source": [
839
+ "trainer.train()"
840
+ ]
841
+ },
842
+ {
843
+ "cell_type": "code",
844
+ "execution_count": null,
845
+ "id": "7ed8c94a-e3ef-47f9-96a8-c112eb7f11bc",
846
+ "metadata": {},
847
+ "outputs": [],
848
+ "source": [
849
+ "# Convert to a python file and run training:\n",
850
+ "#! jupyter nbconvert finetuning-arxiv.ipynb --to python"
851
+ ]
852
+ },
853
+ {
854
+ "cell_type": "markdown",
855
+ "id": "cc8dad7d-8105-4f37-9087-615314c35afb",
856
+ "metadata": {},
857
+ "source": [
858
+ "# Save and share"
859
+ ]
860
+ },
861
+ {
862
+ "cell_type": "code",
863
+ "execution_count": 37,
864
+ "id": "38d24722-d5c6-40ac-b568-3cd7fd9f225e",
865
+ "metadata": {
866
+ "tags": []
867
+ },
868
+ "outputs": [],
869
+ "source": [
870
+ "trainer.args.hub_model_id = \"bert-paper-classifier-arxiv\""
871
+ ]
872
+ },
873
+ {
874
+ "cell_type": "code",
875
+ "execution_count": 50,
876
+ "id": "9530790c-bc63-48f4-9a01-8c534fa90e00",
877
+ "metadata": {
878
+ "tags": []
879
+ },
880
+ "outputs": [
881
+ {
882
+ "data": {
883
+ "text/plain": [
884
+ "('bert-paper-classifier/tokenizer_config.json',\n",
885
+ " 'bert-paper-classifier/special_tokens_map.json',\n",
886
+ " 'bert-paper-classifier/vocab.txt',\n",
887
+ " 'bert-paper-classifier/added_tokens.json',\n",
888
+ " 'bert-paper-classifier/tokenizer.json')"
889
+ ]
890
+ },
891
+ "execution_count": 50,
892
+ "metadata": {},
893
+ "output_type": "execute_result"
894
+ }
895
+ ],
896
+ "source": [
897
+ "tokenizer.save_pretrained(\"bert-paper-classifier-arxiv\")"
898
+ ]
899
+ },
900
+ {
901
+ "cell_type": "code",
902
+ "execution_count": 116,
903
+ "id": "0498df97-cd2c-4732-9d07-ee2013f8bd55",
904
+ "metadata": {
905
+ "tags": []
906
+ },
907
+ "outputs": [],
908
+ "source": [
909
+ "trainer.save_model(\"bert-paper-classifier-arxiv\")"
910
+ ]
911
+ },
912
+ {
913
+ "cell_type": "markdown",
914
+ "id": "7af12b9e-0d77-48ec-af6f-38556e13b067",
915
+ "metadata": {
916
+ "tags": []
917
+ },
918
+ "source": [
919
+ "Запушим модель на HF Hub:"
920
+ ]
921
+ },
922
+ {
923
+ "cell_type": "code",
924
+ "execution_count": 51,
925
+ "id": "5de0e91f-bc23-4413-b22e-5aa32b09ef12",
926
+ "metadata": {
927
+ "scrolled": true,
928
+ "tags": []
929
+ },
930
+ "outputs": [
931
+ {
932
+ "name": "stdout",
933
+ "output_type": "stream",
934
+ "text": [
935
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
936
+ "To disable this warning, you can either:\n",
937
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
938
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
939
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
940
+ "To disable this warning, you can either:\n",
941
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
942
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
943
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
944
+ "To disable this warning, you can either:\n",
945
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
946
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
947
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
948
+ "To disable this warning, you can either:\n",
949
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
950
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
951
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
952
+ "To disable this warning, you can either:\n",
953
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
954
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
955
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
956
+ "To disable this warning, you can either:\n",
957
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
958
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
959
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
960
+ "To disable this warning, you can either:\n",
961
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
962
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
963
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
964
+ "To disable this warning, you can either:\n",
965
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
966
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
967
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
968
+ "To disable this warning, you can either:\n",
969
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
970
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
971
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
972
+ "To disable this warning, you can either:\n",
973
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
974
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
975
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
976
+ "To disable this warning, you can either:\n",
977
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
978
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
979
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
980
+ "To disable this warning, you can either:\n",
981
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
982
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
983
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
984
+ "To disable this warning, you can either:\n",
985
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
986
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
987
+ ]
988
+ },
989
+ {
990
+ "name": "stderr",
991
+ "output_type": "stream",
992
+ "text": [
993
+ "To https://huggingface.co/oracat/bert-paper-classifier\n",
994
+ " 915ccf0..862abb7 main -> main\n",
995
+ "\n"
996
+ ]
997
+ },
998
+ {
999
+ "name": "stdout",
1000
+ "output_type": "stream",
1001
+ "text": [
1002
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
1003
+ "To disable this warning, you can either:\n",
1004
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
1005
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
1006
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
1007
+ "To disable this warning, you can either:\n",
1008
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
1009
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
1010
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
1011
+ "To disable this warning, you can either:\n",
1012
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
1013
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
1014
+ ]
1015
+ }
1016
+ ],
1017
+ "source": [
1018
+ "trainer.push_to_hub()"
1019
+ ]
1020
+ },
1021
+ {
1022
+ "cell_type": "markdown",
1023
+ "id": "5093aee3-106e-43e9-a9c7-413d059ebb27",
1024
+ "metadata": {},
1025
+ "source": [
1026
+ " "
1027
+ ]
1028
+ },
1029
+ {
1030
+ "cell_type": "markdown",
1031
+ "id": "b1a1029f-543c-409e-9aaf-35bcefe49988",
1032
+ "metadata": {},
1033
+ "source": [
1034
+ "# Inference"
1035
+ ]
1036
+ },
1037
+ {
1038
+ "cell_type": "markdown",
1039
+ "id": "e7b0cd5a-2e17-49f3-b2a9-5ae4e8511969",
1040
+ "metadata": {},
1041
+ "source": [
1042
+ "Теперь попробуем загрузить модель с HF Hub:"
1043
+ ]
1044
+ },
1045
+ {
1046
+ "cell_type": "code",
1047
+ "execution_count": 2,
1048
+ "id": "b7fe37b9-61a9-4796-af24-092f6722cd61",
1049
+ "metadata": {
1050
+ "tags": []
1051
+ },
1052
+ "outputs": [
1053
+ {
1054
+ "data": {
1055
+ "application/vnd.jupyter.widget-view+json": {
1056
+ "model_id": "36afc9d465f54c80ab01698f5a687388",
1057
+ "version_major": 2,
1058
+ "version_minor": 0
1059
+ },
1060
+ "text/plain": [
1061
+ "Downloading (…)okenizer_config.json: 0%| | 0.00/394 [00:00<?, ?B/s]"
1062
+ ]
1063
+ },
1064
+ "metadata": {},
1065
+ "output_type": "display_data"
1066
+ },
1067
+ {
1068
+ "data": {
1069
+ "application/vnd.jupyter.widget-view+json": {
1070
+ "model_id": "df18b9d22fc14a0c81e8cb557f88a848",
1071
+ "version_major": 2,
1072
+ "version_minor": 0
1073
+ },
1074
+ "text/plain": [
1075
+ "Downloading (…)solve/main/vocab.txt: 0%| | 0.00/225k [00:00<?, ?B/s]"
1076
+ ]
1077
+ },
1078
+ "metadata": {},
1079
+ "output_type": "display_data"
1080
+ },
1081
+ {
1082
+ "data": {
1083
+ "application/vnd.jupyter.widget-view+json": {
1084
+ "model_id": "4ba2236cf89d4159bcc9740d4654b16d",
1085
+ "version_major": 2,
1086
+ "version_minor": 0
1087
+ },
1088
+ "text/plain": [
1089
+ "Downloading (…)/main/tokenizer.json: 0%| | 0.00/679k [00:00<?, ?B/s]"
1090
+ ]
1091
+ },
1092
+ "metadata": {},
1093
+ "output_type": "display_data"
1094
+ },
1095
+ {
1096
+ "data": {
1097
+ "application/vnd.jupyter.widget-view+json": {
1098
+ "model_id": "cae249ea1c2946a89fffdb80ff1d7b7b",
1099
+ "version_major": 2,
1100
+ "version_minor": 0
1101
+ },
1102
+ "text/plain": [
1103
+ "Downloading (…)cial_tokens_map.json: 0%| | 0.00/125 [00:00<?, ?B/s]"
1104
+ ]
1105
+ },
1106
+ "metadata": {},
1107
+ "output_type": "display_data"
1108
+ },
1109
+ {
1110
+ "data": {
1111
+ "application/vnd.jupyter.widget-view+json": {
1112
+ "model_id": "b860284eb1ff4cb08b5c8d54ab1a33b9",
1113
+ "version_major": 2,
1114
+ "version_minor": 0
1115
+ },
1116
+ "text/plain": [
1117
+ "Downloading (…)lve/main/config.json: 0%| | 0.00/6.04k [00:00<?, ?B/s]"
1118
+ ]
1119
+ },
1120
+ "metadata": {},
1121
+ "output_type": "display_data"
1122
+ },
1123
+ {
1124
+ "data": {
1125
+ "application/vnd.jupyter.widget-view+json": {
1126
+ "model_id": "3607b2b6f85b49b0a03844df69077d7e",
1127
+ "version_major": 2,
1128
+ "version_minor": 0
1129
+ },
1130
+ "text/plain": [
1131
+ "Downloading pytorch_model.bin: 0%| | 0.00/438M [00:00<?, ?B/s]"
1132
+ ]
1133
+ },
1134
+ "metadata": {},
1135
+ "output_type": "display_data"
1136
+ }
1137
+ ],
1138
+ "source": [
1139
+ "inference_tokenizer = AutoTokenizer.from_pretrained(\"oracat/bert-paper-classifier-arxiv\")\n",
1140
+ "inference_model = AutoModelForSequenceClassification.from_pretrained(\"oracat/bert-paper-classifier-arxiv\")"
1141
+ ]
1142
+ },
1143
+ {
1144
+ "cell_type": "code",
1145
+ "execution_count": 3,
1146
+ "id": "34495235-4dca-4635-b468-5b15647a6682",
1147
+ "metadata": {
1148
+ "tags": []
1149
+ },
1150
+ "outputs": [],
1151
+ "source": [
1152
+ "pipe = pipeline(\"text-classification\", model=inference_model, tokenizer=inference_tokenizer, top_k=None)"
1153
+ ]
1154
+ },
1155
+ {
1156
+ "cell_type": "code",
1157
+ "execution_count": 4,
1158
+ "id": "052b5070-c1ee-4419-8a6d-127925c95cce",
1159
+ "metadata": {
1160
+ "tags": []
1161
+ },
1162
+ "outputs": [],
1163
+ "source": [
1164
+ "def top_pct(preds, threshold=.95):\n",
1165
+ " preds = sorted(preds, key=lambda x: -x[\"score\"])\n",
1166
+ " \n",
1167
+ " cum_score = 0\n",
1168
+ " for i, item in enumerate(preds):\n",
1169
+ " cum_score += item[\"score\"]\n",
1170
+ " if cum_score >= threshold:\n",
1171
+ " break\n",
1172
+ "\n",
1173
+ " preds = preds[:(i+1)]\n",
1174
+ " \n",
1175
+ " return preds"
1176
+ ]
1177
+ },
1178
+ {
1179
+ "cell_type": "code",
1180
+ "execution_count": 5,
1181
+ "id": "ed3545b6-e043-4dfb-aeb2-7559eac37f7c",
1182
+ "metadata": {
1183
+ "tags": []
1184
+ },
1185
+ "outputs": [],
1186
+ "source": [
1187
+ "def format_predictions(preds) -> str:\n",
1188
+ " \"\"\"\n",
1189
+ " Prepare predictions and their scores for printing to the user\n",
1190
+ " \"\"\"\n",
1191
+ " out = \"\"\n",
1192
+ " for i, item in enumerate(preds):\n",
1193
+ " out += f\"{i+1}. {item['label']} (score {item['score']:.2f})\\n\"\n",
1194
+ " return out"
1195
+ ]
1196
+ },
1197
+ {
1198
+ "cell_type": "code",
1199
+ "execution_count": 9,
1200
+ "id": "870d593a-a298-4d55-87b0-cb2813cc1fad",
1201
+ "metadata": {
1202
+ "tags": []
1203
+ },
1204
+ "outputs": [
1205
+ {
1206
+ "name": "stdout",
1207
+ "output_type": "stream",
1208
+ "text": [
1209
+ "1. cs.LG (score 0.88)\n",
1210
+ "2. cs.AI (score 0.07)\n",
1211
+ "3. cs.NE (score 0.03)\n",
1212
+ "\n"
1213
+ ]
1214
+ }
1215
+ ],
1216
+ "source": [
1217
+ "print(\n",
1218
+ " format_predictions(\n",
1219
+ " top_pct(\n",
1220
+ " pipe(\"Attention Is All You Need\\nThe dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration.\")[0]\n",
1221
+ " )\n",
1222
+ " )\n",
1223
+ ")"
1224
+ ]
1225
+ },
1226
+ {
1227
+ "cell_type": "markdown",
1228
+ "id": "408f015e-be23-46a6-9e91-503fdccecf11",
1229
+ "metadata": {},
1230
+ "source": [
1231
+ " "
1232
+ ]
1233
+ }
1234
+ ],
1235
+ "metadata": {
1236
+ "kernelspec": {
1237
+ "display_name": "Python 3 (ipykernel)",
1238
+ "language": "python",
1239
+ "name": "python3"
1240
+ },
1241
+ "language_info": {
1242
+ "codemirror_mode": {
1243
+ "name": "ipython",
1244
+ "version": 3
1245
+ },
1246
+ "file_extension": ".py",
1247
+ "mimetype": "text/x-python",
1248
+ "name": "python",
1249
+ "nbconvert_exporter": "python",
1250
+ "pygments_lexer": "ipython3",
1251
+ "version": "3.10.8"
1252
+ }
1253
+ },
1254
+ "nbformat": 4,
1255
+ "nbformat_minor": 5
1256
+ }