Xrenya commited on
Commit
aa7dcb7
·
1 Parent(s): 3c8fa95

Upload 8 files

Browse files
app.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import re
3
+ import string
4
+
5
+ import docx2txt
6
+ import fitz
7
+ import gradio as gr
8
+ import joblib
9
+ import matplotlib.pyplot as plt
10
+ import nltk
11
+ import seaborn as sns
12
+ import shap
13
+ import textract
14
+ import torch
15
+ from lime.lime_text import LimeTextExplainer
16
+ from striprtf.striprtf import rtf_to_text
17
+ from transformers import BertForSequenceClassification, BertTokenizer, pipeline
18
+
19
+ from preprocessing import TextCleaner
20
+
21
+ cleaner = TextCleaner()
22
+ pipe = joblib.load('pipe_v1_natasha.joblib')
23
+
24
+ model_path = "finetunebert"
25
+ tokenizer = BertTokenizer.from_pretrained(model_path,
26
+ padding='max_length',
27
+ truncation=True)
28
+ # tokenizer.init_kwargs["model_max_length"] = 512
29
+ model = BertForSequenceClassification.from_pretrained(model_path)
30
+ document_classifier = pipeline("text-classification",
31
+ model=model,
32
+ tokenizer=tokenizer,
33
+ return_all_scores=True)
34
+
35
+ classes = [
36
+ "Договоры поставки", "Договоры оказания услуг", "Договоры подряда",
37
+ "Договоры аренды", "Договоры купли-продажи"
38
+ ]
39
+
40
+
41
+ def old__pipeline(text):
42
+ clean_text = text_preprocessing(text)
43
+ tokens = tokenizer.batch_encode_plus([clean_text],
44
+ max_length=512,
45
+ padding=True,
46
+ truncation=True)
47
+ item = {k: torch.tensor(v) for k, v in tokens.items()}
48
+ preds = model(**item).logits.detach()
49
+ preds = torch.softmax(preds, dim=1)[0]
50
+ output = [{
51
+ 'label': cls,
52
+ 'score': score
53
+ } for cls, score in zip(classes, preds)]
54
+
55
+ return output
56
+
57
+
58
+ def read_doc(file_obj):
59
+ """Read file
60
+ :param file_obj: file object
61
+ :return: string
62
+ """
63
+ text = read_file(file_obj)
64
+ return text
65
+
66
+
67
+ def read_docv2(file_obj):
68
+ """Read file and collect neighbour for visual output
69
+ :param file_obj: file object
70
+ :return: string
71
+ """
72
+ text = read_file(file_obj)
73
+ explainer = LimeTextExplainer(class_names=classes)
74
+ text = cleaner.execute(text)
75
+ exp = explainer.explain_instance(text,
76
+ pipe.predict_proba,
77
+ num_features=10,
78
+ labels=[0, 1, 2, 3, 4])
79
+ scores = exp.as_list()
80
+ scores_desc = sorted(scores, key=lambda t: t[1])[::-1]
81
+ selected_words = [word[0] for word in scores_desc]
82
+ sent = text.split()
83
+ indices = [i for i, word in enumerate(sent) if word in selected_words]
84
+ neighbors = []
85
+ for ind in indices:
86
+ neighbors.append(" ".join(sent[max(0, ind - 3):min(ind +
87
+ 3, len(sent))]))
88
+ return "\n\n".join(neighbors)
89
+
90
+
91
+ def classifier(file_obj):
92
+ """Classify
93
+ :param file_obj: file object
94
+ :return: Dict[str, int]
95
+ """
96
+ text = read_file(file_obj)
97
+ clean_text = text_preprocessing(text)
98
+ tokens = tokenizer.batch_encode_plus([clean_text],
99
+ max_length=512,
100
+ padding=True,
101
+ truncation=True)
102
+ item = {k: torch.tensor(v) for k, v in tokens.items()}
103
+ preds = model(**item).logits.detach()
104
+ preds = torch.softmax(preds, dim=1)[0]
105
+ return {cls: p.item() for cls, p in zip(classes, preds)}
106
+
107
+
108
+ def clean_text(text):
109
+ """Make text lowercase, remove text in square brackets,remove links,remove punctuation
110
+ and remove words containing numbers."""
111
+ text = text.lower()
112
+ text = re.sub('\[.*?\]', '', text)
113
+ text = re.sub('https?://\S+|www\.\S+', '', text)
114
+ text = re.sub('<.*?>+', '', text)
115
+ text = re.sub('[%s]' % re.escape(string.punctuation), '', text)
116
+ text = re.sub('\n', '', text)
117
+ text = re.sub('\w*\d\w*', '', text)
118
+ return text
119
+
120
+
121
+ def text_preprocessing(text):
122
+ """Cleaning and parsing the text."""
123
+ tokenizer = nltk.tokenize.RegexpTokenizer(r'\w+')
124
+ nopunc = clean_text(text)
125
+ tokenized_text = tokenizer.tokenize(nopunc)
126
+ #remove_stopwords = [w for w in tokenized_text if w not in stopwords.words('english')]
127
+ combined_text = ' '.join(tokenized_text)
128
+ return combined_text
129
+
130
+
131
+ def read_file(file_obj):
132
+ """Read file and fixing encoding
133
+ :param file_obj: file object
134
+ :return: string
135
+ """
136
+ if isinstance(file_obj, list):
137
+ file_obj = file_obj[0]
138
+ filename = file_obj.name
139
+ if filename.endswith("docx"):
140
+ text = docx2txt.process(filename)
141
+ elif filename.endswith("pdf"):
142
+ doc = fitz.open(filename)
143
+ text = []
144
+ for page in doc:
145
+ text.append(page.get_text())
146
+ text = " ".join(text)
147
+ elif filename.endswith("doc"):
148
+ text = reinterpret(textract.process(filename))
149
+ text = remove_convert_info(text)
150
+ elif filename.endswith("rtf"):
151
+ with open(filename) as f:
152
+ content = f.read()
153
+ text = rtf_to_text(content)
154
+ else:
155
+ return {"text": []}
156
+ return text
157
+
158
+
159
+ def reinterpret(text: str):
160
+ return text.decode('utf8')
161
+
162
+
163
+ def remove_convert_info(text: str):
164
+ for i, s in enumerate(text):
165
+ if s == ":":
166
+ break
167
+ return text[i + 6:]
168
+
169
+
170
+ def plot_weights(file_obj):
171
+ text = read_file(file_obj)
172
+ explainer = LimeTextExplainer(class_names=classes)
173
+ text = cleaner.execute(text)
174
+ exp = explainer.explain_instance(text,
175
+ pipe.predict_proba,
176
+ num_features=10,
177
+ labels=[0, 1, 2, 3, 4])
178
+ scores = exp.as_list()
179
+ scores_desc = sorted(scores, key=lambda t: t[1])[::-1]
180
+ plt.rcParams.update({'font.size': 35})
181
+ fig = plt.figure(figsize=(20, 20))
182
+ sns.barplot(x=[s[0] for s in scores_desc[:10]],
183
+ y=[s[1] for s in scores_desc[:10]])
184
+ plt.title("Top words contributing to positive sentiment")
185
+ plt.ylabel("Weight")
186
+ plt.xlabel("Word")
187
+ plt.title("Interpreting text predictions with LIME")
188
+ plt.xticks(rotation=20)
189
+ plt.tight_layout()
190
+ return fig
191
+
192
+
193
+ def interpretation_function(file_obj):
194
+ text = read_file(file_obj)
195
+ clean_text = text_preprocessing(text)
196
+ explainer = shap.Explainer(document_classifier)
197
+ shap_values = explainer([clean_text])
198
+
199
+ # Dimensions are (batch size, text size, number of classes)
200
+ # Since we care about positive sentiment, use index 1
201
+ scores = list(zip(shap_values.data[0], shap_values.values[0, :, 1]))
202
+ # Scores contains (word, score) pairs
203
+ # Format expected by gr.components.Interpretation
204
+ return {"original": clean_text, "interpretation": scores}
205
+
206
+
207
+ def as_pyplot_figure(file_obj):
208
+ text = read_file(file_obj)
209
+ explainer = LimeTextExplainer(class_names=classes)
210
+ text = cleaner.execute(text)
211
+ exp = explainer.explain_instance(text,
212
+ pipe.predict_proba,
213
+ num_features=10,
214
+ labels=[0, 1, 2, 3, 4])
215
+ buf = io.BytesIO()
216
+ fig = exp.as_pyplot_figure()
217
+ fig.tight_layout()
218
+ plt.rcParams.update({'font.size': 10})
219
+ plt.savefig(buf)
220
+ buf.seek(0)
221
+ return fig
222
+
223
+
224
+ with gr.Blocks() as demo:
225
+ gr.Markdown("""**Document classification**""")
226
+ with gr.Row():
227
+ with gr.Column():
228
+ file = gr.File(label="Input File")
229
+ with gr.Row():
230
+ classify = gr.Button("Classify document")
231
+ read = gr.Button("Get text")
232
+ interpret_lime = gr.Button("Interpret LIME")
233
+ interpret_shap = gr.Button("Interpret SHAP")
234
+ with gr.Column():
235
+ label = gr.Label(label="Predicted Document Class")
236
+ plot = gr.Plot()
237
+ with gr.Column():
238
+ text = gr.Text(label="Selected keywords")
239
+ with gr.Column():
240
+ interpretation = gr.components.Interpretation(text)
241
+ classify.click(classifier, file, label)
242
+ read.click(read_docv2, file, [text])
243
+ interpret_shap.click(interpretation_function, file, interpretation)
244
+ interpret_lime.click(as_pyplot_figure, file, plot)
245
+
246
+ if __name__ == "__main__":
247
+ demo.launch(share=True)
finetunebert/config.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "DeepPavlov/rubert-base-cased-sentence",
3
+ "architectures": [
4
+ "BertForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "classifier_dropout": null,
8
+ "directionality": "bidi",
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 768,
12
+ "id2label": {
13
+ "0": "Договоры_поставки",
14
+ "1": "Договоры_оказания_услуг",
15
+ "2": "Договоры_подряда",
16
+ "3": "Договоры_аренды",
17
+ "4": "Договоры_купли_продажи"
18
+ },
19
+ "initializer_range": 0.02,
20
+ "intermediate_size": 3072,
21
+ "label2id": {
22
+ "Договоры_поставки": 0,
23
+ "Договоры_оказания_услуг": 1,
24
+ "Договоры_подряда": 2,
25
+ "Договоры_аренды": 3,
26
+ "Договоры_купли_продажи": 4
27
+ },
28
+ "layer_norm_eps": 1e-12,
29
+ "max_position_embeddings": 512,
30
+ "model_type": "bert",
31
+ "num_attention_heads": 12,
32
+ "num_hidden_layers": 12,
33
+ "output_past": true,
34
+ "pad_token_id": 0,
35
+ "pooler_fc_size": 768,
36
+ "pooler_num_attention_heads": 12,
37
+ "pooler_num_fc_layers": 3,
38
+ "pooler_size_per_head": 128,
39
+ "pooler_type": "first_token_transform",
40
+ "position_embedding_type": "absolute",
41
+ "problem_type": "single_label_classification",
42
+ "torch_dtype": "float32",
43
+ "transformers_version": "4.25.1",
44
+ "type_vocab_size": 2,
45
+ "use_cache": true,
46
+ "vocab_size": 119547
47
+ }
finetunebert/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62cab11907d1a4647b4060ea92dc23a4637efb0de1ee5f787e0f87263a6d0e25
3
+ size 711501941
finetunebert/special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
finetunebert/tokenizer_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "do_basic_tokenize": true,
4
+ "do_lower_case": false,
5
+ "mask_token": "[MASK]",
6
+ "use_fast": true,
7
+ "model_max_length": 512,
8
+ "padding_side": "right",
9
+ "truncation_side": "left",
10
+ "name_or_path": "DeepPavlov/rubert-base-cased-sentence",
11
+ "never_split": null,
12
+ "pad_token": "[PAD]",
13
+ "sep_token": "[SEP]",
14
+ "special_tokens_map_file": "/home/xrenya/.cache/huggingface/hub/models--DeepPavlov--rubert-base-cased-sentence/snapshots/78b5122d6365337dd4114281b0d08cd1edbb3bc8/special_tokens_map.json",
15
+ "strip_accents": null,
16
+ "tokenize_chinese_chars": true,
17
+ "tokenizer_class": "BertTokenizer",
18
+ "unk_token": "[UNK]"
19
+ }
finetunebert/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
pipe_v1_natasha.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5bd8acea272a9df1c44a2a0d8e9f50d315691b6bf11a7c14e83fb5d35f1d94ba
3
+ size 265645
preprocessing.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import nltk
4
+ from natasha import (Doc, MorphVocab, NamesExtractor, NewsEmbedding,
5
+ NewsMorphTagger, NewsNERTagger, NewsSyntaxParser,
6
+ Segmenter)
7
+ from nltk.corpus import stopwords
8
+
9
+ nltk.download('stopwords')
10
+
11
+
12
+ class TextCleaner:
13
+
14
+ def __init__(self, lemma: bool = True):
15
+ self.lemma = lemma
16
+ self.segmenter = Segmenter()
17
+ self.morph_vocab = MorphVocab()
18
+ emb = NewsEmbedding()
19
+ self.morph_tagger = NewsMorphTagger(emb)
20
+ syntax_parser = NewsSyntaxParser(emb)
21
+ ner_tagger = NewsNERTagger(emb)
22
+ names_extractor = NamesExtractor(self.morph_vocab)
23
+ self.en_stops = stopwords.words('english')
24
+ self.ru_stops = stopwords.words('russian')
25
+ self.punc = '''!()-[]{};:'"\,<>./?@#$%^&*_~'''
26
+ self.words_pattern = '[а-я]+'
27
+
28
+ def execute(self, text):
29
+ text = self.text_preprocessing(text)
30
+ if self.lemma:
31
+ text = self.lemmatize(text)
32
+ return text
33
+
34
+ def text_preprocessing(self, data):
35
+ data = " ".join(x.lower() for x in data.split())
36
+ data = data.replace('[^\w\s]', '')
37
+ data = " ".join(x for x in data.split()
38
+ if x not in self.ru_stops and x not in self.en_stops)
39
+ for punc in self.punc:
40
+ if punc in data:
41
+ data = data.replace(punc, "")
42
+ data = re.sub(' +', ' ', data)
43
+ return " ".join(
44
+ re.findall(self.words_pattern, data, flags=re.IGNORECASE))
45
+
46
+ def lemmatize(self, text):
47
+ doc = Doc(text)
48
+ doc.segment(self.segmenter)
49
+ doc.tag_morph(self.morph_tagger)
50
+ for token in doc.tokens:
51
+ token.lemmatize(self.morph_vocab)
52
+ tokens = []
53
+ for token in doc.tokens:
54
+ tokens.append(token.lemma)
55
+ return " ".join(tokens)