File size: 1,222 Bytes
4b75840
 
 
 
e99a699
4b75840
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline


class NamedEntityRecognition:
    def __init__(self):
        tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large-finetuned-conll03-english")
        model = AutoModelForTokenClassification.from_pretrained("xlm-roberta-large-finetuned-conll03-english")
        self.nlp = pipeline("ner", model=model, tokenizer=tokenizer, grouped_entities=True)

    def get_annotation(self, preds, text):
        splits = [0]
        entities = {}
        for i in preds:
            splits.append(i['start'])
            splits.append(i['end'])
            entities[i['word']] = i['entity_group']

        # Exclude bad preds
        exclude = ['', '.', '. ', ' ']
        for x in exclude:
            if x in entities.keys():
                entities.pop(x)

        parts = [text[i:j] for i, j in zip(splits, splits[1:] + [None])]

        final_annotation = [(x, entities[x], "") if x in entities.keys() else x for x in parts]

        return final_annotation

    def classify(self, text):
        preds = self.nlp(text)
        ner_annotation = self.get_annotation(preds, text)
        return preds, ner_annotation