Spaces:
Runtime error
Runtime error
import random | |
import re | |
from poems import SAMPLE_POEMS | |
import langid | |
import numpy as np | |
import streamlit as st | |
import torch | |
from icu_tokenizer import Tokenizer | |
from transformers import pipeline | |
MODELS = { | |
"ALBERTI": "flax-community/alberti-bert-base-multilingual-cased", | |
"mBERT": "bert-base-multilingual-cased" | |
} | |
TOPK = 50 | |
st.set_page_config(layout="wide") | |
def mask_line(line, language="es", restrictive=True): | |
tokenizer = Tokenizer(lang=language) | |
token_list = tokenizer.tokenize(line) | |
if lang != "zh": | |
restrictive = not all([len(token) <= 3 for token in token_list]) | |
random_num = random.randint(0, len(token_list) - 1) | |
random_word = token_list[random_num] | |
if not restrictive: | |
token_list[random_num] = "[MASK]" | |
masked_l = " ".join(token_list) | |
return masked_l | |
elif len(random_word) > 3 or (lang == "zh" and random_word.isalpha()): | |
token_list[random_num] = "[MASK]" | |
masked_l = " ".join(token_list) | |
return masked_l | |
else: | |
return mask_line(line, language) | |
def filter_candidates(candidates, get_any_candidate=False): | |
cand_list = [] | |
score_list = [] | |
for candidate in candidates: | |
if not get_any_candidate and candidate["token_str"][:2] != "##" and candidate["token_str"].isalpha(): | |
cand = candidate["sequence"] | |
score = candidate["score"] | |
cand_list.append(cand) | |
score_list.append('{0:.5f}'.format(score)) | |
elif get_any_candidate: | |
cand = candidate["sequence"] | |
score = candidate["score"] | |
cand_list.append(cand) | |
score_list.append('{0:.5f}'.format(score)) | |
if len(score_list) == TOPK: | |
break | |
if len(cand_list) < 1: | |
return filter_candidates(candidates, get_any_candidate=True) | |
else: | |
return cand_list[0] | |
def infer_candidates(nlp, line): | |
line = re.sub("β", "-", line) | |
line = re.sub("β", "-", line) | |
line = re.sub("β", "'", line) | |
line = re.sub("β¦", "...", line) | |
inputs = nlp._parse_and_tokenize(line) | |
outputs = nlp._forward(inputs, return_tensors=True) | |
input_ids = inputs["input_ids"][0] | |
masked_index = torch.nonzero(input_ids == nlp.tokenizer.mask_token_id, | |
as_tuple=False) | |
logits = outputs[0, masked_index.item(), :] | |
probs = logits.softmax(dim=0) | |
values, predictions = probs.topk(TOPK) | |
result = [] | |
for v, p in zip(values.tolist(), predictions.tolist()): | |
tokens = input_ids.numpy() | |
tokens[masked_index] = p | |
# Filter padding out: | |
tokens = tokens[np.where(tokens != nlp.tokenizer.pad_token_id)] | |
l = [] | |
token_list = [nlp.tokenizer.decode([token], skip_special_tokens=True) for token in tokens] | |
for idx, token in enumerate(token_list): | |
if token.startswith('##'): | |
l[-1] += token[2:] | |
elif idx == masked_index.item(): | |
l += ['<b style="color: #ff0000;">', token, "</b>"] | |
else: | |
l += [token] | |
sequence = " ".join(l).strip() | |
result.append( | |
{ | |
"sequence": sequence, | |
"score": v, | |
"token": p, | |
"token_str": nlp.tokenizer.decode(p), | |
"masked_index": masked_index.item() | |
} | |
) | |
return result | |
def rewrite_poem(poem, ml_model=MODELS["ALBERTI"], masking=True, language="es"): | |
nlp = pipeline("fill-mask", model=ml_model) | |
unmasked_lines = [] | |
masked_lines = [] | |
for line in poem: | |
if line == "": | |
unmasked_lines.append("") | |
masked_lines.append("") | |
continue | |
if masking: | |
masked_line = mask_line(line, language) | |
else: | |
masked_line = line | |
masked_lines.append(masked_line) | |
unmasked_line_candidates = infer_candidates(nlp, masked_line) | |
unmasked_line = filter_candidates(unmasked_line_candidates) | |
unmasked_lines.append(unmasked_line) | |
unmasked_poem = "<br>".join(unmasked_lines) | |
return unmasked_poem, masked_lines | |
instructions_text_0 = st.sidebar.markdown( | |
"""# ALBERTI vs BERT π₯ | |
We present ALBERTI, our BERT-based multilingual model for poetry.""") | |
instructions_text_1 = st.sidebar.markdown( | |
"""We have trained bert on a huge (for poetry, that is) corpus of | |
multilingual poetry to try to get a more 'poetic' model. This is the result | |
of our work. | |
You can find more information on the [project's site](https://huggingface.co/flax-community/alberti-bert-base-multilingual-cased)""") | |
sample_chooser = st.sidebar.selectbox( | |
"Choose a poem", | |
list(SAMPLE_POEMS.keys()) | |
) | |
instructions_text_2 = st.sidebar.markdown("""# How to use | |
You can choose from a list of example poems in Spanish, English, French, German, | |
Chinese and Arabic, but you can also paste a poem, or write it yourself! | |
Then click on 'Rewrite!' to do the masking and the fill-mask task on the chosen | |
poem, randomly masking one word per verse, and get the two new versions for each of the models. | |
The list of languages used on the training of ALBERTI are: | |
* Arabic | |
* Chinese | |
* Czech | |
* English | |
* Finnish | |
* French | |
* German | |
* Hungarian | |
* Italian | |
* Portuguese | |
* Russian | |
* Spanish""") | |
col1, col2, col3 = st.columns(3) | |
st.markdown( | |
""" | |
<style> | |
label { | |
font-size: 1rem !important; | |
font-weight: bold !important; | |
} | |
.block-container { | |
padding-left: 1rem !important; | |
padding-right: 1rem !important; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
if sample_chooser: | |
model_list = set(MODELS.values()) | |
user_input = col1.text_area("Input poem", | |
"\n".join(SAMPLE_POEMS[sample_chooser]), | |
height=600) | |
poem = user_input.split("\n") | |
rewrite_button = col1.button("Rewrite!") | |
if "[MASK]" in user_input or "<mask>" in user_input: | |
col1.error("You don't have to mask the poem, we'll do it for you!") | |
if rewrite_button: | |
lang = langid.classify(user_input)[0] | |
unmasked_poem, masked_poem = rewrite_poem(poem, language=lang) | |
user_input_2 = col2.write(f"""<b>Output poem from ALBERTI</b> | |
{unmasked_poem}""", unsafe_allow_html=True) | |
unmasked_poem_2, _ = rewrite_poem(masked_poem, ml_model=MODELS["mBERT"], | |
masking=False) | |
user_input_3 = col3.write(f"""<b>Output poem from mBERT</b> | |
{unmasked_poem_2}""", unsafe_allow_html=True) | |