Spaces:
Build error
Build error
File size: 2,576 Bytes
b3e4e96 26dff99 efb23c9 26dff99 efb23c9 920a8fe efb23c9 26dff99 b3e4e96 766dac7 b3e4e96 efb23c9 b3e4e96 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import pandas as pd
from spacy import displacy
from spacy.tokens import Doc
from spacy.vocab import Vocab
from spacy_streamlit.util import get_html
import streamlit as st
import torch
from transformers import BertTokenizerFast
from model import BertForTokenAndSequenceJointClassification
@st.cache(allow_output_mutation=True)
def load_model():
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')
model = BertForTokenAndSequenceJointClassification.from_pretrained(
"QCRI/PropagandaTechniquesAnalysis-en-BERT",
revision="v0.1.0")
return tokenizer, model
with torch.inference_mode(True):
tokenizer, model = load_model()
st.write("[Propaganda Techniques Analysis BERT](https://huggingface.co/QCRI/PropagandaTechniquesAnalysis-en-BERT) Tagger")
input = st.text_area('Input', """\
In some instances, it can be highly dangerous to use a medicine for the prevention or treatment of COVID-19 that has not been approved by or has not received emergency use authorization from the FDA.
""")
inputs = tokenizer.encode_plus(input, return_tensors="pt")
outputs = model(**inputs)
sequence_class_index = torch.argmax(outputs.sequence_logits, dim=-1)
sequence_class = model.sequence_tags[sequence_class_index[0]]
token_class_index = torch.argmax(outputs.token_logits, dim=-1)
tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0][1:-1])
tags = [model.token_tags[i] for i in token_class_index[0].tolist()[1:-1]]
columns = st.columns(len(outputs.sequence_logits.flatten()))
for col, sequence_tag, logit in zip(columns, model.sequence_tags, outputs.sequence_logits.flatten()):
col.metric(sequence_tag, '%.2f' % logit.item())
spaces = [not tok.startswith('##') for tok in tokens][1:] + [False]
doc = Doc(Vocab(strings=set(tokens)),
words=tokens,
spaces=spaces,
ents=[tag if tag == "O" else f"B-{tag}" for tag in tags])
labels = model.token_tags[2:]
label_select = st.multiselect(
"Tags",
options=labels,
default=labels,
key=f"tags_ner_label_select",
)
html = displacy.render(
doc, style="ent", options={"ents": label_select, "colors": {}}
)
style = "<style>mark.entity { display: inline-block }</style>"
st.write(f"{style}{get_html(html)}", unsafe_allow_html=True)
attrs = ["text", "label_", "start", "end", "start_char", "end_char"]
data = [
[str(getattr(ent, attr)) for attr in attrs]
for ent in doc.ents
if ent.label_ in label_select
]
if data:
df = pd.DataFrame(data, columns=attrs)
st.dataframe(df)
|