File size: 3,824 Bytes
94be5fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""
| **Abbreviation** | **Description** |
|------------------|-----------------|
| O                | Outside of a named entity
| B-MIS            | Beginning of a miscellaneous entity right after another miscellaneous entity
| I-MIS            | Miscellaneous entity
| B-PER            | Beginning of a person’s name right after another person’s name
| I-PER            | Person’s name
| B-ORG            | Beginning of an organization right after another organization
| I-ORG            | Organization
| B-LOC            | Beginning of a location right after another location
| I-LOC            | Location
"""

from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline

from enum import Enum


class DictKey(Enum):
    ENTITY = 'entity'
    SCORE = 'score'
    INDEX = 'index'
    WORD = 'word'
    START = 'start'
    END = 'end'


class NER:
    def __init__(self, text_to_analyse):
        """
        The Constructor for the Named Entity Recognition class.
        :param text_to_analyse: The text in which to find named entities.
        """
        self.tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")

        self.model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")

        self.nlp = pipeline("ner", model=self.model, tokenizer=self.tokenizer, grouped_entities=True)
        if self.nlp is None:
            raise ValueError("Unable to load pipeline from DSLIM BERT model")

        self.text_to_analyse = text_to_analyse
        self.results = self.nlp(text_to_analyse)
        self.all_entities = self.get_list_of_entities()
        self.unique_entities = self.unique_entities()
        self.markdown = None
        self.markdown_text = None

    def get_entity_value(self, key: DictKey, item_index):
        """
        Extracts the value for a specific key (as an Enum) from a specific dictionary item in the list.
        :param key: DictKey Enum representing the key for which the value is required.
        :param item_index: Index of the item in the list to process.
        :return: Value for the given key in the specified dictionary item, or None if key is not found.
        """
        if item_index < len(self.results):
            return self.results[item_index].get(key.value)
        else:
            raise ValueError("The supplied list index is out of bounds")

    def get_list_of_entities(self):
        """
        Returns a list of all entities in the original text, in the order they appear. There may be repeated
        entities in this list.
        :return: A list of all entities in the original text.
        """
        # create a list where each item is the value of word from each of the dictionaries in self.results
        return [item.get(DictKey.WORD.value) for item in self.results]

    def entity_markdown(self):
        """
        Convert a string to markdown format and change the color of specified substrings to red.
        """
        self.markdown = self.text_to_analyse

        for substring in self.get_list_of_entities():
            self.markdown = self.markdown.replace(substring, f'<span style = "color:red;">{substring}</span>')

        self.markdown_text = self.markdown.replace('\n', '  \n')  # Two spaces at the end of line for markdown new line

    def unique_entities(self):
        """
        Return a list of all unique entities in the original text.
        :return: A list of unique entities.
        """
        unique_set = set()  # Sets are faster than lists for checking membership

        # Create a new list to store the unique strings in order
        unique_list = []

        for string in self.all_entities:
            if string not in unique_set:
                unique_set.add(string)
                unique_list.append(string)

        return unique_list