alberti / app.py
alvp's picture
Update app.py
a1028e2
import random
import re
from poems import SAMPLE_POEMS
import langid
import numpy as np
import streamlit as st
import torch
from icu_tokenizer import Tokenizer
from transformers import pipeline
MODELS = {
"ALBERTI": "flax-community/alberti-bert-base-multilingual-cased",
"mBERT": "bert-base-multilingual-cased"
}
TOPK = 50
st.set_page_config(layout="wide")
def mask_line(line, language="es", restrictive=True):
tokenizer = Tokenizer(lang=language)
token_list = tokenizer.tokenize(line)
if lang != "zh":
restrictive = not all([len(token) <= 3 for token in token_list])
random_num = random.randint(0, len(token_list) - 1)
random_word = token_list[random_num]
if not restrictive:
token_list[random_num] = "[MASK]"
masked_l = " ".join(token_list)
return masked_l
elif len(random_word) > 3 or (lang == "zh" and random_word.isalpha()):
token_list[random_num] = "[MASK]"
masked_l = " ".join(token_list)
return masked_l
else:
return mask_line(line, language)
def filter_candidates(candidates, get_any_candidate=False):
cand_list = []
score_list = []
for candidate in candidates:
if not get_any_candidate and candidate["token_str"][:2] != "##" and candidate["token_str"].isalpha():
cand = candidate["sequence"]
score = candidate["score"]
cand_list.append(cand)
score_list.append('{0:.5f}'.format(score))
elif get_any_candidate:
cand = candidate["sequence"]
score = candidate["score"]
cand_list.append(cand)
score_list.append('{0:.5f}'.format(score))
if len(score_list) == TOPK:
break
if len(cand_list) < 1:
return filter_candidates(candidates, get_any_candidate=True)
else:
return cand_list[0]
def infer_candidates(nlp, line):
line = re.sub("–", "-", line)
line = re.sub("β€”", "-", line)
line = re.sub("’", "'", line)
line = re.sub("…", "...", line)
inputs = nlp._parse_and_tokenize(line)
outputs = nlp._forward(inputs, return_tensors=True)
input_ids = inputs["input_ids"][0]
masked_index = torch.nonzero(input_ids == nlp.tokenizer.mask_token_id,
as_tuple=False)
logits = outputs[0, masked_index.item(), :]
probs = logits.softmax(dim=0)
values, predictions = probs.topk(TOPK)
result = []
for v, p in zip(values.tolist(), predictions.tolist()):
tokens = input_ids.numpy()
tokens[masked_index] = p
# Filter padding out:
tokens = tokens[np.where(tokens != nlp.tokenizer.pad_token_id)]
l = []
token_list = [nlp.tokenizer.decode([token], skip_special_tokens=True) for token in tokens]
for idx, token in enumerate(token_list):
if token.startswith('##'):
l[-1] += token[2:]
elif idx == masked_index.item():
l += ['<b style="color: #ff0000;">', token, "</b>"]
else:
l += [token]
sequence = " ".join(l).strip()
result.append(
{
"sequence": sequence,
"score": v,
"token": p,
"token_str": nlp.tokenizer.decode(p),
"masked_index": masked_index.item()
}
)
return result
def rewrite_poem(poem, ml_model=MODELS["ALBERTI"], masking=True, language="es"):
nlp = pipeline("fill-mask", model=ml_model)
unmasked_lines = []
masked_lines = []
for line in poem:
if line == "":
unmasked_lines.append("")
masked_lines.append("")
continue
if masking:
masked_line = mask_line(line, language)
else:
masked_line = line
masked_lines.append(masked_line)
unmasked_line_candidates = infer_candidates(nlp, masked_line)
unmasked_line = filter_candidates(unmasked_line_candidates)
unmasked_lines.append(unmasked_line)
unmasked_poem = "<br>".join(unmasked_lines)
return unmasked_poem, masked_lines
instructions_text_0 = st.sidebar.markdown(
"""# ALBERTI vs BERT πŸ₯Š
We present ALBERTI, our BERT-based multilingual model for poetry.""")
instructions_text_1 = st.sidebar.markdown(
"""We have trained bert on a huge (for poetry, that is) corpus of
multilingual poetry to try to get a more 'poetic' model. This is the result
of our work.
You can find more information on the [project's site](https://huggingface.co/flax-community/alberti-bert-base-multilingual-cased)""")
sample_chooser = st.sidebar.selectbox(
"Choose a poem",
list(SAMPLE_POEMS.keys())
)
instructions_text_2 = st.sidebar.markdown("""# How to use
You can choose from a list of example poems in Spanish, English, French, German,
Chinese and Arabic, but you can also paste a poem, or write it yourself!
Then click on 'Rewrite!' to do the masking and the fill-mask task on the chosen
poem, randomly masking one word per verse, and get the two new versions for each of the models.
The list of languages used on the training of ALBERTI are:
* Arabic
* Chinese
* Czech
* English
* Finnish
* French
* German
* Hungarian
* Italian
* Portuguese
* Russian
* Spanish""")
col1, col2, col3 = st.columns(3)
st.markdown(
"""
<style>
label {
font-size: 1rem !important;
font-weight: bold !important;
}
.block-container {
padding-left: 1rem !important;
padding-right: 1rem !important;
}
</style>
""", unsafe_allow_html=True)
if sample_chooser:
model_list = set(MODELS.values())
user_input = col1.text_area("Input poem",
"\n".join(SAMPLE_POEMS[sample_chooser]),
height=600)
poem = user_input.split("\n")
rewrite_button = col1.button("Rewrite!")
if "[MASK]" in user_input or "<mask>" in user_input:
col1.error("You don't have to mask the poem, we'll do it for you!")
if rewrite_button:
lang = langid.classify(user_input)[0]
unmasked_poem, masked_poem = rewrite_poem(poem, language=lang)
user_input_2 = col2.write(f"""<b>Output poem from ALBERTI</b>
{unmasked_poem}""", unsafe_allow_html=True)
unmasked_poem_2, _ = rewrite_poem(masked_poem, ml_model=MODELS["mBERT"],
masking=False)
user_input_3 = col3.write(f"""<b>Output poem from mBERT</b>
{unmasked_poem_2}""", unsafe_allow_html=True)