yairschiff's picture
Upload tokenizer
ebc7d8f verified
raw
history blame
4.97 kB
"""Character tokenizer for Hugging Face.
"""
from typing import List, Optional, Dict, Sequence, Tuple
from transformers import PreTrainedTokenizer
class CaduceusTokenizer(PreTrainedTokenizer):
model_input_names = ["input_ids"]
def __init__(self,
model_max_length: int,
characters: Sequence[str] = ("A", "C", "G", "T", "N"),
complement_map=None,
bos_token="[BOS]",
eos_token="[SEP]",
sep_token="[SEP]",
cls_token="[CLS]",
pad_token="[PAD]",
mask_token="[MASK]",
unk_token="[UNK]",
**kwargs):
"""Character tokenizer for Hugging Face transformers.
Adapted from https://huggingface.co/LongSafari/hyenadna-tiny-1k-seqlen-hf/blob/main/tokenization_hyena.py
Args:
model_max_length (int): Model maximum sequence length.
characters (Sequence[str]): List of desired characters. Any character which
is not included in this list will be replaced by a special token called
[UNK] with id=6. Following is a list of the special tokens with
their corresponding ids:
"[CLS]": 0
"[SEP]": 1
"[BOS]": 2
"[MASK]": 3
"[PAD]": 4
"[RESERVED]": 5
"[UNK]": 6
an id (starting at 7) will be assigned to each character.
complement_map (Optional[Dict[str, str]]): Dictionary with string complements for each character.
"""
if complement_map is None:
complement_map = {"A": "T", "C": "G", "G": "C", "T": "A", "N": "N"}
self.characters = characters
self.model_max_length = model_max_length
self._vocab_str_to_int = {
"[CLS]": 0,
"[SEP]": 1,
"[BOS]": 2,
"[MASK]": 3,
"[PAD]": 4,
"[RESERVED]": 5,
"[UNK]": 6,
**{ch: i + 7 for i, ch in enumerate(self.characters)},
}
self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}
add_prefix_space = kwargs.pop("add_prefix_space", False)
padding_side = kwargs.pop("padding_side", "left")
self._complement_map = {}
for k, v in self._vocab_str_to_int.items():
complement_id = self._vocab_str_to_int[complement_map[k]] if k in complement_map.keys() else v
self._complement_map[self._vocab_str_to_int[k]] = complement_id
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
unk_token=unk_token,
add_prefix_space=add_prefix_space,
model_max_length=model_max_length,
padding_side=padding_side,
**kwargs,
)
@property
def vocab_size(self) -> int:
return len(self._vocab_str_to_int)
@property
def complement_map(self) -> Dict[int, int]:
return self._complement_map
def _tokenize(self, text: str, **kwargs) -> List[str]:
return list(text.upper()) # Convert all base pairs to uppercase
def _convert_token_to_id(self, token: str) -> int:
return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"])
def _convert_id_to_token(self, index: int) -> str:
return self._vocab_int_to_str[index]
def convert_tokens_to_string(self, tokens):
return "".join(tokens) # Note: this operation has lost info about which base pairs were originally lowercase
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False,
) -> List[int]:
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0,
token_ids_1=token_ids_1,
already_has_special_tokens=True,
)
result = ([0] * len(token_ids_0)) + [1]
if token_ids_1 is not None:
result += ([0] * len(token_ids_1)) + [1]
return result
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
sep = [self.sep_token_id]
# cls = [self.cls_token_id]
result = token_ids_0 + sep
if token_ids_1 is not None:
result += token_ids_1 + sep
return result
def get_vocab(self) -> Dict[str, int]:
return self._vocab_str_to_int
# Fixed vocabulary with no vocab file
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple:
return ()