seanius commited on
Commit
5cb5319
·
verified ·
1 Parent(s): c8d2dcc

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +44 -2
README.md CHANGED
@@ -4,6 +4,7 @@ datasets:
4
  language:
5
  - en
6
  library_name: transformers
 
7
  ---
8
  ONNX model - a fine tuned version of DistilBERT which can be used to classify text as one of:
9
  - neutral, offensive_language, harmful_behaviour, hate_speech
@@ -12,10 +13,51 @@ The model was trained using the [csfy tool](https://github.com/mrseanryan/csfy)
12
 
13
  The base model is required (distilbert-base-uncased)
14
 
15
- For an example of how to run the model, see the [csfy tool](https://github.com/mrseanryan/csfy).
16
 
17
- The output is a number indicating the class - you can decode that via the label_mapping.json file.
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  ---
21
  license: mit
 
4
  language:
5
  - en
6
  library_name: transformers
7
+ base_model: distilbert/distilbert-base-uncased
8
  ---
9
  ONNX model - a fine tuned version of DistilBERT which can be used to classify text as one of:
10
  - neutral, offensive_language, harmful_behaviour, hate_speech
 
13
 
14
  The base model is required (distilbert-base-uncased)
15
 
16
+ For an example of how to run the model, see below - or see the [csfy tool](https://github.com/mrseanryan/csfy).
17
 
18
+ The output is a number indicating the class - it is decoded via the label_mapping.json file.
19
 
20
+ # Usage
21
+
22
+ ```python
23
+ # Loading the label mappings
24
+ import json
25
+ def load_label_mappings():
26
+ with open("./label_mapping.json", encoding="utf-8") as f:
27
+ data = json.load(f)
28
+ return data['labels']
29
+
30
+ label_mappings = load_label_mappings()
31
+
32
+ # Loading the model
33
+ import onnxruntime as ort
34
+ from transformers import DistilBertTokenizer
35
+ tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
36
+ ort_session = ort.InferenceSession("./toxic-or-neutral-text-labelled.onnx")
37
+
38
+ # Predicting label for given text
39
+ def predict_via_onnx(text, ort_session, tokenizer, label_mappings):
40
+ model_expected_input_shape = ort_session.get_inputs()[0].shape
41
+ print("Model expects input shape:", model_expected_input_shape)
42
+ inputs = tokenizer(text, return_tensors="np", padding="max_length", truncation=True, max_length=model_expected_input_shape[1])
43
+ print("input shape", inputs['input_ids'].shape)
44
+
45
+ input_ids = inputs['input_ids']
46
+ if input_ids.ndim == 1:
47
+ input_ids = input_ids[np.newaxis, :]
48
+ ort_inputs = {ort_session.get_inputs()[0].name: input_ids}
49
+
50
+ ort_inputs['input_ids'] = ort_inputs['input_ids'].astype(np.int64)
51
+
52
+ ort_outputs = ort_session.run(None, ort_inputs)
53
+ predictions = np.argmax(ort_outputs, axis=-1)
54
+
55
+ predicted_label = label_mappings[predictions.item()]
56
+ return predicted_label
57
+
58
+ predicted_label = predict_via_onnx("How do I get to the beach?", ort_session, tokenizer, label_mappings)
59
+ print(predicted_label)
60
+ ```
61
 
62
  ---
63
  license: mit