File size: 4,369 Bytes
158e9d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
# based on https://github.com/EleutherAI/gpt-neox/blob/main/megatron/tokenizer/tokenizer.py
from __future__ import annotations
import torch
import numpy as np
from os import PathLike
from typing import List, Tuple
from tokenizers import Tokenizer
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.tokenization_utils_base import BatchEncoding, TruncationStrategy
from transformers.utils.generic import TensorType, PaddingStrategy
EMPTY: str = ""
class ByteTokenizer(PreTrainedTokenizer):
"""UTF-8 Encoder."""
@classmethod
def from_pretrained(cls, model_id: str | PathLike, **kwargs) -> ByteTokenizer:
return cls(**kwargs, byte_level=True)
@property
def vocab_size(self) -> int:
return 512
@property
def byte_level(self) -> bool:
return self.init_kwargs.get('byte_level', True)
def get_vocab(self) -> Dict[str, int]:
return {chr(i): i for i in range(self.vocab_size)}
def __len__(self) -> int:
return self.vocab_size
def clamp(self, n: int) -> int:
return max(32, min(n, self.vocab_size))
def _tokenize(self, text: str, **kwargs) -> List[str]:
return list(text)
def byte_tokenize(self, text: str) -> np.ndarray:
return np.frombuffer(text.encode('utf-8'), dtype=np.uint8)
def _convert_token_to_id(self, token: str) -> int:
return self.clamp(ord(token))
def _convert_id_to_token(self, index: int) -> str:
return chr(self.clamp(index))
def convert_tokens_to_string(self, tokens: List[str]) -> str:
return EMPTY.join(tokens)
def _decode(self, token_ids: List[int], **kwargs) -> str:
indices = np.asarray(token_ids, dtype=np.uint8)
return (
indices.clip(min=32, max=self.vocab_size, out=indices)
.tobytes()
.decode('utf-8')
)
def _encode_plus(self, text: str, **kwargs) -> BatchEncoding:
first_ids = self.byte_tokenize(text).tolist()
return self.prepare_for_model(
first_ids,
pair_ids=None,
add_special_tokens=kwargs.get('add_special_tokens', False),
padding=kwargs.get('padding_strategy', PaddingStrategy.DO_NOT_PAD).value,
truncation=kwargs.get('truncation_strategy', TruncationStrategy.DO_NOT_TRUNCATE).value,
max_length=kwargs.get('max_length'),
stride=kwargs.get('stride', 0),
pad_to_multiple_of=kwargs.get('pad_to_multiple_of'),
return_tensors=kwargs.get('return_tensors'),
prepend_batch_axis=True,
return_attention_mask=kwargs.get('return_attention_mask'),
return_token_type_ids=kwargs.get('return_token_type_ids'),
return_overflowing_tokens=kwargs.get('return_overflowing_tokens', False),
return_special_tokens_mask=kwargs.get('return_special_tokens_mask', False),
return_length=kwargs.get('return_length', False),
verbose=kwargs.get('verbose', True),
)
def _batch_encode_plus(self, batch_text: List[str], **kwargs) -> BatchEncoding:
input_ids = [(self.byte_tokenize(text).tolist(), None) for text in batch_text]
return self._batch_prepare_for_model(
input_ids,
add_special_tokens=kwargs.get('add_special_tokens', False),
padding_strategy=kwargs.get('padding_strategy', PaddingStrategy.DO_NOT_PAD),
truncation_strategy=kwargs.get('truncation_strategy', TruncationStrategy.DO_NOT_TRUNCATE),
max_length=kwargs.get('max_length'),
stride=kwargs.get('stride', 0),
pad_to_multiple_of=kwargs.get('pad_to_multiple_of'),
return_attention_mask=kwargs.get('return_attention_mask'),
return_token_type_ids=kwargs.get('return_token_type_ids'),
return_overflowing_tokens=kwargs.get('return_overflowing_tokens', False),
return_special_tokens_mask=kwargs.get('return_special_tokens_mask', False),
return_length=kwargs.get('return_length', False),
return_tensors=kwargs.get('return_tensors'),
verbose=kwargs.get('verbose', True),
)
def _save_pretrained(
self, save_directory: str | PathLike, file_names: Tuple[str], **kwargs
) -> Tuple[str]:
return file_names
|