Text Classification Toxicity

This model is a fined-tuned version of MiniLMv2-L6-H384 on the on the Jigsaw 1st Kaggle competition dataset using unitary/toxic-bert as teacher model. The original unquantized model can be found here.

The model contains two labels only (toxicity and severe toxicity). For the model with all labels refer to this page

Optimum

Installation

Install from source:

python -m pip install optimum[onnxruntime]@git+https://github.com/huggingface/optimum.git

Run the Model

from optimum.onnxruntime import ORTModelForSequenceClassification
from transformers import AutoTokenizer, pipeline

model = ORTModelForSequenceClassification.from_pretrained('minuva/MiniLMv2-toxic-jigsaw-lite-onnx', provider="CPUExecutionProvider")
tokenizer = AutoTokenizer.from_pretrained('minuva/MiniLMv2-toxic-jigsaw-lite-onnx', use_fast=True, model_max_length=256, truncation=True, padding='max_length')

pipe = pipeline(task='text-classification', model=model, tokenizer=tokenizer, )
texts = ["This is pure trash",]
pipe(texts)
# [{'label': 'toxic', 'score': 0.6553249955177307}]

ONNX Runtime only

A lighter solution for deployment

Installation

pip install tokenizers
pip install onnxruntime
git clone https://huggingface.co/minuva/MiniLMv2-toxic-jigsaw-lite-onnx

Load the Model

import os
import numpy as np
import json

from tokenizers import Tokenizer
from onnxruntime import InferenceSession


model_name = "minuva/MiniLMv2-toxic-jigsaw-lite-onnx"
tokenizer = Tokenizer.from_pretrained(model_name)
tokenizer.enable_padding()
tokenizer.enable_truncation(max_length=256)
batch_size = 16

texts = ["This is pure trash",]
outputs = []
model = InferenceSession("MiniLMv2-toxic-jigsaw-lite-onnx/model_optimized_quantized.onnx", providers=['CPUExecutionProvider'])

with open(os.path.join("MiniLMv2-toxic-jigsaw-lite-onnx", "config.json"), "r") as f:
            config = json.load(f)

output_names = [output.name for output in model.get_outputs()]
input_names = [input.name for input in model.get_inputs()]

for subtexts in np.array_split(np.array(texts), len(texts) // batch_size + 1):
            encodings = tokenizer.encode_batch(list(subtexts))
            inputs = {
                "input_ids": np.vstack(
                    [encoding.ids for encoding in encodings],
                ),
                "attention_mask": np.vstack(
                    [encoding.attention_mask for encoding in encodings],
                ),
                "token_type_ids": np.vstack(
                    [encoding.type_ids for encoding in encodings],
                ),
            }

            for input_name in input_names:
                if input_name not in inputs:
                    raise ValueError(f"Input name {input_name} not found in inputs")

            inputs = {input_name: inputs[input_name] for input_name in input_names}
            output = np.squeeze(
                np.stack(
                    model.run(output_names=output_names, input_feed=inputs)
                ),
                axis=0,
            )
            outputs.append(output)

outputs = np.concatenate(outputs, axis=0)
scores = 1 / (1 + np.exp(-outputs))
results = []
for item in scores:
    labels = []
    scores = []
    for idx, s in enumerate(item):
        labels.append(config["id2label"][str(idx)])
        scores.append(float(s))
    results.append({"labels": labels, "scores": scores})

res = []

for result in results:
    joined = list(zip(result['labels'], result['scores']))
    max_score = max(joined, key=lambda x: x[1])    
    res.append(max_score)

res
# [('toxic', 0.6553249955177307)]

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 6e-05
  • train_batch_size: 48
  • eval_batch_size: 48
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: linear
  • num_epochs: 10
  • warmup_ratio: 0.1

Metrics (comparison with teacher model)

Teacher (params) Student (params) Set (metric) Score (teacher) Score (student)
unitary/toxic-bert (110M) MiniLMv2-toxic-jigsaw-lite (23M) Test (ROC_AUC) 0.982677 0.9806

Deployment

Check our fast-nlp-text-toxicity repository for a FastAPI and ONNX based server to deploy this model on CPU devices.

Downloads last month
7
Inference Examples
Inference API (serverless) has been turned off for this model.

Collection including minuva/MiniLMv2-toxic-jigsaw-lite-onnx