Add wrapper for punctuation
Browse files- sbert-punc-case-ru/__init__.py +1 -0
- sbert-punc-case-ru/sbertpunccase.py +192 -0
- setup.py +19 -0
sbert-punc-case-ru/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .sbertpunccase import SbertPuncCase
|
sbert-punc-case-ru/sbertpunccase.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
9 |
+
import re
|
10 |
+
import string
|
11 |
+
from typing import List, Optional
|
12 |
+
|
13 |
+
|
14 |
+
TOKEN_RE = re.compile(r'-?\d*\.\d+|[a-zа-яё]+|-?[\d\+\(\)\-]+|\S', re.I)
|
15 |
+
"""
|
16 |
+
Регулярка, для того чтобы выделять в отдельные токены знаки препинания, числа и слова. А именно:
|
17 |
+
- Числа с плавающей точкой вида 123.23 выделяются в один токен. Десятичным разделителем рассматривается только точка
|
18 |
+
- Число может быть отрицательным: иметь знак -123.4
|
19 |
+
- Целой части числа может вовсе не быть: последовательности -0.15 и −.15 означают одно и то же число.
|
20 |
+
- При этом числа с нулевой дробной частью не допускаются: строка "12345." будет разделена на два токена "12345" и "."
|
21 |
+
- Идущие подряд знаки препинания выделяются каждый в отдельный токен.
|
22 |
+
- Телефонные номера выделяются в один токен +7(999)164-20-69
|
23 |
+
- Множество букв в словах ограничивается только кириллическим и англ алфавитом (33 буквы и 26 cоотв).
|
24 |
+
"""
|
25 |
+
|
26 |
+
# Прогнозируемые знаки препинания
|
27 |
+
PUNK_MAPPING = {'.': 'PERIOD', ',': 'COMMA', '?': 'QUESTION'}
|
28 |
+
|
29 |
+
# Прогнозируемый регистр LOWER - нижний регистр, UPPER - верхний регистр для первого символа, UPPER_TOTAL - верхний регистр для всех символов
|
30 |
+
LABELS_CASE = ['LOWER', 'UPPER', 'UPPER_TOTAL']
|
31 |
+
# Добавим в пунктуацию метку O означающий отсутсвие пунктуации
|
32 |
+
LABELS_PUNC = ['O'] + list(PUNK_MAPPING.values())
|
33 |
+
|
34 |
+
# Сформируем метки на основе комбинаций регистра и пунктуации
|
35 |
+
LABELS_list = []
|
36 |
+
for case in LABELS_CASE:
|
37 |
+
for punc in LABELS_PUNC:
|
38 |
+
LABELS_list.append(f'{case}_{punc}')
|
39 |
+
LABELS = {label: i+1 for i, label in enumerate(LABELS_list)}
|
40 |
+
LABELS['O'] = -100
|
41 |
+
INVERSE_LABELS = {i: label for label, i in LABELS.items()}
|
42 |
+
|
43 |
+
LABEL_TO_PUNC_LABEL = {label: label.split('_')[-1] for label in LABELS.keys() if label != 'O'}
|
44 |
+
LABEL_TO_CASE_LABEL = {label: '_'.join(label.split('_')[:-1]) for label in LABELS.keys() if label != 'O'}
|
45 |
+
|
46 |
+
|
47 |
+
def token_to_label(token, label):
|
48 |
+
if type(label) == int:
|
49 |
+
label = INVERSE_LABELS[label]
|
50 |
+
if label == 'LOWER_O':
|
51 |
+
return token
|
52 |
+
if label == 'LOWER_PERIOD':
|
53 |
+
return token + '.'
|
54 |
+
if label == 'LOWER_COMMA':
|
55 |
+
return token + ','
|
56 |
+
if label == 'LOWER_QUESTION':
|
57 |
+
return token + '?'
|
58 |
+
if label == 'UPPER_O':
|
59 |
+
return token.capitalize()
|
60 |
+
if label == 'UPPER_PERIOD':
|
61 |
+
return token.capitalize() + '.'
|
62 |
+
if label == 'UPPER_COMMA':
|
63 |
+
return token.capitalize() + ','
|
64 |
+
if label == 'UPPER_QUESTION':
|
65 |
+
return token.capitalize() + '?'
|
66 |
+
if label == 'UPPER_TOTAL_O':
|
67 |
+
return token.upper()
|
68 |
+
if label == 'UPPER_TOTAL_PERIOD':
|
69 |
+
return token.upper() + '.'
|
70 |
+
if label == 'UPPER_TOTAL_COMMA':
|
71 |
+
return token.upper() + ','
|
72 |
+
if label == 'UPPER_TOTAL_QUESTION':
|
73 |
+
return token.upper() + '?'
|
74 |
+
if label == 'O':
|
75 |
+
return token
|
76 |
+
|
77 |
+
|
78 |
+
def decode_label(label, classes='all'):
|
79 |
+
if classes == 'punc':
|
80 |
+
return LABEL_TO_PUNC_LABEL[INVERSE_LABELS[label]]
|
81 |
+
if classes == 'case':
|
82 |
+
return LABEL_TO_CASE_LABEL[INVERSE_LABELS[label]]
|
83 |
+
else:
|
84 |
+
return INVERSE_LABELS[label]
|
85 |
+
|
86 |
+
|
87 |
+
def make_labeling(text: str):
|
88 |
+
# Разобъем предложение на слова и знаки препинания
|
89 |
+
tokens = TOKEN_RE.findall(text)
|
90 |
+
# Предобработаем слова, удалим знаки препинания и зададим метки
|
91 |
+
|
92 |
+
preprocessed_tokens = []
|
93 |
+
token_labels: List[List[str]] = []
|
94 |
+
|
95 |
+
# Убираем всю пунктуацию в начале предложения
|
96 |
+
while tokens[0] in string.punctuation:
|
97 |
+
tokens.pop(0)
|
98 |
+
|
99 |
+
for token in tokens:
|
100 |
+
if token in string.punctuation:
|
101 |
+
# Если встретился знак препинания который мы прогнозируем изменим метку предыдущего слова, иначе проигнорируем его
|
102 |
+
if token in PUNK_MAPPING:
|
103 |
+
token_labels[-1][1] = PUNK_MAPPING[token]
|
104 |
+
else:
|
105 |
+
# Если встретилось слово, то укажем метку регистра и добавим в список предобработанных слов в ни��нем регистре
|
106 |
+
if sum(char.isupper() for char in token) > 1:
|
107 |
+
token_labels.append(['UPPER_TOTAL', 'O'])
|
108 |
+
elif token[0].isupper():
|
109 |
+
token_labels.append(['UPPER', 'O'])
|
110 |
+
else:
|
111 |
+
token_labels.append(['LOWER', 'O'])
|
112 |
+
preprocessed_tokens.append(token.lower())
|
113 |
+
token_labels_merged = ['_'.join(label) for label in token_labels]
|
114 |
+
token_labels_ids = [LABELS[label] for label in token_labels_merged]
|
115 |
+
return dict(words=preprocessed_tokens, labels=token_labels_merged, label_ids=token_labels_ids)
|
116 |
+
|
117 |
+
|
118 |
+
def align_labels(label_ids: list[int], word_ids: list[Optional[int]]):
|
119 |
+
aligned_label_ids = []
|
120 |
+
previous_id = None
|
121 |
+
for word_id in word_ids:
|
122 |
+
if word_id is None or word_id == previous_id:
|
123 |
+
aligned_label_ids.append(LABELS['O'])
|
124 |
+
else:
|
125 |
+
aligned_label_ids.append(label_ids.pop(0))
|
126 |
+
previous_id = word_id
|
127 |
+
return aligned_label_ids
|
128 |
+
|
129 |
+
|
130 |
+
MODEL_REPO = "kontur-ai/sbert-punc-case-ru"
|
131 |
+
|
132 |
+
|
133 |
+
class SbertPuncCase(nn.Module):
|
134 |
+
def __init__(self):
|
135 |
+
super().__init__()
|
136 |
+
|
137 |
+
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO,
|
138 |
+
revision="sbert",
|
139 |
+
use_auth_token=True,
|
140 |
+
strip_accents=False)
|
141 |
+
self.model = AutoModelForTokenClassification.from_pretrained(MODEL_REPO,
|
142 |
+
revision="sbert",
|
143 |
+
use_auth_token=True
|
144 |
+
)
|
145 |
+
self.model.eval()
|
146 |
+
|
147 |
+
def forward(self, input_ids, attention_mask):
|
148 |
+
return self.model(input_ids=input_ids,
|
149 |
+
attention_mask=attention_mask)
|
150 |
+
|
151 |
+
def punctuate(self, text):
|
152 |
+
text = text.strip().lower()
|
153 |
+
|
154 |
+
# preprocess
|
155 |
+
words_with_labels = make_labeling(text)
|
156 |
+
words = words_with_labels['words']
|
157 |
+
label_ids = words_with_labels['label_ids']
|
158 |
+
|
159 |
+
tokenizer_output = self.tokenizer(words, is_split_into_words=True)
|
160 |
+
aligned_label_ids = [align_labels(label_ids, tokenizer_output.word_ids())]
|
161 |
+
|
162 |
+
result = dict(tokenizer_output)
|
163 |
+
result.update({'labels': aligned_label_ids})
|
164 |
+
|
165 |
+
if len(result['input_ids']) > 512:
|
166 |
+
return ' '.join([self.punctuate(' '.join(text_part)) for text_part in np.array_split(words, 2)])
|
167 |
+
|
168 |
+
predictions = self(torch.tensor([result['input_ids']], device=self.model.device),
|
169 |
+
torch.tensor([result['attention_mask']], device=self.model.device)).logits.cpu().data.numpy()
|
170 |
+
predictions = np.argmax(predictions, axis=2)
|
171 |
+
|
172 |
+
# decode punctuation and casing
|
173 |
+
splitted_text = []
|
174 |
+
word_ids = tokenizer_output.word_ids()
|
175 |
+
for i, word in enumerate(words):
|
176 |
+
label_pos = word_ids.index(i)
|
177 |
+
label_id = predictions[0][label_pos]
|
178 |
+
label = decode_label(label_id)
|
179 |
+
splitted_text.append(token_to_label(word, label))
|
180 |
+
capitalized_text = ' '.join(splitted_text)
|
181 |
+
return capitalized_text
|
182 |
+
|
183 |
+
|
184 |
+
if __name__ == '__main__':
|
185 |
+
parser = argparse.ArgumentParser("Punctuation and case restoration model sbert-punc-case-ru")
|
186 |
+
parser.add_argument("-i", "--input", type=str, help="text to restore", default='SbertPuncCase расставляет точки запятые и знаки вопроса вам нравится')
|
187 |
+
parser.add_argument("-d", "--device", type=str, help="run model on cpu or gpu", choices=['cpu', 'cuda'], default='cpu')
|
188 |
+
args = parser.parse_args()
|
189 |
+
print(f"Source text: {args.input}\n")
|
190 |
+
sbertpunc = SbertPuncCase().to(args.device)
|
191 |
+
punctuated_text = sbertpunc.punctuate(args.input)
|
192 |
+
print(f"Restored text: {punctuated_text}")
|
setup.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from distutils.core import setup
|
2 |
+
|
3 |
+
setup(name='sbert-punc-case-ru',
|
4 |
+
version='0.1',
|
5 |
+
description='Punctuation and Case Restoration model based on https://huggingface.co/sberbank-ai/sbert_large_nlu_ru',
|
6 |
+
author='Almira Murtazina',
|
7 |
+
author_email='[email protected]',
|
8 |
+
packages=['sbert-punc-case-ru'],
|
9 |
+
install_requires=['transformers>=4.18.3'],
|
10 |
+
classifiers=[
|
11 |
+
"Operating System :: OS Independent",
|
12 |
+
"Programming Language :: Python :: 3",
|
13 |
+
"Programming Language :: Python :: 3.6",
|
14 |
+
"Programming Language :: Python :: 3.7",
|
15 |
+
"Programming Language :: Python :: 3.8",
|
16 |
+
"Programming Language :: Python :: 3.9",
|
17 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
18 |
+
]
|
19 |
+
)
|