File size: 4,187 Bytes
27140ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa9cf26
27140ac
 
a35de04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27140ac
 
 
 
 
 
 
a35de04
27140ac
a35de04
 
 
 
 
27140ac
a35de04
27140ac
 
 
 
 
 
 
 
 
 
 
 
 
a35de04
27140ac
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# based on https://github.com/EleutherAI/gpt-neox/blob/main/megatron/tokenizer/tokenizer.py
from abc import ABC
import json
import pathlib

import torch
import tqdm
from tokenizers import Tokenizer
from abc import abstractmethod
from typing import Any, List, Union
import numpy as np


class HFAutoTokenizer:
    def __init__(self, vocab_file):
        self.tokenizer = Tokenizer.from_file(vocab_file)
        self.eos = "</s>"
        self.bos = "<s>"
        self.eos_id = self.tokenize(self.eos)
        self.bos_id = self.tokenize(self.bos)
        self.vsize = 32000

    def encode_to_list(self, text):
        return self.tokenizer.encode(text, add_special_tokens=False)

    def tokenize_file(self, input_file, output_file, verbose=False):
        if verbose:
            print(f"Tokenizing file: {input_file}")

        if pathlib.Path(output_file).exists():
            print(f"Output file {output_file} already exists, skipping")
            return
        with open(input_file, "r") as fin, open(output_file, "w") as fout:
            for line in tqdm.tqdm(fin):
                if verbose:
                    print(f"Tokenizing line: {line[-200:]}")
                data = json.loads(line.strip())
                if "text" not in data.keys():
                    break
                tokenized_data = self.tokenize(data["text"])
                fout.write(json.dumps({"tokens": tokenized_data}) + "\n")

    def tokenize(self, text: str, *args, **kwargs):
        ids = self.tokenizer.encode(text)
        if type(ids) == list:
            return torch.tensor(ids)
        else:
            return torch.tensor(ids.ids)

    def tokenize_batch(self, text_batch):
        return self.tokenizer.encode_batch(text_batch)

    def detokenize(self, token_ids, skip_special_tokens=False):
        return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)

    def detokenize_batch(self, token_ids_batch, skip_special_tokens=False):
        out = []
        for token_ids in token_ids_batch:
            out.append(
                self.detokenize(
                    [t.item() for t in token_ids],
                    skip_special_tokens=skip_special_tokens,
                )
            )
        return out

    @property
    def eod(self):
        return self.eod_id

    @property
    def vocab_size(self):
        return 32000


class ByteTokenizer(PreTrainedTokenizer):
    """UTF-8 Encoder."""
    def __init__(self):
        super().__init__(
            bos_token=self.decode_token(2),
            eos_token=self.decode_token(0),
            unk_token=self.decode_token(0),
            pad_token=self.decode_token(1),
            mask_token=self.decode_token(3),
        )
    
    @property
    def vocab_size(self) -> int:
        return 512

    @classmethod
    def from_pretrained(cls, *args, **kwargs):
        return cls()

    def get_vocab(self):
        return {str(i): i for i in range(512)}

    def clamp(self, n):
        return max(32, min(n, self.vocab_size))

    def decode_token(self, token: int):
        return str(chr(self.clamp(token)))

    def __call__(self, text: str, return_tensors: bool = False, *args, **kwargs):
        ids = torch.tensor(self.tokenize(text), dtype=torch.long).unsqueeze(0)
        return {"input_ids": ids} if return_tensors == False else ids

    def _tokenize(self, text: str):
        return np.frombuffer(text.encode('utf-8'), dtype=np.uint8)
        
    def tokenize(self, text: str):
        return self._tokenize(text).tolist()

    def tokenize_batch(self, text_batch: Union[List[str], str]):
        if isinstance(text_batch, list):
            return [self.tokenize(s) for s in text_batch]
        else:
            return self.tokenize(text_batch)

    def decode(self, token_ids):
        return "".join(list(map(self.decode_token, token_ids)))

    def decode_batch(self, token_ids: Union[List[str], str]):
        if isinstance(token_ids, list):
            return [self.decode(s) for s in token_ids]
        
        elif isinstance(token_ids, torch.Tensor):
            return [self.decode(s) for s in token_ids.tolist()]
        else:
            return self.decode(token_ids)