1-800-BAD-CODE
commited on
Commit
·
c224a7e
1
Parent(s):
7fed67b
add manual usage example
Browse files
README.md
CHANGED
@@ -69,6 +69,8 @@ and detect sentence boundaries (full stops) in 47 languages.
|
|
69 |
|
70 |
# Usage
|
71 |
|
|
|
|
|
72 |
The easiest way to use this model is to install [punctuators](https://github.com/1-800-BAD-CODE/punctuators):
|
73 |
|
74 |
```bash
|
@@ -178,6 +180,130 @@ Outputs:
|
|
178 |
|
179 |
</details>
|
180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
# Model Architecture
|
182 |
This model implements the following graph, which allows punctuation, true-casing, and fullstop prediction
|
183 |
in every language without language-specific behavior:
|
@@ -735,6 +861,7 @@ seg test report:
|
|
735 |
|
736 |
</details>
|
737 |
|
|
|
738 |
|
739 |
# Extra Stuff
|
740 |
|
|
|
69 |
|
70 |
# Usage
|
71 |
|
72 |
+
## Usage via `punctuators` package
|
73 |
+
|
74 |
The easiest way to use this model is to install [punctuators](https://github.com/1-800-BAD-CODE/punctuators):
|
75 |
|
76 |
```bash
|
|
|
180 |
|
181 |
</details>
|
182 |
|
183 |
+
|
184 |
+
## Manual Usage
|
185 |
+
If you want to use the ONNX and SP models without wrappers, see the following example.
|
186 |
+
|
187 |
+
<details>
|
188 |
+
|
189 |
+
<summary>Click to see manual usage</summary>
|
190 |
+
|
191 |
+
|
192 |
+
```python
|
193 |
+
from typing import List
|
194 |
+
|
195 |
+
import numpy as np
|
196 |
+
import onnxruntime as ort
|
197 |
+
from huggingface_hub import hf_hub_download
|
198 |
+
from omegaconf import OmegaConf
|
199 |
+
from sentencepiece import SentencePieceProcessor
|
200 |
+
|
201 |
+
# Download the models from HF hub. Note: to clean up, you can find these files in your HF cache directory
|
202 |
+
spe_path = hf_hub_download(repo_id="1-800-BAD-CODE/xlm-roberta_punctuation_fullstop_truecase", filename="sp.model")
|
203 |
+
onnx_path = hf_hub_download(repo_id="1-800-BAD-CODE/xlm-roberta_punctuation_fullstop_truecase", filename="model.onnx")
|
204 |
+
config_path = hf_hub_download(
|
205 |
+
repo_id="1-800-BAD-CODE/xlm-roberta_punctuation_fullstop_truecase", filename="config.yaml"
|
206 |
+
)
|
207 |
+
|
208 |
+
# Load the SP model
|
209 |
+
tokenizer: SentencePieceProcessor = SentencePieceProcessor(spe_path) # noqa
|
210 |
+
# Load the ONNX graph
|
211 |
+
ort_session: ort.InferenceSession = ort.InferenceSession(onnx_path)
|
212 |
+
# Load the model config with labels, etc.
|
213 |
+
config = OmegaConf.load(config_path)
|
214 |
+
# Potential classification labels before each subtoken
|
215 |
+
pre_labels: List[str] = config.pre_labels
|
216 |
+
# Potential classification labels after each subtoken
|
217 |
+
post_labels: List[str] = config.post_labels
|
218 |
+
# Special class that means "predict nothing"
|
219 |
+
null_token = config.get("null_token", "<NULL>")
|
220 |
+
# Special class that means "all chars in this subtoken end with a period", e.g., "am" -> "a.m."
|
221 |
+
acronym_token = config.get("acronym_token", "<ACRONYM>")
|
222 |
+
# Not used in this example, but if your sequence exceed this value, you need to fold it over multiple inputs
|
223 |
+
max_len = config.max_length
|
224 |
+
# For reference only, graph has no language-specific behavior
|
225 |
+
languages: List[str] = config.languages
|
226 |
+
|
227 |
+
# Encode some input text, adding BOS + EOS
|
228 |
+
input_text = "hola mundo cómo estás estamos bajo el sol y hace mucho calor santa coloma abre los huertos urbanos a las escuelas de la ciudad"
|
229 |
+
input_ids = [tokenizer.bos_id()] + tokenizer.EncodeAsIds(input_text) + [tokenizer.eos_id()]
|
230 |
+
|
231 |
+
# Create a numpy array with shape [B, T], as the graph expects as input.
|
232 |
+
# Note that we do not pass lengths to the graph; if you are using a batch, padding should be tokenizer.pad_id() and the
|
233 |
+
# graph's attention mechanisms will ignore pad_id() without requiring explicit sequence lengths.
|
234 |
+
input_ids_arr: np.array = np.array([input_ids])
|
235 |
+
|
236 |
+
# Run the graph, get outputs for all analytics
|
237 |
+
pre_preds, post_preds, cap_preds, sbd_preds = ort_session.run(None, {"input_ids": input_ids_arr})
|
238 |
+
# Squeeze off the batch dimensions and convert to lists
|
239 |
+
pre_preds = pre_preds[0].tolist()
|
240 |
+
post_preds = post_preds[0].tolist()
|
241 |
+
cap_preds = cap_preds[0].tolist()
|
242 |
+
sbd_preds = sbd_preds[0].tolist()
|
243 |
+
|
244 |
+
# Segmented sentences
|
245 |
+
output_texts: List[str] = []
|
246 |
+
# Current sentence, which is built until we hit a sentence boundary prediction
|
247 |
+
current_chars: List[str] = []
|
248 |
+
# Iterate over the outputs, ignoring the first (BOS) and final (EOS) predictions and tokens
|
249 |
+
for token_idx in range(1, len(input_ids) - 1):
|
250 |
+
token = tokenizer.IdToPiece(input_ids[token_idx])
|
251 |
+
# Simple SP decoding
|
252 |
+
if token.startswith("▁") and current_chars:
|
253 |
+
current_chars.append(" ")
|
254 |
+
# Token-level predictions
|
255 |
+
pre_label = pre_labels[pre_preds[token_idx]]
|
256 |
+
post_label = post_labels[post_preds[token_idx]]
|
257 |
+
# If we predict "pre-punct", insert it before this token
|
258 |
+
if pre_label != null_token:
|
259 |
+
current_chars.append(pre_label)
|
260 |
+
# Iterate over each char. Skip SP's space token,
|
261 |
+
char_start = 1 if token.startswith("▁") else 0
|
262 |
+
for token_char_idx, char in enumerate(token[char_start:], start=char_start):
|
263 |
+
# If this char should be capitalized, apply upper case
|
264 |
+
if cap_preds[token_idx][token_char_idx]:
|
265 |
+
char = char.upper()
|
266 |
+
# Append char
|
267 |
+
current_chars.append(char)
|
268 |
+
# if this is an acronym, add a period after every char (p.m., a.m., etc.)
|
269 |
+
if post_label == acronym_token:
|
270 |
+
current_chars.append(".")
|
271 |
+
# Maybe this subtoken ends with punctuation
|
272 |
+
if post_label != null_token and post_label != acronym_token:
|
273 |
+
current_chars.append(post_label)
|
274 |
+
|
275 |
+
# If this token is a sentence boundary, finalize the current sentence and reset
|
276 |
+
if sbd_preds[token_idx]:
|
277 |
+
output_texts.append("".join(current_chars))
|
278 |
+
current_chars.clear()
|
279 |
+
|
280 |
+
# Maybe push final sentence, if the final token was not classified as a sentence boundary
|
281 |
+
if current_chars:
|
282 |
+
output_texts.append("".join(current_chars))
|
283 |
+
|
284 |
+
# Pretty print
|
285 |
+
print(f"Input: {input_text}")
|
286 |
+
print("Outputs:")
|
287 |
+
for text in output_texts:
|
288 |
+
print(f"\t{text}")
|
289 |
+
|
290 |
+
```
|
291 |
+
|
292 |
+
Expected output:
|
293 |
+
|
294 |
+
```text
|
295 |
+
Input: hola mundo cómo estás estamos bajo el sol y hace mucho calor santa coloma abre los huertos urbanos a las escuelas de la ciudad
|
296 |
+
Outputs:
|
297 |
+
Hola mundo, ¿cómo estás?
|
298 |
+
Estamos bajo el sol y hace mucho calor.
|
299 |
+
Santa Coloma abre los huertos urbanos a las escuelas de la ciudad.
|
300 |
+
```
|
301 |
+
|
302 |
+
</details>
|
303 |
+
|
304 |
+
|
305 |
+
|
306 |
+
|
307 |
# Model Architecture
|
308 |
This model implements the following graph, which allows punctuation, true-casing, and fullstop prediction
|
309 |
in every language without language-specific behavior:
|
|
|
861 |
|
862 |
</details>
|
863 |
|
864 |
+
|
865 |
|
866 |
# Extra Stuff
|
867 |
|