dfki-nlp commited on
Commit
c7db18a
·
1 Parent(s): 132fa59

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -3
app.py CHANGED
@@ -21,12 +21,16 @@ class ExampleDocument(TextDocument):
21
 
22
  ner_model_name_or_path = "pie/example-ner-spanclf-conll03"
23
  re_model_name_or_path = "DFKI-SLT/relation_classification_tacred_revisited"
24
- #"pie/example-re-textclf-tacred"
25
- #"DFKI-SLT/relation_classification_tacred_revisited"
26
 
27
  ner_pipeline = AutoPipeline.from_pretrained(ner_model_name_or_path, device=-1, num_workers=0)
28
  re_pipeline = AutoPipeline.from_pretrained(re_model_name_or_path, device=-1, num_workers=0)
29
 
 
 
 
 
 
 
30
 
31
  def predict(text):
32
  document = ExampleDocument(text)
@@ -34,10 +38,16 @@ def predict(text):
34
  ner_pipeline(document)
35
 
36
  while len(document.entities.predictions) > 0:
37
- document.entities.append(document.entities.predictions.pop(0))
 
 
 
 
 
38
 
39
  re_pipeline(document)
40
 
 
41
  t = PrettyTable()
42
  t.field_names = ["head", "tail", "relation"]
43
  t.align = "l"
 
21
 
22
  ner_model_name_or_path = "pie/example-ner-spanclf-conll03"
23
  re_model_name_or_path = "DFKI-SLT/relation_classification_tacred_revisited"
 
 
24
 
25
  ner_pipeline = AutoPipeline.from_pretrained(ner_model_name_or_path, device=-1, num_workers=0)
26
  re_pipeline = AutoPipeline.from_pretrained(re_model_name_or_path, device=-1, num_workers=0)
27
 
28
+ ner_tag_mapping = {
29
+ 'ORG': 'ORGANIZATION',
30
+ 'PER': 'PERSON',
31
+ 'LOC': 'LOCATION'
32
+ }
33
+
34
 
35
  def predict(text):
36
  document = ExampleDocument(text)
 
38
  ner_pipeline(document)
39
 
40
  while len(document.entities.predictions) > 0:
41
+ entity = document.entities.predictions.pop(0)
42
+ if entity.label in ner_tag_mapping:
43
+ entity = LabeledSpan(start=entity.start, end=entity.end, label=ner_tag_mapping[entity.label],
44
+ score=entity.score)
45
+ if entity.label in re_pipeline.taskmodule.entity_labels:
46
+ document.entities.append(entity)
47
 
48
  re_pipeline(document)
49
 
50
+
51
  t = PrettyTable()
52
  t.field_names = ["head", "tail", "relation"]
53
  t.align = "l"