ElisonSherton commited on
Commit
5124462
·
1 Parent(s): 27067f9

Added the notebook which created this finetuned model

Browse files
Files changed (1) hide show
  1. custom-ner.ipynb +784 -0
custom-ner.ipynb ADDED
@@ -0,0 +1,784 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from transformers import AutoModelForTokenClassification\n",
10
+ "from transformers import AutoTokenizer\n",
11
+ "\n",
12
+ "from datasets import load_dataset\n",
13
+ "from pprint import pprint\n",
14
+ "from collections import Counter\n",
15
+ "import random\n",
16
+ "import evaluate\n",
17
+ "import numpy as np\n",
18
+ "\n",
19
+ "import os\n",
20
+ "from huggingface_hub import login\n",
21
+ "from transformers import TrainingArguments, Trainer\n",
22
+ "from transformers import DataCollatorForTokenClassification"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": 3,
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "# Define the checkpoint and get access to the huggingface token for uploading the model to huggingface hub\n",
32
+ "checkpoint = \"bert-base-cased\"\n",
33
+ "os.environ[\"HF_TOKEN\"] = open(\n",
34
+ " \"/home/hf/hf-course/chapter7/hf-token.txt\", \"r\").readlines()[0]"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": 4,
40
+ "metadata": {},
41
+ "outputs": [
42
+ {
43
+ "data": {
44
+ "text/plain": [
45
+ "DatasetDict({\n",
46
+ " train: Dataset({\n",
47
+ " features: ['text', 'entities', 'entities-suggestion', 'entities-suggestion-metadata', 'external_id', 'metadata'],\n",
48
+ " num_rows: 8528\n",
49
+ " })\n",
50
+ " validation: Dataset({\n",
51
+ " features: ['text', 'entities', 'entities-suggestion', 'entities-suggestion-metadata', 'external_id', 'metadata'],\n",
52
+ " num_rows: 8528\n",
53
+ " })\n",
54
+ "})"
55
+ ]
56
+ },
57
+ "execution_count": 4,
58
+ "metadata": {},
59
+ "output_type": "execute_result"
60
+ }
61
+ ],
62
+ "source": [
63
+ "# Load the dataset\n",
64
+ "dataset = load_dataset(\"louisguitton/dev-ner-ontonotes\")\n",
65
+ "dataset"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "execution_count": 5,
71
+ "metadata": {},
72
+ "outputs": [
73
+ {
74
+ "name": "stdout",
75
+ "output_type": "stream",
76
+ "text": [
77
+ "{'entities': [],\n",
78
+ " 'entities-suggestion': {'end': [30],\n",
79
+ " 'label': ['PERSON'],\n",
80
+ " 'score': [1.0],\n",
81
+ " 'start': [23],\n",
82
+ " 'text': ['Camilla']},\n",
83
+ " 'entities-suggestion-metadata': {'agent': 'gold_labels',\n",
84
+ " 'score': None,\n",
85
+ " 'type': None},\n",
86
+ " 'external_id': None,\n",
87
+ " 'metadata': '{}',\n",
88
+ " 'text': 'The horse is basically Camilla /.'}\n"
89
+ ]
90
+ }
91
+ ],
92
+ "source": [
93
+ "# Have a look at one sample example in the dataset\n",
94
+ "pprint(dataset[\"train\"].shuffle().take(1)[0])"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": 6,
100
+ "metadata": {},
101
+ "outputs": [
102
+ {
103
+ "name": "stdout",
104
+ "output_type": "stream",
105
+ "text": [
106
+ "['O', 'B-CARDINAL', 'I-CARDINAL', 'B-DATE', 'I-DATE', 'B-EVENT', 'I-EVENT', 'B-FAC', 'I-FAC', 'B-GPE', 'I-GPE', 'B-LANGUAGE', 'I-LANGUAGE', 'B-LAW', 'I-LAW', 'B-LOC', 'I-LOC', 'B-MONEY', 'I-MONEY', 'B-NORP', 'I-NORP', 'B-ORDINAL', 'I-ORDINAL', 'B-ORG', 'I-ORG', 'B-PERCENT', 'I-PERCENT', 'B-PERSON', 'I-PERSON', 'B-PRODUCT', 'I-PRODUCT', 'B-QUANTITY', 'I-QUANTITY', 'B-TIME', 'I-TIME', 'B-WORK_OF_ART', 'I-WORK_OF_ART']\n",
107
+ "Counter({'GPE': 2268, 'PERSON': 2020, 'ORG': 1740, 'DATE': 1507, 'CARDINAL': 938, 'NORP': 847, 'MONEY': 274, 'ORDINAL': 232, 'TIME': 214, 'LOC': 204, 'PERCENT': 177, 'EVENT': 143, 'WORK_OF_ART': 142, 'FAC': 115, 'QUANTITY': 100, 'PRODUCT': 72, 'LAW': 40, 'LANGUAGE': 33})\n"
108
+ ]
109
+ }
110
+ ],
111
+ "source": [
112
+ "# Have a look at the distribution of all the labels\n",
113
+ "entity_types = []\n",
114
+ "\n",
115
+ "for element in dataset[\"train\"]:\n",
116
+ " entity_types.extend(element[\"entities-suggestion\"][\"label\"])\n",
117
+ "\n",
118
+ "entities = sorted(set(entity_types))\n",
119
+ "final_entities = [\"O\"]\n",
120
+ "for entity in entities:\n",
121
+ " final_entities.extend([f\"B-{entity}\", f\"I-{entity}\"])\n",
122
+ "print(final_entities)\n",
123
+ "print(Counter(entity_types))"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": 7,
129
+ "metadata": {},
130
+ "outputs": [],
131
+ "source": [
132
+ "# Create a couple of dictionaries to map all the entities to integer ids and vice versa\n",
133
+ "id2label = {i: label for i, label in enumerate(final_entities)}\n",
134
+ "label2id = {v: k for k, v in id2label.items()}"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": 8,
140
+ "metadata": {},
141
+ "outputs": [
142
+ {
143
+ "name": "stderr",
144
+ "output_type": "stream",
145
+ "text": [
146
+ "/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
147
+ " warnings.warn(\n"
148
+ ]
149
+ }
150
+ ],
151
+ "source": [
152
+ "# Create the tokenizer\n",
153
+ "tokenizer = AutoTokenizer.from_pretrained(checkpoint)"
154
+ ]
155
+ },
156
+ {
157
+ "cell_type": "code",
158
+ "execution_count": 9,
159
+ "metadata": {},
160
+ "outputs": [
161
+ {
162
+ "name": "stdout",
163
+ "output_type": "stream",
164
+ "text": [
165
+ "BertTokenizerFast(name_or_path='bert-base-cased', vocab_size=28996, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})\n"
166
+ ]
167
+ }
168
+ ],
169
+ "source": [
170
+ "# Have a look at the tokenizer\n",
171
+ "pprint(tokenizer)"
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "code",
176
+ "execution_count": 10,
177
+ "metadata": {},
178
+ "outputs": [],
179
+ "source": [
180
+ "# Tokenize one sample and check what all is returned\n",
181
+ "output = tokenizer(dataset[\"train\"][0][\"text\"], return_offsets_mapping=True)"
182
+ ]
183
+ },
184
+ {
185
+ "cell_type": "code",
186
+ "execution_count": 11,
187
+ "metadata": {},
188
+ "outputs": [
189
+ {
190
+ "data": {
191
+ "text/plain": [
192
+ "dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping'])"
193
+ ]
194
+ },
195
+ "execution_count": 11,
196
+ "metadata": {},
197
+ "output_type": "execute_result"
198
+ }
199
+ ],
200
+ "source": [
201
+ "output.keys()"
202
+ ]
203
+ },
204
+ {
205
+ "cell_type": "code",
206
+ "execution_count": 12,
207
+ "metadata": {},
208
+ "outputs": [
209
+ {
210
+ "data": {
211
+ "text/plain": [
212
+ "{'start': [2, 40, 53, 108, 122],\n",
213
+ " 'end': [9, 45, 56, 113, 137],\n",
214
+ " 'label': ['NORP', 'CARDINAL', 'CARDINAL', 'PRODUCT', 'LOC'],\n",
215
+ " 'text': ['Russian', 'three', '118', 'Kursk', 'the Barents Sea'],\n",
216
+ " 'score': [1.0, 1.0, 1.0, 1.0, 1.0]}"
217
+ ]
218
+ },
219
+ "execution_count": 12,
220
+ "metadata": {},
221
+ "output_type": "execute_result"
222
+ }
223
+ ],
224
+ "source": [
225
+ "# Have a look at the entities\n",
226
+ "dataset[\"train\"][\"entities-suggestion\"][0]"
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "code",
231
+ "execution_count": 13,
232
+ "metadata": {},
233
+ "outputs": [],
234
+ "source": [
235
+ "def in_span(source_start, source_end, target_start, target_end):\n",
236
+ " \"\"\"\n",
237
+ " Function to check if the target span is contained within the source span\n",
238
+ " \"\"\"\n",
239
+ " if (target_start >= source_start) and (target_end <= source_end):\n",
240
+ " return True\n",
241
+ " return False\n",
242
+ "\n",
243
+ "\n",
244
+ "def tokenize_and_create_labels(example):\n",
245
+ " \"\"\"\n",
246
+ " Function to tokenize the example and subsequently create labels. The labels provided will not be aligned with the tokens (after wordpiece tokenization); hence this step.\n",
247
+ " \"\"\"\n",
248
+ " outputs = tokenizer(\n",
249
+ " example[\"text\"], truncation=True, return_offsets_mapping=True)\n",
250
+ "\n",
251
+ " output_labels = []\n",
252
+ " n_samples = len(example[\"text\"])\n",
253
+ "\n",
254
+ " # Do for all the samples in the batch\n",
255
+ " for i in range(n_samples):\n",
256
+ " # Do not take the first and last offsets as they belong to a special token (CLS and SEP respectively)\n",
257
+ " offsets = outputs[\"offset_mapping\"][i][1:-1]\n",
258
+ " num_tokens = len(offsets)\n",
259
+ "\n",
260
+ " # Entity spans\n",
261
+ " entity_starts = example[\"entities-suggestion\"][i][\"start\"]\n",
262
+ " entity_ends = example[\"entities-suggestion\"][i][\"end\"]\n",
263
+ "\n",
264
+ " # Labels and their number\n",
265
+ " text_labels = example[\"entities-suggestion\"][i][\"label\"]\n",
266
+ " num_entities = len(text_labels)\n",
267
+ "\n",
268
+ " labels = []\n",
269
+ "\n",
270
+ " entities = example[\"entities-suggestion\"][i]\n",
271
+ "\n",
272
+ " # If there are no spans, it will all be a list of Os\n",
273
+ " if len(entities[\"start\"]) == 0:\n",
274
+ " labels = [label2id[\"O\"] for _ in range(num_tokens)]\n",
275
+ " # Otherwise check span by span\n",
276
+ " else:\n",
277
+ " idx = 0\n",
278
+ " source_start, source_end = entity_starts[idx], entity_ends[idx]\n",
279
+ " previous_label = \"O\"\n",
280
+ "\n",
281
+ " for loop_idx, (start, end) in enumerate(offsets):\n",
282
+ " # By default, the token is an O token\n",
283
+ " lab = \"O\"\n",
284
+ "\n",
285
+ " # While you have not exceeded the number of identities provided\n",
286
+ " if idx < num_entities:\n",
287
+ " # While you have not stepped ahead of the next identity span\n",
288
+ " if start > source_end:\n",
289
+ " # If you have reached the end of the identities annotated, simply fill in the remainder of the tokens as O\n",
290
+ " if idx == num_entities - 1:\n",
291
+ " lab = \"O\"\n",
292
+ " remainder = [\n",
293
+ " label2id[\"O\"] for _ in range(num_tokens - loop_idx)\n",
294
+ " ]\n",
295
+ " labels.extend(remainder)\n",
296
+ " break\n",
297
+ " else:\n",
298
+ " idx += 1\n",
299
+ "\n",
300
+ " # If the idx is refreshed, then consider new span\n",
301
+ " source_start, source_end = entity_starts[idx], entity_ends[idx]\n",
302
+ "\n",
303
+ " # Check if current token is within the source span\n",
304
+ " if in_span(source_start, source_end, start, end):\n",
305
+ " # Check if the previous label was an O, if so then this one would begin with a B- else an I-\n",
306
+ " lab = \"B-\" if previous_label == \"O\" else \"I-\"\n",
307
+ " lab = lab + text_labels[idx]\n",
308
+ " else:\n",
309
+ " lab = \"O\"\n",
310
+ "\n",
311
+ " labels.append(label2id[lab])\n",
312
+ " previous_label = lab\n",
313
+ " # The first and last tokens are reserved for special words [CLS] and [SEP], hence modify their indices accordingly\n",
314
+ " output_labels.append([-100] + labels + [-100])\n",
315
+ " outputs[\"labels\"] = output_labels\n",
316
+ "\n",
317
+ " return outputs"
318
+ ]
319
+ },
320
+ {
321
+ "cell_type": "code",
322
+ "execution_count": 14,
323
+ "metadata": {},
324
+ "outputs": [],
325
+ "source": [
326
+ "tokenized_dataset = dataset.map(tokenize_and_create_labels, batched=True,\n",
327
+ " remove_columns=dataset[\"train\"].column_names)"
328
+ ]
329
+ },
330
+ {
331
+ "cell_type": "code",
332
+ "execution_count": 15,
333
+ "metadata": {},
334
+ "outputs": [
335
+ {
336
+ "data": {
337
+ "application/vnd.jupyter.widget-view+json": {
338
+ "model_id": "14b7a117c7c4418aa3d0d08eb7563add",
339
+ "version_major": 2,
340
+ "version_minor": 0
341
+ },
342
+ "text/plain": [
343
+ "Map: 0%| | 0/5 [00:00<?, ? examples/s]"
344
+ ]
345
+ },
346
+ "metadata": {},
347
+ "output_type": "display_data"
348
+ }
349
+ ],
350
+ "source": [
351
+ "# Create a sample of 5 items for the sake of visualization\n",
352
+ "samples = dataset[\"train\"].shuffle(seed=43).take(5).map(\n",
353
+ " tokenize_and_create_labels, batched=True)"
354
+ ]
355
+ },
356
+ {
357
+ "cell_type": "code",
358
+ "execution_count": 16,
359
+ "metadata": {},
360
+ "outputs": [
361
+ {
362
+ "name": "stdout",
363
+ "output_type": "stream",
364
+ "text": [
365
+ "[CLS] An easy but rare maneuver with extraordinary consequences / . [SEP] \n",
366
+ "SPECIAL O O O O O O O O O O SPECIAL \n",
367
+ "Number of tokens: 12, Number of Labels: 12\n",
368
+ "Entities Annotated: {'start': [], 'end': [], 'label': [], 'text': [], 'score': []}\n"
369
+ ]
370
+ }
371
+ ],
372
+ "source": [
373
+ "# Visualize a few samples from the dataset randomly\n",
374
+ "idx = random.randint(0, len(samples))\n",
375
+ "\n",
376
+ "ip_tokens = [tokenizer.decode([x]) for x in samples[idx][\"input_ids\"]]\n",
377
+ "labels = samples[idx][\"labels\"]\n",
378
+ "\n",
379
+ "token_op, lbl_op = \"\", \"\"\n",
380
+ "for token, lbl in zip(ip_tokens, labels):\n",
381
+ " lbl = id2label.get(lbl, \"SPECIAL\")\n",
382
+ " l = max(len(token), len(lbl)) + 2\n",
383
+ " token_op += f\"{token:<{l}}\"\n",
384
+ " lbl_op += f\"{lbl:<{l}}\"\n",
385
+ "\n",
386
+ "print(token_op)\n",
387
+ "print(lbl_op)\n",
388
+ "print(f\"Number of tokens: {len(ip_tokens)}, Number of Labels: {len(labels)}\")\n",
389
+ "print(\"Entities Annotated: \", samples[idx][\"entities-suggestion\"])"
390
+ ]
391
+ },
392
+ {
393
+ "cell_type": "code",
394
+ "execution_count": 17,
395
+ "metadata": {},
396
+ "outputs": [],
397
+ "source": [
398
+ "# We need to remove the offset mappings as it would not be possible to colalte data without dropping this column\n",
399
+ "tokenized_dataset = tokenized_dataset.remove_columns(\n",
400
+ " column_names=[\"offset_mapping\"])"
401
+ ]
402
+ },
403
+ {
404
+ "cell_type": "code",
405
+ "execution_count": 18,
406
+ "metadata": {},
407
+ "outputs": [
408
+ {
409
+ "name": "stderr",
410
+ "output_type": "stream",
411
+ "text": [
412
+ "You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
413
+ ]
414
+ },
415
+ {
416
+ "data": {
417
+ "text/plain": [
418
+ "tensor([[-100, 0, 19, 0, 0, 0, 0, 0, 0, 1, 0, 0,\n",
419
+ " 1, 0, 0, 0, 0, 0, 0, 0, 0, 29, 30, 0,\n",
420
+ " 0, 15, 16, 16, 16, 0, -100],\n",
421
+ " [-100, 0, 0, 0, 0, 0, 0, 0, 19, 0, 19, 0,\n",
422
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
423
+ " 0, 0, -100, -100, -100, -100, -100]])"
424
+ ]
425
+ },
426
+ "execution_count": 18,
427
+ "metadata": {},
428
+ "output_type": "execute_result"
429
+ }
430
+ ],
431
+ "source": [
432
+ "# Create a data collator to apply padding as and when necessary and have a look at the working of the same\n",
433
+ "data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)\n",
434
+ "batch = data_collator([tokenized_dataset[\"train\"][i] for i in range(2)])\n",
435
+ "batch[\"labels\"]"
436
+ ]
437
+ },
438
+ {
439
+ "cell_type": "code",
440
+ "execution_count": 20,
441
+ "metadata": {},
442
+ "outputs": [],
443
+ "source": [
444
+ "metric = evaluate.load(\"seqeval\")\n",
445
+ "\n",
446
+ "def compute_metrics(eval_preds):\n",
447
+ " logits, labels = eval_preds\n",
448
+ "\n",
449
+ " # Get the most probable token prediction\n",
450
+ " predictions = np.argmax(logits, axis=-1)\n",
451
+ "\n",
452
+ " # Remove ignored index (special tokens) and convert to labels\n",
453
+ " true_labels, true_predictions = [], []\n",
454
+ " for prediction, label in zip(predictions, labels):\n",
455
+ " current_prediction, current_label = [], []\n",
456
+ " for p, l in zip(prediction, label):\n",
457
+ " if l != -100:\n",
458
+ " current_label.append(id2label[l])\n",
459
+ " current_prediction.append(id2label[p])\n",
460
+ " true_labels.append(current_label)\n",
461
+ " true_predictions.append(current_prediction)\n",
462
+ "\n",
463
+ " # Compute the metrics using above predictions and labels\n",
464
+ " all_metrics = metric.compute(\n",
465
+ " predictions=true_predictions, references=true_labels)\n",
466
+ "\n",
467
+ " # Return the overall metrics and not individual level metrics\n",
468
+ " return {\n",
469
+ " \"precision\": all_metrics[\"overall_precision\"],\n",
470
+ " \"recall\": all_metrics[\"overall_recall\"],\n",
471
+ " \"f1\": all_metrics[\"overall_f1\"],\n",
472
+ " \"accuracy\": all_metrics[\"overall_accuracy\"],\n",
473
+ " }"
474
+ ]
475
+ },
476
+ {
477
+ "cell_type": "code",
478
+ "execution_count": 21,
479
+ "metadata": {},
480
+ "outputs": [
481
+ {
482
+ "name": "stderr",
483
+ "output_type": "stream",
484
+ "text": [
485
+ "Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForTokenClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']\n",
486
+ "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
487
+ "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
488
+ "Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
489
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
490
+ ]
491
+ }
492
+ ],
493
+ "source": [
494
+ "# Create a model for token classification on top of pretrained BERT model\n",
495
+ "model = AutoModelForTokenClassification.from_pretrained(\n",
496
+ " checkpoint,\n",
497
+ " id2label=id2label,\n",
498
+ " label2id=label2id\n",
499
+ ")"
500
+ ]
501
+ },
502
+ {
503
+ "cell_type": "code",
504
+ "execution_count": 22,
505
+ "metadata": {},
506
+ "outputs": [
507
+ {
508
+ "data": {
509
+ "text/plain": [
510
+ "Linear(in_features=768, out_features=37, bias=True)"
511
+ ]
512
+ },
513
+ "execution_count": 22,
514
+ "metadata": {},
515
+ "output_type": "execute_result"
516
+ }
517
+ ],
518
+ "source": [
519
+ "# Check the classifier architecture\n",
520
+ "model.classifier"
521
+ ]
522
+ },
523
+ {
524
+ "cell_type": "code",
525
+ "execution_count": 23,
526
+ "metadata": {},
527
+ "outputs": [
528
+ {
529
+ "data": {
530
+ "text/plain": [
531
+ "(37, 37, 37)"
532
+ ]
533
+ },
534
+ "execution_count": 23,
535
+ "metadata": {},
536
+ "output_type": "execute_result"
537
+ }
538
+ ],
539
+ "source": [
540
+ "# Have a look at the number of labels, the number of ids created for those labels and the number of activations in the final layer of the model\n",
541
+ "model.config.num_labels, len(label2id), len(id2label)"
542
+ ]
543
+ },
544
+ {
545
+ "cell_type": "code",
546
+ "execution_count": 24,
547
+ "metadata": {},
548
+ "outputs": [
549
+ {
550
+ "name": "stdout",
551
+ "output_type": "stream",
552
+ "text": [
553
+ "The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n",
554
+ "Token is valid (permission: write).\n",
555
+ "Your token has been saved to /home/.cache/huggingface/token\n",
556
+ "Login successful\n"
557
+ ]
558
+ }
559
+ ],
560
+ "source": [
561
+ "# Login to huggingface for uploading the generated model\n",
562
+ "login(token=os.environ.get(\"HF_TOKEN\"))"
563
+ ]
564
+ },
565
+ {
566
+ "cell_type": "code",
567
+ "execution_count": 27,
568
+ "metadata": {},
569
+ "outputs": [],
570
+ "source": [
571
+ "args = TrainingArguments(\n",
572
+ " \"dev-ner-ontonote-bert-finetuned\",\n",
573
+ " evaluation_strategy=\"epoch\",\n",
574
+ " save_strategy=\"epoch\",\n",
575
+ " learning_rate=2e-5,\n",
576
+ " num_train_epochs=5,\n",
577
+ " weight_decay=0.01,\n",
578
+ " push_to_hub=True,\n",
579
+ " per_device_train_batch_size=32,\n",
580
+ " per_device_eval_batch_size=32\n",
581
+ ")"
582
+ ]
583
+ },
584
+ {
585
+ "cell_type": "code",
586
+ "execution_count": 28,
587
+ "metadata": {},
588
+ "outputs": [
589
+ {
590
+ "data": {
591
+ "text/plain": [
592
+ "DatasetDict({\n",
593
+ " train: Dataset({\n",
594
+ " features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],\n",
595
+ " num_rows: 8528\n",
596
+ " })\n",
597
+ " validation: Dataset({\n",
598
+ " features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],\n",
599
+ " num_rows: 8528\n",
600
+ " })\n",
601
+ "})"
602
+ ]
603
+ },
604
+ "execution_count": 28,
605
+ "metadata": {},
606
+ "output_type": "execute_result"
607
+ }
608
+ ],
609
+ "source": [
610
+ "tokenized_dataset"
611
+ ]
612
+ },
613
+ {
614
+ "cell_type": "code",
615
+ "execution_count": 29,
616
+ "metadata": {},
617
+ "outputs": [
618
+ {
619
+ "name": "stderr",
620
+ "output_type": "stream",
621
+ "text": [
622
+ "/home/huggingface/lib/python3.10/site-packages/huggingface_hub/utils/_deprecation.py:131: FutureWarning: 'Repository' (from 'huggingface_hub.repository') is deprecated and will be removed from version '1.0'. Please prefer the http-based alternatives instead. Given its large adoption in legacy code, the complete removal is only planned on next major release.\n",
623
+ "For more details, please read https://huggingface.co/docs/huggingface_hub/concepts/git_vs_http.\n",
624
+ " warnings.warn(warning_message, FutureWarning)\n",
625
+ "/home/hf/hf-course/chapter7/dev-ner-ontonote-bert-finetuned is already a clone of https://huggingface.co/ElisonSherton/dev-ner-ontonote-bert-finetuned. Make sure you pull the latest changes with `repo.git_pull()`.\n",
626
+ "/home/huggingface/lib/python3.10/site-packages/transformers/optimization.py:391: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
627
+ " warnings.warn(\n"
628
+ ]
629
+ },
630
+ {
631
+ "data": {
632
+ "text/html": [
633
+ "\n",
634
+ " <div>\n",
635
+ " \n",
636
+ " <progress value='1335' max='1335' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
637
+ " [1335/1335 09:17, Epoch 5/5]\n",
638
+ " </div>\n",
639
+ " <table border=\"1\" class=\"dataframe\">\n",
640
+ " <thead>\n",
641
+ " <tr style=\"text-align: left;\">\n",
642
+ " <th>Epoch</th>\n",
643
+ " <th>Training Loss</th>\n",
644
+ " <th>Validation Loss</th>\n",
645
+ " <th>Precision</th>\n",
646
+ " <th>Recall</th>\n",
647
+ " <th>F1</th>\n",
648
+ " <th>Accuracy</th>\n",
649
+ " </tr>\n",
650
+ " </thead>\n",
651
+ " <tbody>\n",
652
+ " <tr>\n",
653
+ " <td>1</td>\n",
654
+ " <td>No log</td>\n",
655
+ " <td>0.111329</td>\n",
656
+ " <td>0.757552</td>\n",
657
+ " <td>0.797257</td>\n",
658
+ " <td>0.776898</td>\n",
659
+ " <td>0.968852</td>\n",
660
+ " </tr>\n",
661
+ " <tr>\n",
662
+ " <td>2</td>\n",
663
+ " <td>0.281100</td>\n",
664
+ " <td>0.055888</td>\n",
665
+ " <td>0.873178</td>\n",
666
+ " <td>0.908711</td>\n",
667
+ " <td>0.890590</td>\n",
668
+ " <td>0.984724</td>\n",
669
+ " </tr>\n",
670
+ " <tr>\n",
671
+ " <td>3</td>\n",
672
+ " <td>0.281100</td>\n",
673
+ " <td>0.035979</td>\n",
674
+ " <td>0.914701</td>\n",
675
+ " <td>0.947770</td>\n",
676
+ " <td>0.930942</td>\n",
677
+ " <td>0.990416</td>\n",
678
+ " </tr>\n",
679
+ " <tr>\n",
680
+ " <td>4</td>\n",
681
+ " <td>0.063000</td>\n",
682
+ " <td>0.027458</td>\n",
683
+ " <td>0.933327</td>\n",
684
+ " <td>0.960033</td>\n",
685
+ " <td>0.946492</td>\n",
686
+ " <td>0.992793</td>\n",
687
+ " </tr>\n",
688
+ " <tr>\n",
689
+ " <td>5</td>\n",
690
+ " <td>0.063000</td>\n",
691
+ " <td>0.024083</td>\n",
692
+ " <td>0.940449</td>\n",
693
+ " <td>0.966845</td>\n",
694
+ " <td>0.953464</td>\n",
695
+ " <td>0.993742</td>\n",
696
+ " </tr>\n",
697
+ " </tbody>\n",
698
+ "</table><p>"
699
+ ],
700
+ "text/plain": [
701
+ "<IPython.core.display.HTML object>"
702
+ ]
703
+ },
704
+ "metadata": {},
705
+ "output_type": "display_data"
706
+ },
707
+ {
708
+ "name": "stderr",
709
+ "output_type": "stream",
710
+ "text": [
711
+ "/home/huggingface/lib/python3.10/site-packages/seqeval/metrics/v1.py:57: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
712
+ " _warn_prf(average, modifier, msg_start, len(result))\n",
713
+ "/home/huggingface/lib/python3.10/site-packages/seqeval/metrics/v1.py:57: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
714
+ " _warn_prf(average, modifier, msg_start, len(result))\n"
715
+ ]
716
+ },
717
+ {
718
+ "data": {
719
+ "text/plain": [
720
+ "TrainOutput(global_step=1335, training_loss=0.1388676861252231, metrics={'train_runtime': 562.8544, 'train_samples_per_second': 75.757, 'train_steps_per_second': 2.372, 'total_flos': 1425922860395136.0, 'train_loss': 0.1388676861252231, 'epoch': 5.0})"
721
+ ]
722
+ },
723
+ "execution_count": 29,
724
+ "metadata": {},
725
+ "output_type": "execute_result"
726
+ }
727
+ ],
728
+ "source": [
729
+ "trainer = Trainer(\n",
730
+ " model=model,\n",
731
+ " args=args,\n",
732
+ " data_collator=data_collator,\n",
733
+ " train_dataset=tokenized_dataset[\"train\"],\n",
734
+ " eval_dataset=tokenized_dataset[\"validation\"],\n",
735
+ " compute_metrics=compute_metrics,\n",
736
+ " tokenizer=tokenizer\n",
737
+ ")\n",
738
+ "\n",
739
+ "trainer.train()"
740
+ ]
741
+ },
742
+ {
743
+ "cell_type": "code",
744
+ "execution_count": 30,
745
+ "metadata": {},
746
+ "outputs": [
747
+ {
748
+ "name": "stderr",
749
+ "output_type": "stream",
750
+ "text": [
751
+ "To https://huggingface.co/ElisonSherton/dev-ner-ontonote-bert-finetuned\n",
752
+ " 41c8386..27067f9 main -> main\n",
753
+ "\n"
754
+ ]
755
+ }
756
+ ],
757
+ "source": [
758
+ "trainer.push_to_hub(\n",
759
+ " commit_message=\"🤗 Training of first BERT based NER task completed!!\")"
760
+ ]
761
+ }
762
+ ],
763
+ "metadata": {
764
+ "kernelspec": {
765
+ "display_name": "Python 3 (ipykernel)",
766
+ "language": "python",
767
+ "name": "python3"
768
+ },
769
+ "language_info": {
770
+ "codemirror_mode": {
771
+ "name": "ipython",
772
+ "version": 3
773
+ },
774
+ "file_extension": ".py",
775
+ "mimetype": "text/x-python",
776
+ "name": "python",
777
+ "nbconvert_exporter": "python",
778
+ "pygments_lexer": "ipython3",
779
+ "version": "3.10.14"
780
+ }
781
+ },
782
+ "nbformat": 4,
783
+ "nbformat_minor": 4
784
+ }