Knowledge-graphs / rebel.py
khaerens's picture
space
70303d6
raw
history blame
3.85 kB
from typing import List
from transformers import pipeline
from pyvis.network import Network
from functools import lru_cache
from app import generate_graph
import spacy
from spacy import displacy
DEFAULT_LABEL_COLORS = {
"ORG": "#7aecec",
"PRODUCT": "#bfeeb7",
"GPE": "#feca74",
"LOC": "#ff9561",
"PERSON": "#aa9cfc",
"NORP": "#c887fb",
"FACILITY": "#9cc9cc",
"EVENT": "#ffeb80",
"LAW": "#ff8197",
"LANGUAGE": "#ff8197",
"WORK_OF_ART": "#f0d0ff",
"DATE": "#bfe1d9",
"TIME": "#bfe1d9",
"MONEY": "#e4e7d2",
"QUANTITY": "#e4e7d2",
"ORDINAL": "#e4e7d2",
"CARDINAL": "#e4e7d2",
"PERCENT": "#e4e7d2",
}
def generate_knowledge_graph(texts: List[str], filename: str):
nlp = spacy.load("en_core_web_sm")
doc = nlp("\n".join(texts))
NERs = [ent.text for ent in doc.ents]
NER_types = [ent.label_ for ent in doc.ents]
for nr, nrt in zip(NERs, NER_types):
print(nr, nrt)
triplets = []
for triplet in texts:
triplets.extend(generate_partial_graph(triplet))
print(generate_partial_graph.cache_info())
heads = [ t["head"] for t in triplets]
tails = [ t["tail"] for t in triplets]
nodes = set(heads + tails)
net = Network(directed=True)
for n in nodes:
if n in NERs:
NER_type = NER_types[NERs.index(n)]
color = DEFAULT_LABEL_COLORS[NER_type]
net.add_node(n, title=NER_type, shape="circle", color=color)
else:
net.add_node(n, shape="circle")
unique_triplets = set()
stringify_trip = lambda x : x["tail"] + x["head"] + x["type"]
for triplet in triplets:
if stringify_trip(triplet) not in unique_triplets:
net.add_edge(triplet["tail"], triplet["head"], title=triplet["type"], label=triplet["type"])
unique_triplets.add(stringify_trip(triplet))
net.repulsion(
node_distance=200,
central_gravity=0.2,
spring_length=200,
spring_strength=0.05,
damping=0.09
)
net.set_edge_smooth('dynamic')
net.show(filename)
return nodes
@lru_cache
def generate_partial_graph(text):
triplet_extractor = pipeline('text2text-generation', model='Babelscape/rebel-large', tokenizer='Babelscape/rebel-large')
a = triplet_extractor(text, return_tensors=True, return_text=False)[0]["generated_token_ids"]["output_ids"]
extracted_text = triplet_extractor.tokenizer.batch_decode(a)
extracted_triplets = extract_triplets(extracted_text[0])
return extracted_triplets
def extract_triplets(text):
"""
Function to parse the generated text and extract the triplets
"""
triplets = []
relation, subject, relation, object_ = '', '', '', ''
text = text.strip()
current = 'x'
for token in text.replace("<s>", "").replace("<pad>", "").replace("</s>", "").split():
if token == "<triplet>":
current = 't'
if relation != '':
triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
relation = ''
subject = ''
elif token == "<subj>":
current = 's'
if relation != '':
triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
object_ = ''
elif token == "<obj>":
current = 'o'
relation = ''
else:
if current == 't':
subject += ' ' + token
elif current == 's':
object_ += ' ' + token
elif current == 'o':
relation += ' ' + token
if subject != '' and relation != '' and object_ != '':
triplets.append({'head': subject.strip(), 'type': relation.strip(),'tail': object_.strip()})
return triplets