mapo-pick-safety / README.md
JW17's picture
fix 404 (#2)
8944e1e verified
---
license: openrail++
library_name: diffusers
tags:
- text-to-image
- text-to-image
- diffusers-training
- diffusers
- stable-diffusion-xl
- stable-diffusion-xl-diffusers
base_model: stabilityai/stable-diffusion-xl-base-1.0
---
# Margin-aware Preference Optimization for Aligning Diffusion Models without Reference
<div align="center">
<img src="https://github.com/mapo-t2i/mapo/blob/main/assets/mapo_overview.png?raw=true" width=750/>
</div><br>
We propose **MaPO**, a reference-free, sample-efficient, memory-friendly alignment technique for text-to-image diffusion models. For more details on the technique, please refer to our paper [here](https://arxiv.org/abs/2406.06424).
## Developed by
* Jiwoo Hong<sup>*</sup> (KAIST AI)
* Sayak Paul<sup>*</sup> (Hugging Face)
* Noah Lee (KAIST AI)
* Kashif Rasul (Hugging Face)
* James Thorne (KAIST AI)
* Jongheon Jeong (Korea University)
## Dataset
This model was fine-tuned from [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) on the [Pick-Safety](https://huggingface.co/datasets/mapo-t2i/pick-safety). While the model is trained for safer generations, the training dataset contains examples of harmful content, including explicit text and images.
## Training Code
Refer to our code repository [here](https://github.com/mapo-t2i/mapo).
## Inference
```python
from diffusers import DiffusionPipeline, AutoencoderKL, UNet2DConditionModel
import torch
sdxl_id = "stabilityai/stable-diffusion-xl-base-1.0"
vae_id = "madebyollin/sdxl-vae-fp16-fix"
unet_id = "mapo-t2i/mapo-pick-safety"
vae = AutoencoderKL.from_pretrained(vae_id, torch_dtype=torch.float16)
unet = UNet2DConditionModel.from_pretrained(unet_id, subfolder='unet', torch_dtype=torch.float16)
pipeline = DiffusionPipeline.from_pretrained(sdxl_id, vae=vae, unet=unet, torch_dtype=torch.float16).to("cuda")
prompt = "bright and shiny weather, gorgeous naked Latin girl, realistic and extremely detailed full body image, 8k"
image = pipeline(prompt=prompt, num_inference_steps=30).images[0]
```
For qualitative results, please visit our [project website](https://mapo-t2i.github.io/).
## Citation
```bibtex
@misc{hong2024marginaware,
title={Margin-aware Preference Optimization for Aligning Diffusion Models without Reference},
author={Jiwoo Hong and Sayak Paul and Noah Lee and Kashif Rasul and James Thorne and Jongheon Jeong},
year={2024},
eprint={2406.06424},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```