1-800-BAD-CODE commited on
Commit
c224a7e
·
1 Parent(s): 7fed67b

add manual usage example

Browse files
Files changed (1) hide show
  1. README.md +127 -0
README.md CHANGED
@@ -69,6 +69,8 @@ and detect sentence boundaries (full stops) in 47 languages.
69
 
70
  # Usage
71
 
 
 
72
  The easiest way to use this model is to install [punctuators](https://github.com/1-800-BAD-CODE/punctuators):
73
 
74
  ```bash
@@ -178,6 +180,130 @@ Outputs:
178
 
179
  </details>
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  # Model Architecture
182
  This model implements the following graph, which allows punctuation, true-casing, and fullstop prediction
183
  in every language without language-specific behavior:
@@ -735,6 +861,7 @@ seg test report:
735
 
736
  </details>
737
 
 
738
 
739
  # Extra Stuff
740
 
 
69
 
70
  # Usage
71
 
72
+ ## Usage via `punctuators` package
73
+
74
  The easiest way to use this model is to install [punctuators](https://github.com/1-800-BAD-CODE/punctuators):
75
 
76
  ```bash
 
180
 
181
  </details>
182
 
183
+
184
+ ## Manual Usage
185
+ If you want to use the ONNX and SP models without wrappers, see the following example.
186
+
187
+ <details>
188
+
189
+ <summary>Click to see manual usage</summary>
190
+
191
+
192
+ ```python
193
+ from typing import List
194
+
195
+ import numpy as np
196
+ import onnxruntime as ort
197
+ from huggingface_hub import hf_hub_download
198
+ from omegaconf import OmegaConf
199
+ from sentencepiece import SentencePieceProcessor
200
+
201
+ # Download the models from HF hub. Note: to clean up, you can find these files in your HF cache directory
202
+ spe_path = hf_hub_download(repo_id="1-800-BAD-CODE/xlm-roberta_punctuation_fullstop_truecase", filename="sp.model")
203
+ onnx_path = hf_hub_download(repo_id="1-800-BAD-CODE/xlm-roberta_punctuation_fullstop_truecase", filename="model.onnx")
204
+ config_path = hf_hub_download(
205
+ repo_id="1-800-BAD-CODE/xlm-roberta_punctuation_fullstop_truecase", filename="config.yaml"
206
+ )
207
+
208
+ # Load the SP model
209
+ tokenizer: SentencePieceProcessor = SentencePieceProcessor(spe_path) # noqa
210
+ # Load the ONNX graph
211
+ ort_session: ort.InferenceSession = ort.InferenceSession(onnx_path)
212
+ # Load the model config with labels, etc.
213
+ config = OmegaConf.load(config_path)
214
+ # Potential classification labels before each subtoken
215
+ pre_labels: List[str] = config.pre_labels
216
+ # Potential classification labels after each subtoken
217
+ post_labels: List[str] = config.post_labels
218
+ # Special class that means "predict nothing"
219
+ null_token = config.get("null_token", "<NULL>")
220
+ # Special class that means "all chars in this subtoken end with a period", e.g., "am" -> "a.m."
221
+ acronym_token = config.get("acronym_token", "<ACRONYM>")
222
+ # Not used in this example, but if your sequence exceed this value, you need to fold it over multiple inputs
223
+ max_len = config.max_length
224
+ # For reference only, graph has no language-specific behavior
225
+ languages: List[str] = config.languages
226
+
227
+ # Encode some input text, adding BOS + EOS
228
+ input_text = "hola mundo cómo estás estamos bajo el sol y hace mucho calor santa coloma abre los huertos urbanos a las escuelas de la ciudad"
229
+ input_ids = [tokenizer.bos_id()] + tokenizer.EncodeAsIds(input_text) + [tokenizer.eos_id()]
230
+
231
+ # Create a numpy array with shape [B, T], as the graph expects as input.
232
+ # Note that we do not pass lengths to the graph; if you are using a batch, padding should be tokenizer.pad_id() and the
233
+ # graph's attention mechanisms will ignore pad_id() without requiring explicit sequence lengths.
234
+ input_ids_arr: np.array = np.array([input_ids])
235
+
236
+ # Run the graph, get outputs for all analytics
237
+ pre_preds, post_preds, cap_preds, sbd_preds = ort_session.run(None, {"input_ids": input_ids_arr})
238
+ # Squeeze off the batch dimensions and convert to lists
239
+ pre_preds = pre_preds[0].tolist()
240
+ post_preds = post_preds[0].tolist()
241
+ cap_preds = cap_preds[0].tolist()
242
+ sbd_preds = sbd_preds[0].tolist()
243
+
244
+ # Segmented sentences
245
+ output_texts: List[str] = []
246
+ # Current sentence, which is built until we hit a sentence boundary prediction
247
+ current_chars: List[str] = []
248
+ # Iterate over the outputs, ignoring the first (BOS) and final (EOS) predictions and tokens
249
+ for token_idx in range(1, len(input_ids) - 1):
250
+ token = tokenizer.IdToPiece(input_ids[token_idx])
251
+ # Simple SP decoding
252
+ if token.startswith("▁") and current_chars:
253
+ current_chars.append(" ")
254
+ # Token-level predictions
255
+ pre_label = pre_labels[pre_preds[token_idx]]
256
+ post_label = post_labels[post_preds[token_idx]]
257
+ # If we predict "pre-punct", insert it before this token
258
+ if pre_label != null_token:
259
+ current_chars.append(pre_label)
260
+ # Iterate over each char. Skip SP's space token,
261
+ char_start = 1 if token.startswith("▁") else 0
262
+ for token_char_idx, char in enumerate(token[char_start:], start=char_start):
263
+ # If this char should be capitalized, apply upper case
264
+ if cap_preds[token_idx][token_char_idx]:
265
+ char = char.upper()
266
+ # Append char
267
+ current_chars.append(char)
268
+ # if this is an acronym, add a period after every char (p.m., a.m., etc.)
269
+ if post_label == acronym_token:
270
+ current_chars.append(".")
271
+ # Maybe this subtoken ends with punctuation
272
+ if post_label != null_token and post_label != acronym_token:
273
+ current_chars.append(post_label)
274
+
275
+ # If this token is a sentence boundary, finalize the current sentence and reset
276
+ if sbd_preds[token_idx]:
277
+ output_texts.append("".join(current_chars))
278
+ current_chars.clear()
279
+
280
+ # Maybe push final sentence, if the final token was not classified as a sentence boundary
281
+ if current_chars:
282
+ output_texts.append("".join(current_chars))
283
+
284
+ # Pretty print
285
+ print(f"Input: {input_text}")
286
+ print("Outputs:")
287
+ for text in output_texts:
288
+ print(f"\t{text}")
289
+
290
+ ```
291
+
292
+ Expected output:
293
+
294
+ ```text
295
+ Input: hola mundo cómo estás estamos bajo el sol y hace mucho calor santa coloma abre los huertos urbanos a las escuelas de la ciudad
296
+ Outputs:
297
+ Hola mundo, ¿cómo estás?
298
+ Estamos bajo el sol y hace mucho calor.
299
+ Santa Coloma abre los huertos urbanos a las escuelas de la ciudad.
300
+ ```
301
+
302
+ </details>
303
+
304
+ &nbsp;
305
+
306
+
307
  # Model Architecture
308
  This model implements the following graph, which allows punctuation, true-casing, and fullstop prediction
309
  in every language without language-specific behavior:
 
861
 
862
  </details>
863
 
864
+ &nbsp;
865
 
866
  # Extra Stuff
867