Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -21,12 +21,16 @@ class ExampleDocument(TextDocument):
|
|
21 |
|
22 |
ner_model_name_or_path = "pie/example-ner-spanclf-conll03"
|
23 |
re_model_name_or_path = "DFKI-SLT/relation_classification_tacred_revisited"
|
24 |
-
#"pie/example-re-textclf-tacred"
|
25 |
-
#"DFKI-SLT/relation_classification_tacred_revisited"
|
26 |
|
27 |
ner_pipeline = AutoPipeline.from_pretrained(ner_model_name_or_path, device=-1, num_workers=0)
|
28 |
re_pipeline = AutoPipeline.from_pretrained(re_model_name_or_path, device=-1, num_workers=0)
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
def predict(text):
|
32 |
document = ExampleDocument(text)
|
@@ -34,10 +38,16 @@ def predict(text):
|
|
34 |
ner_pipeline(document)
|
35 |
|
36 |
while len(document.entities.predictions) > 0:
|
37 |
-
document.entities.
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
re_pipeline(document)
|
40 |
|
|
|
41 |
t = PrettyTable()
|
42 |
t.field_names = ["head", "tail", "relation"]
|
43 |
t.align = "l"
|
|
|
21 |
|
22 |
ner_model_name_or_path = "pie/example-ner-spanclf-conll03"
|
23 |
re_model_name_or_path = "DFKI-SLT/relation_classification_tacred_revisited"
|
|
|
|
|
24 |
|
25 |
ner_pipeline = AutoPipeline.from_pretrained(ner_model_name_or_path, device=-1, num_workers=0)
|
26 |
re_pipeline = AutoPipeline.from_pretrained(re_model_name_or_path, device=-1, num_workers=0)
|
27 |
|
28 |
+
ner_tag_mapping = {
|
29 |
+
'ORG': 'ORGANIZATION',
|
30 |
+
'PER': 'PERSON',
|
31 |
+
'LOC': 'LOCATION'
|
32 |
+
}
|
33 |
+
|
34 |
|
35 |
def predict(text):
|
36 |
document = ExampleDocument(text)
|
|
|
38 |
ner_pipeline(document)
|
39 |
|
40 |
while len(document.entities.predictions) > 0:
|
41 |
+
entity = document.entities.predictions.pop(0)
|
42 |
+
if entity.label in ner_tag_mapping:
|
43 |
+
entity = LabeledSpan(start=entity.start, end=entity.end, label=ner_tag_mapping[entity.label],
|
44 |
+
score=entity.score)
|
45 |
+
if entity.label in re_pipeline.taskmodule.entity_labels:
|
46 |
+
document.entities.append(entity)
|
47 |
|
48 |
re_pipeline(document)
|
49 |
|
50 |
+
|
51 |
t = PrettyTable()
|
52 |
t.field_names = ["head", "tail", "relation"]
|
53 |
t.align = "l"
|