File size: 3,846 Bytes
70303d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from typing import List
from transformers import pipeline
from pyvis.network import Network
from functools import lru_cache
from app import generate_graph
import spacy
from spacy import displacy


DEFAULT_LABEL_COLORS = {
    "ORG": "#7aecec",
    "PRODUCT": "#bfeeb7",
    "GPE": "#feca74",
    "LOC": "#ff9561",
    "PERSON": "#aa9cfc",
    "NORP": "#c887fb",
    "FACILITY": "#9cc9cc",
    "EVENT": "#ffeb80",
    "LAW": "#ff8197",
    "LANGUAGE": "#ff8197",
    "WORK_OF_ART": "#f0d0ff",
    "DATE": "#bfe1d9",
    "TIME": "#bfe1d9",
    "MONEY": "#e4e7d2",
    "QUANTITY": "#e4e7d2",
    "ORDINAL": "#e4e7d2",
    "CARDINAL": "#e4e7d2",
    "PERCENT": "#e4e7d2",
}

def generate_knowledge_graph(texts: List[str], filename: str):
    nlp = spacy.load("en_core_web_sm")
    doc = nlp("\n".join(texts))
    NERs = [ent.text for ent in doc.ents]
    NER_types =  [ent.label_ for ent in doc.ents]
    for nr, nrt in zip(NERs, NER_types):
        print(nr, nrt)

    triplets = []
    for triplet in texts:
        triplets.extend(generate_partial_graph(triplet))
    print(generate_partial_graph.cache_info())
    heads = [ t["head"] for t in triplets]
    tails = [ t["tail"] for t in triplets]

    nodes = set(heads + tails)
    net = Network(directed=True)

    for n in nodes:
        if n in NERs:
            NER_type = NER_types[NERs.index(n)]
            color = DEFAULT_LABEL_COLORS[NER_type]
            net.add_node(n, title=NER_type, shape="circle", color=color)
        else:
            net.add_node(n, shape="circle")

    unique_triplets = set()
    stringify_trip = lambda x : x["tail"] + x["head"] + x["type"]
    for triplet in triplets:
        if stringify_trip(triplet) not in unique_triplets:
            net.add_edge(triplet["tail"], triplet["head"], title=triplet["type"], label=triplet["type"])
            unique_triplets.add(stringify_trip(triplet))

    net.repulsion(
        node_distance=200,
        central_gravity=0.2,
        spring_length=200,
        spring_strength=0.05,
        damping=0.09
    )
    net.set_edge_smooth('dynamic')
    net.show(filename)
    return nodes


@lru_cache
def generate_partial_graph(text):
    triplet_extractor = pipeline('text2text-generation', model='Babelscape/rebel-large', tokenizer='Babelscape/rebel-large')
    a = triplet_extractor(text, return_tensors=True, return_text=False)[0]["generated_token_ids"]["output_ids"]
    extracted_text = triplet_extractor.tokenizer.batch_decode(a)
    extracted_triplets = extract_triplets(extracted_text[0])
    return extracted_triplets


def extract_triplets(text):
    """
    Function to parse the generated text and extract the triplets
    """
    triplets = []
    relation, subject, relation, object_ = '', '', '', ''
    text = text.strip()
    current = 'x'
    for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
        if token == "<triplet>":
            current = 't'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
                relation = ''
            subject = ''
        elif token == "<subj>":
            current = 's'
            if relation != '':
                triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
            object_ = ''
        elif token == "<obj>":
            current = 'o'
            relation = ''
        else:
            if current == 't':
                subject += ' ' + token
            elif current == 's':
                object_ += ' ' + token
            elif current == 'o':
                relation += ' ' + token
    if subject != '' and relation != '' and object_ != '':
        triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})

    return triplets