darkproger commited on
Commit
efb23c9
·
1 Parent(s): 766dac7

use st.metric for sequence logits

Browse files
Files changed (1) hide show
  1. app.py +22 -15
app.py CHANGED
@@ -7,7 +7,8 @@ import streamlit as st
7
  import torch
8
  from transformers import BertTokenizerFast
9
 
10
- from model import BertForTokenAndSequenceJointClassification, TOKEN_TAGS
 
11
 
12
  @st.cache(allow_output_mutation=True)
13
  def load_model():
@@ -16,22 +17,28 @@ def load_model():
16
  "QCRI/PropagandaTechniquesAnalysis-en-BERT",
17
  revision="v0.1.0")
18
  return tokenizer, model
19
-
20
- tokenizer, model = load_model()
21
 
22
- st.write("[Propaganda Techniques Analysis BERT](https://huggingface.co/QCRI/PropagandaTechniquesAnalysis-en-BERT) Tagger")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- input = st.text_area('Input', """\
25
- In some instances, it can be highly dangerous to use a medicine for the prevention or treatment of COVID-19 that has not been approved by or has not received emergency use authorization from the FDA.
26
- """)
27
 
28
- inputs = tokenizer.encode_plus(input, return_tensors="pt")
29
- outputs = model(**inputs)
30
- sequence_class_index = torch.argmax(outputs.sequence_logits, dim=-1)
31
- sequence_class = model.sequence_tags[sequence_class_index[0]]
32
- token_class_index = torch.argmax(outputs.token_logits, dim=-1)
33
- tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0][1:-1])
34
- tags = [model.token_tags[i] for i in token_class_index[0].tolist()[1:-1]]
35
 
36
  spaces = [not tok.startswith('##') for tok in tokens][1:] + [False]
37
 
@@ -40,7 +47,7 @@ doc = Doc(Vocab(strings=set(tokens)),
40
  spaces=spaces,
41
  ents=[tag if tag == "O" else f"B-{tag}" for tag in tags])
42
 
43
- labels = TOKEN_TAGS[2:]
44
 
45
  label_select = st.multiselect(
46
  "Tags",
 
7
  import torch
8
  from transformers import BertTokenizerFast
9
 
10
+ from model import BertForTokenAndSequenceJointClassification
11
+
12
 
13
  @st.cache(allow_output_mutation=True)
14
  def load_model():
 
17
  "QCRI/PropagandaTechniquesAnalysis-en-BERT",
18
  revision="v0.1.0")
19
  return tokenizer, model
 
 
20
 
21
+ with torch.inference_mode(True):
22
+ tokenizer, model = load_model()
23
+
24
+ st.write("[Propaganda Techniques Analysis BERT](https://huggingface.co/QCRI/PropagandaTechniquesAnalysis-en-BERT) Tagger")
25
+
26
+ input = st.text_area('Input', """\
27
+ In some instances, it can be highly dangerous to use a medicine for the prevention or treatment of COVID-19 that has not been approved by or has not received emergency use authorization from the FDA.
28
+ """)
29
+
30
+ inputs = tokenizer.encode_plus(input, return_tensors="pt")
31
+ outputs = model(**inputs)
32
+ sequence_class_index = torch.argmax(outputs.sequence_logits, dim=-1)
33
+ sequence_class = model.sequence_tags[sequence_class_index[0]]
34
+ token_class_index = torch.argmax(outputs.token_logits, dim=-1)
35
+ tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0][1:-1])
36
+ tags = [model.token_tags[i] for i in token_class_index[0].tolist()[1:-1]]
37
 
38
+ columns = st.columns(len(outputs.sequence_logits.flatten()))
39
+ for col, sequence_tag, logit in zip(columns, model.sequence_tags, outputs.sequence_logits.flatten()):
40
+ col.metric(sequence_tag, '%.2f' % logit.item())
41
 
 
 
 
 
 
 
 
42
 
43
  spaces = [not tok.startswith('##') for tok in tokens][1:] + [False]
44
 
 
47
  spaces=spaces,
48
  ents=[tag if tag == "O" else f"B-{tag}" for tag in tags])
49
 
50
+ labels = model.token_tags[2:]
51
 
52
  label_select = st.multiselect(
53
  "Tags",