Spaces:
Build error
Build error
import pandas as pd | |
from spacy import displacy | |
from spacy.tokens import Doc | |
from spacy.vocab import Vocab | |
from spacy_streamlit.util import get_html | |
import streamlit as st | |
import torch | |
from transformers import BertTokenizerFast | |
from model import BertForTokenAndSequenceJointClassification | |
def load_model(): | |
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased') | |
model = BertForTokenAndSequenceJointClassification.from_pretrained( | |
"QCRI/PropagandaTechniquesAnalysis-en-BERT", | |
revision="v0.1.0") | |
return tokenizer, model | |
with torch.inference_mode(True): | |
tokenizer, model = load_model() | |
st.write("[Propaganda Techniques Analysis BERT](https://huggingface.co/QCRI/PropagandaTechniquesAnalysis-en-BERT) Tagger") | |
input = st.text_area('Input', """\ | |
In some instances, it can be highly dangerous to use a medicine for the prevention or treatment of COVID-19 that has not been approved by or has not received emergency use authorization from the FDA. | |
""") | |
inputs = tokenizer.encode_plus(input, return_tensors="pt") | |
outputs = model(**inputs) | |
sequence_class_index = torch.argmax(outputs.sequence_logits, dim=-1) | |
sequence_class = model.sequence_tags[sequence_class_index[0]] | |
token_class_index = torch.argmax(outputs.token_logits, dim=-1) | |
tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0][1:-1]) | |
tags = [model.token_tags[i] for i in token_class_index[0].tolist()[1:-1]] | |
columns = st.columns(len(outputs.sequence_logits.flatten())) | |
for col, sequence_tag, logit in zip(columns, model.sequence_tags, outputs.sequence_logits.flatten()): | |
col.metric(sequence_tag, '%.2f' % logit.item()) | |
spaces = [not tok.startswith('##') for tok in tokens][1:] + [False] | |
doc = Doc(Vocab(strings=set(tokens)), | |
words=tokens, | |
spaces=spaces, | |
ents=[tag if tag == "O" else f"B-{tag}" for tag in tags]) | |
labels = model.token_tags[2:] | |
label_select = st.multiselect( | |
"Tags", | |
options=labels, | |
default=labels, | |
key=f"tags_ner_label_select", | |
) | |
html = displacy.render( | |
doc, style="ent", options={"ents": label_select, "colors": {}} | |
) | |
style = "<style>mark.entity { display: inline-block }</style>" | |
st.write(f"{style}{get_html(html)}", unsafe_allow_html=True) | |
attrs = ["text", "label_", "start", "end", "start_char", "end_char"] | |
data = [ | |
[str(getattr(ent, attr)) for attr in attrs] | |
for ent in doc.ents | |
if ent.label_ in label_select | |
] | |
if data: | |
df = pd.DataFrame(data, columns=attrs) | |
st.dataframe(df) | |