sayakpaul HF staff commited on
Commit
256477b
·
1 Parent(s): 2ea9793

Upload train_unigram.py

Browse files
Files changed (1) hide show
  1. train_unigram.py +119 -0
train_unigram.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python train_unigram.py --export_to_hub
4
+
5
+ Note that you'd need to execute `huggingface-cli login` before if you passed export_to_hub.
6
+
7
+ Reference:
8
+ https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/tokenizer_training.ipynb
9
+ """
10
+
11
+ import argparse
12
+ import logging
13
+
14
+ import datasets
15
+ import torch
16
+ from datasets import Dataset
17
+ from tokenizers import (
18
+ Tokenizer,
19
+ decoders,
20
+ normalizers,
21
+ pre_tokenizers,
22
+ processors,
23
+ )
24
+ from tokenizers.models import Unigram
25
+ from tokenizers.trainers import UnigramTrainer
26
+ from transformers import AlbertTokenizerFast
27
+
28
+
29
+ def parse_args():
30
+ parser = argparse.ArgumentParser(
31
+ description="Train a unigram tokenizer on the wikitext dataset."
32
+ )
33
+ parser.add_argument(
34
+ "-bs",
35
+ "--batch-size",
36
+ type=int,
37
+ default=1000,
38
+ help="Batch size during training.",
39
+ )
40
+ parser.add_argument(
41
+ "-vs",
42
+ "--vocab-size",
43
+ type=int,
44
+ default=10000,
45
+ help="Size of the desired vocabulary.",
46
+ )
47
+ parser.add_argument(
48
+ "--limit",
49
+ default=None,
50
+ type=int,
51
+ help="Limit the number of shards (used for debugging).",
52
+ )
53
+ parser.add_argument(
54
+ "--export_to_hub",
55
+ action="store_true",
56
+ )
57
+
58
+ args = parser.parse_args()
59
+ return args
60
+
61
+
62
+ def get_unigram_tokenizer() -> Tokenizer:
63
+ tokenizer = Tokenizer(Unigram())
64
+ tokenizer.normalizer = normalizers.Sequence(
65
+ [normalizers.Replace("``", '"'), normalizers.Replace("''", '"')]
66
+ )
67
+ tokenizer.pre_tokenizer = pre_tokenizers.Metaspace()
68
+ return tokenizer
69
+
70
+
71
+ def get_unigram_trainer(vocab_size: int) -> UnigramTrainer:
72
+ trainer = UnigramTrainer(
73
+ unk_token="<unk>",
74
+ special_tokens=["[CLS]", "[SEP]", "<unk>", "<pad>", "[MASK]"],
75
+ vocab_size=vocab_size,
76
+ )
77
+ return trainer
78
+
79
+
80
+ def main(args):
81
+ wikitext = datasets.load_dataset(
82
+ "wikitext", "wikitext-103-raw-v1", split="train"
83
+ )
84
+
85
+ if args.limit is not None:
86
+ wikitext = wikitext[: args.limit]
87
+ wikitext = Dataset.from_dict(wikitext)
88
+ logging.info(f"Limiting the dataset to {args.limit} entries.")
89
+
90
+ dataloader = torch.utils.data.DataLoader(
91
+ wikitext, num_workers=0, batch_size=args.batch_size
92
+ )
93
+ logging.info("Training the tokenizer.")
94
+ tokenizer = get_unigram_tokenizer()
95
+ trainer = get_unigram_trainer(args.vocab_size)
96
+ tokenizer.train_from_iterator(dataloader, trainer=trainer)
97
+ logging.info("Tokenizer training complete!")
98
+
99
+ cls_token_id = tokenizer.token_to_id("[CLS]")
100
+ sep_token_id = tokenizer.token_to_id("[SEP]")
101
+ tokenizer.post_processor = processors.TemplateProcessing(
102
+ single="[CLS]:0 $A:0 [SEP]:0",
103
+ pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
104
+ special_tokens=[
105
+ ("[CLS]", cls_token_id),
106
+ ("[SEP]", sep_token_id),
107
+ ],
108
+ )
109
+ tokenizer.decoder = decoders.Metaspace()
110
+
111
+ if args.export_to_hub:
112
+ logging.info("Exporting the trained tokenzier to Hub.")
113
+ new_tokenizer = AlbertTokenizerFast(tokenizer_object=tokenizer)
114
+ new_tokenizer.push_to_hub("sayakpaul/unigram-tokenizer-wikitext")
115
+
116
+
117
+ if __name__ == "__main__":
118
+ args = parse_args()
119
+ main(args)