Spaces:
Running
Running
import gradio as gr | |
import torch | |
from peft import PeftModel, PeftConfig | |
from transformers import AutoModelForTokenClassification | |
def test_mask(model, sample): | |
""" | |
Masks the padded tokens in the input. | |
Args: | |
data (list): List of strings. | |
Returns: | |
dataset (list): List of dictionaries. | |
""" | |
tokens = dict() | |
input_tokens = [i + 3 for i in sample.encode('utf-8')] | |
input_tokens.append(0) # eos token | |
tokens['input_ids'] = torch.tensor([input_tokens], dtype=torch.int64, device=model.device) | |
# Create attention mask | |
tokens['attention_mask'] = torch.ones_like(tokens['input_ids'], dtype=torch.int64, device=model.device) | |
return tokens | |
def rewrite(model, data): | |
""" | |
Rewrites the input text with the model. | |
Args: | |
model (torch.nn.Module): Model. | |
data (dict): Dictionary containing 'input_ids' and 'attention_mask'. | |
Returns: | |
output (str): Rewritten text. | |
""" | |
with torch.no_grad(): | |
pred = torch.argmax(model(**data).logits, dim=2).squeeze(0) | |
output = list() # save the indices of the characters as list of integers | |
# Conversion table for Turkish characters {100: [300, 350], ...} | |
en2tr = {en: tr for tr, en in zip(list(map(list, map(str.encode, list('ÜİĞŞÇÖüığşçö')))), list(map(ord, list('UIGSCOuigsco'))))} | |
for inp, lab in zip((data['input_ids'].squeeze(0) - 3).tolist(), pred.tolist()): | |
if lab and inp in en2tr: | |
# if the model predicts a diacritic, replace it with the corresponding Turkish character | |
output.extend(en2tr[inp]) | |
elif inp >= 0: output.append(inp) | |
return bytes(output).decode() | |
def try_it(text): | |
sample = test_mask(model, text) | |
return rewrite(model, sample) | |
if __name__ == '__main__': | |
config = PeftConfig.from_pretrained("bite-the-byte/byt5-small-deASCIIfy-TR") | |
model = AutoModelForTokenClassification.from_pretrained("google/byt5-small") | |
model = PeftModel.from_pretrained(model, "bite-the-byte/byt5-small-deASCIIfy-TR") | |
diacritize_app = gr.Interface(fn=try_it, inputs="text", outputs="text") | |
diacritize_app.launch(share=True) |