darkproger commited on
Commit
b3e4e96
·
1 Parent(s): 920a8fe

use displacy to render tags

Browse files
Files changed (2) hide show
  1. app.py +36 -2
  2. requirements.txt +3 -0
app.py CHANGED
@@ -1,8 +1,13 @@
 
 
 
 
 
1
  import streamlit as st
2
  import torch
3
  from transformers import BertTokenizerFast
4
 
5
- from model import BertForTokenAndSequenceJointClassification
6
 
7
  @st.cache(allow_output_mutation=True)
8
  def load_model():
@@ -28,4 +33,33 @@ token_class_index = torch.argmax(outputs.token_logits, dim=-1)
28
  tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0][1:-1])
29
  tags = [model.token_tags[i] for i in token_class_index[0].tolist()[1:-1]]
30
 
31
- st.table(list(zip(tokens, tags)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from spacy import displacy
3
+ from spacy.tokens import Doc
4
+ from spacy.vocab import Vocab
5
+ from spacy_streamlit.util import get_html
6
  import streamlit as st
7
  import torch
8
  from transformers import BertTokenizerFast
9
 
10
+ from model import BertForTokenAndSequenceJointClassification, TOKEN_TAGS
11
 
12
  @st.cache(allow_output_mutation=True)
13
  def load_model():
 
33
  tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids[0][1:-1])
34
  tags = [model.token_tags[i] for i in token_class_index[0].tolist()[1:-1]]
35
 
36
+ spaces = [not tok.startswith('##') for tok in tokens][1:] + [False]
37
+
38
+ doc = Doc(Vocab(strings=set(tokens)),
39
+ words=tokens,
40
+ spaces=spaces,
41
+ ents=[tag if tag == "O" else f"I-{tag}" for tag in tags])
42
+
43
+ labels = TOKEN_TAGS[2:]
44
+
45
+ label_select = st.multiselect(
46
+ "Tags",
47
+ options=labels,
48
+ default=labels,
49
+ key=f"tags_ner_label_select",
50
+ )
51
+ html = displacy.render(
52
+ doc, style="ent", options={"ents": label_select, "colors": {}}
53
+ )
54
+ style = "<style>mark.entity { display: inline-block }</style>"
55
+ st.write(f"{style}{get_html(html)}", unsafe_allow_html=True)
56
+
57
+ attrs = ["text", "label_", "start", "end", "start_char", "end_char"]
58
+ data = [
59
+ [str(getattr(ent, attr)) for attr in attrs]
60
+ for ent in doc.ents
61
+ if ent.label_ in label_select
62
+ ]
63
+ if data:
64
+ df = pd.DataFrame(data, columns=attrs)
65
+ st.dataframe(df)
requirements.txt CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  streamlit
2
  transformers
3
  torch
 
1
+ pandas
2
+ spacy
3
+ spacy_streamlit
4
  streamlit
5
  transformers
6
  torch