AdamCodd's picture
Update README.md
0baf05b verified
|
raw
history blame
6.2 kB
---
metrics:
- accuracy
- f1
- roc_auc
- precision
- recall
pipeline_tag: image-classification
base_model: google/vit-base-patch16-384
model-index:
- name: AdamCodd/vit-nsfw-stable-diffusion
results:
- task:
type: image-classification
name: Image Classification
metrics:
- type: accuracy
value: 0.9349
name: Accuracy
- type: F1
value: 0.935
name: F1
- type: ROC_AUC
value: 0.9847
name: AUC
- type: precision
value: 0.9335
name: Precision
- type: recall
value: 0.9366
name: Recall
- type: loss
value: 0.1592
name: Loss
tags:
- transformers.js
- transformers
- nlp
license: cc-by-nc-nd-4.0
datasets:
- AdamCodd/Civitai-8m-prompts
---
# vit-nsfw-stable-diffusion
This model is a fine-tuned version of [vit-base-patch16-384](https://huggingface.co/google/vit-base-patch16-384) on ~1.7M generated images from the [AdamCodd/Civitai-8m-prompts](https://huggingface.co/datasets/AdamCodd/Civitai-8m-prompts) dataset, balanced between NSFW/SFW labels.
It achieves the following results on the evaluation set:
- Loss: 0.1592
- Accuracy: 0.9349
Unlike [AdamCodd/vit-base-nsfw-detector](https://huggingface.co/AdamCodd/vit-base-nsfw-detector) model, this one was exclusively trained on generated images from stable diffusion.
The license for this model is [**cc-by-nc-nd-4.0**](https://creativecommons.org/licenses/by-nc-nd/4.0/legalcode.en). For commercial use rights, please contact me ([email protected]).
## Model description
The Vision Transformer (ViT) is a transformer encoder model (BERT-like) pretrained on a large collection of images in a supervised fashion, namely ImageNet-21k, at a resolution of 224x224 pixels. Next, the model was fine-tuned on ImageNet (also referred to as ILSVRC2012), a dataset comprising 1 million images and 1,000 classes, at a higher resolution of 384x384.
## Intended uses & limitations
Usage for a local image:
```python
from transformers import pipeline
from PIL import Image
img = Image.open("<path_to_image_file>")
predict = pipeline('image-classification', model='AdamCodd/vit-nsfw-stable-diffusion')
predict(img)
```
Usage for a distant image:
```python
from transformers import ViTImageProcessor, AutoModelForImageClassification
from PIL import Image
import requests
url = 'https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/ba55f276-5aa5-446f-c59a-8fff4d209100/width=512/ba55f276-5aa5-446f-c59a-8fff4d209100.jpeg'
image = Image.open(requests.get(url, stream=True).raw)
processor = ViTImageProcessor.from_pretrained('AdamCodd/vit-nsfw-stable-diffusion')
model = AutoModelForImageClassification.from_pretrained('AdamCodd/vit-nsfw-stable-diffusion')
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])
# Predicted class: sfw
```
Usage with Transformers.js (Vanilla JS):
```js
/* Instructions:
* - Place this script in an HTML file using the <script type="module"> tag.
* - Ensure the HTML file is served over a local or remote server (e.g., using Python's http.server, Node.js server, or similar).
* - Replace 'https://example.com/path/to/image.jpg' in the classifyImage function call with the URL of the image you want to classify.
*
* Example of how to include this script in HTML:
* <script type="module" src="path/to/this_script.js"></script>
*
* This setup ensures that the script can use imports and perform network requests without CORS issues.
*/
import { pipeline, env } from 'https://cdn.jsdelivr.net/npm/@xenova/[email protected]';
// Since we will download the model from HuggingFace Hub, we can skip the local model check
env.allowLocalModels = false;
// Load the image classification model
const classifier = await pipeline('image-classification', 'AdamCodd/vit-nsfw-stable-diffusion');
// Function to fetch and classify an image from a URL
async function classifyImage(url) {
try {
const response = await fetch(url);
if (!response.ok) throw new Error('Failed to load image');
const blob = await response.blob();
const image = new Image();
const imagePromise = new Promise((resolve, reject) => {
image.onload = () => resolve(image);
image.onerror = reject;
image.src = URL.createObjectURL(blob);
});
const img = await imagePromise; // Ensure the image is loaded
const classificationResults = await classifier([img.src]); // Classify the image
console.log('Predicted class: ', classificationResults[0].label);
} catch (error) {
console.error('Error classifying image:', error);
}
}
// Example usage
classifyImage('https://example.com/path/to/image.jpg');
// Predicted class: sfw
```
Since this model has been trained on generated images from stable diffusion, it won't perform as well on real pictures (in that case just use my other ViT model).
It performs very well on generated images and would pair well with the following [AdamCodd/distilroberta-nsfw-prompt-stable-diffusion](https://huggingface.co/AdamCodd/distilroberta-nsfw-prompt-stable-diffusion) model to filter both prompts and images, ensuring safe results.
## Training and evaluation data
More information needed
## Training procedure
### Training hyperparameters
The following hyperparameters were used during training:
- learning_rate: 3e-05
- train_batch_size: 32
- eval_batch_size: 64
- seed: 42
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
- num_epochs: 2
### Training results
- Validation Loss: 0.1592
- Accuracy: 0.9349
- F1: 0.9350
- AUC: 0.9847
- Precision: 0.9335
- Recall: 0.9366
[Confusion matrix](https://huggingface.co/AdamCodd/vit-nsfw-stable-diffusion/resolve/main/confusion_matrix_epoch_2.png) (eval):
[[78666 5644]
[5355 79173]]
### Framework versions
- Transformers 4.36.2
- Evaluate 0.4.1
If you want to support me, you can [here](https://ko-fi.com/adamcodd).
## Citation and Acknowledgments
I would like to express my sincere gratitude to Prodia.com for generously providing the GPU resources, specifically the RTX 4090, that made the training of this model possible.