shakespeare-demo / word_data.py
skar0's picture
Initial commit
4c2c4e8
import re
from typing import Optional, Union
import requests
from torch.utils.data import Dataset
import torch as t
class WordsDataset(Dataset):
def __init__(self, texts, labels):
self.texts = texts
self.labels = labels
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
label = self.labels[idx]
text = self.texts[idx]
sample = (text, label)
return sample
#%%
def tokenize(text):
return re.split(r"\b", text)
def _remove_duplicates(text, string=" "):
if string + string in text:
text = text.replace(string + string, string)
return _remove_duplicates(text, string)
return text
def remove_duplicates(text):
text = _remove_duplicates(text, ' ')
text = _remove_duplicates(text, '\n')
return text
# %%
class WordData():
def __init__(self, text, start, end, device):
self.complete_text = remove_duplicates(text)
if start is not None and end is not None:
self.complete_text = self.get_excerpt(start, end)
self.complete_tokens = tokenize(self.complete_text)
self.vocab = sorted(set(self.complete_tokens))
self.token_to_id = dict(zip(self.vocab, list(range(len(self.vocab)))))
self.id_to_token = dict(zip(list(range(len(self.vocab))), self.vocab))
self.model_max_length = None
self.device = device
@staticmethod
def from_link(link, device, start=None, end=None):
return WordData(
requests.get(link).content.decode('utf-8'),
start,
end,
device=device
)
@staticmethod
def from_file(filename, device, start=None, end=None):
with open(filename, encoding='utf-8') as f:
text = f.read()
return WordData(text, start, end, device=device)
def get_excerpt(self, start="THE SONNETS", end="THE END", text=None):
if text is None:
text = self.complete_text
assert start in text, f'get_excerpt: cannot find {start} in text'
l_stripped = text.split(start, maxsplit=1)[1]
assert end in l_stripped, f'get_excerpt: cannot find {end} in text'
r_stripped = l_stripped.split(end, maxsplit=1)[0]
return r_stripped
def generate_autoregressive_dataset(self, sequence_length, text=None):
self.model_max_length = sequence_length
if text is None:
text = self.complete_text
token_ids = self.encode(text, return_tensors="pt")
inputs = [token_ids[i:i + sequence_length] for i in range(len(token_ids) - sequence_length)]
labels = [token_ids[i + 1:i + 1 + sequence_length] for i in range(len(token_ids) - sequence_length)]
return WordsDataset(inputs, labels)
def encode(self, initial_text: str, return_tensors: Optional[str] = None) -> Union[list, t.Tensor]:
'''
Tokenizes initial_text, then returns the token ids.
Return type is list by default, but if return_tensors="pt" then it is returned as a tensor.
'''
tokens = tokenize(initial_text)
token_ids = [self.token_to_id[t] for t in tokens]
if return_tensors == "pt":
return t.tensor(token_ids, device=self.device)
return token_ids
def decode(self, list_of_ids: Union[t.Tensor, list]) -> str:
'''
Converts ids to a list of tokens, then joins them into a single string.
'''
tokens = [self.id_to_token[int(i)] for i in list_of_ids]
return "".join(tokens)