propaganda / app.py
darkproger's picture
use st.metric for sequence logits
efb23c9
import pandas as pd
from spacy import displacy
from spacy.tokens import Doc
from spacy.vocab import Vocab
from spacy_streamlit.util import get_html
import streamlit as st
import torch
from transformers import BertTokenizerFast
from model import BertForTokenAndSequenceJointClassification
@st.cache(allow_output_mutation=True)
def load_model():
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')
model = BertForTokenAndSequenceJointClassification.from_pretrained(
"QCRI/PropagandaTechniquesAnalysis-en-BERT",
revision="v0.1.0")
return tokenizer, model
with torch.inference_mode(True):
tokenizer, model = load_model()
st.write("[Propaganda Techniques Analysis BERT](https://huggingface.co/QCRI/PropagandaTechniquesAnalysis-en-BERT) Tagger")
input = st.text_area('Input', """\
In some instances, it can be highly dangerous to use a medicine for the prevention or treatment of COVID-19 that has not been approved by or has not received emergency use authorization from the FDA.
""")
inputs = tokenizer.encode_plus(input, return_tensors="pt")
outputs = model(**inputs)
sequence_class_index = torch.argmax(outputs.sequence_logits, dim=-1)
sequence_class = model.sequence_tags[sequence_class_index[0]]
token_class_index = torch.argmax(outputs.token_logits, dim=-1)
tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0][1:-1])
tags = [model.token_tags[i] for i in token_class_index[0].tolist()[1:-1]]
columns = st.columns(len(outputs.sequence_logits.flatten()))
for col, sequence_tag, logit in zip(columns, model.sequence_tags, outputs.sequence_logits.flatten()):
col.metric(sequence_tag, '%.2f' % logit.item())
spaces = [not tok.startswith('##') for tok in tokens][1:] + [False]
doc = Doc(Vocab(strings=set(tokens)),
words=tokens,
spaces=spaces,
ents=[tag if tag == "O" else f"B-{tag}" for tag in tags])
labels = model.token_tags[2:]
label_select = st.multiselect(
"Tags",
options=labels,
default=labels,
key=f"tags_ner_label_select",
)
html = displacy.render(
doc, style="ent", options={"ents": label_select, "colors": {}}
)
style = "<style>mark.entity { display: inline-block }</style>"
st.write(f"{style}{get_html(html)}", unsafe_allow_html=True)
attrs = ["text", "label_", "start", "end", "start_char", "end_char"]
data = [
[str(getattr(ent, attr)) for attr in attrs]
for ent in doc.ents
if ent.label_ in label_select
]
if data:
df = pd.DataFrame(data, columns=attrs)
st.dataframe(df)