Spaces:
Runtime error
Runtime error
import io | |
import base64 | |
from typing import List, Tuple | |
import numpy as np | |
import gradio as gr | |
from datasets import load_dataset | |
from transformers import AutoProcessor, AutoModel | |
import torch | |
from PIL import Image | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
dtype = torch.float16 if device == "cuda" else torch.float32 | |
# Load example dataset | |
dataset = load_dataset("xzuyn/dalle-3_vs_sd-v1-5_dpo", num_proc=4) | |
processor_name = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" | |
model_name = "yuvalkirstain/PickScore_v1" | |
processor = AutoProcessor.from_pretrained(processor_name) | |
model = AutoModel.from_pretrained(model_name, torch_dtype=dtype).to(device) | |
def decode_image(image: str) -> Image: | |
""" | |
Decodes base64 string to PIL image. | |
Args: | |
image: base64 string | |
Returns: | |
PIL image | |
""" | |
img_byte_arr = base64.b64decode(image) | |
img_byte_arr = io.BytesIO(img_byte_arr) | |
img_byte_arr = Image.open(img_byte_arr) | |
return img_byte_arr | |
def get_preference(img_1: Image.Image, img_2: Image.Image, caption: str) -> Image.Image: | |
""" | |
Returns the preference of the caption for the two images. | |
Args: | |
img_1: PIL image | |
img_2: PIL image | |
caption: string | |
Returns: | |
preference image: PIL image | |
""" | |
imgs = [img_1, img_2] | |
logits = get_logits(caption, imgs) | |
preference = logits.argmax().item() | |
return imgs[preference] | |
def sample_example() -> Tuple[Image.Image, Image.Image, Image.Image, str]: | |
""" | |
Samples a random example from the dataset and displays it. | |
Returns: | |
img_1: PIL image | |
img_2: PIL image | |
preference: PIL image | |
caption: string | |
""" | |
example = dataset["train"][np.random.randint(0, len(dataset["train"]))] | |
img_1 = decode_image(example["jpg_0"]) | |
img_2 = decode_image(example["jpg_1"]) | |
caption = example["caption"] | |
imgs = [img_1, img_2] | |
logits = get_logits(caption, imgs) | |
preference = logits.argmax().item() | |
return (img_1, img_2, imgs[preference], caption) | |
def get_logits(caption: str, imgs: List[Image.Image]) -> torch.Tensor: | |
""" | |
Returns the logits for the caption and images. | |
Args: | |
caption: string | |
imgs: list of PIL images | |
Returns: | |
logits: torch.Tensor | |
""" | |
inputs = processor( | |
text=caption, | |
images=imgs, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=77, | |
).to(device) | |
inputs["pixel_values"] = ( | |
inputs["pixel_values"].half() if device == "cuda" else inputs["pixel_values"] | |
) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits_per_image = outputs.logits_per_image | |
return logits_per_image | |
### Description | |
title = r""" | |
<h1 align="center">Aesthetic Scorer: CLIP fine-tuned for DPO scoring </h1> | |
""" | |
description = r""" | |
<b> This is a demo for the paper <a href="https://arxiv.org/abs/2109.04436">Pick-a-Pic: An Open Dataset of User Preferences for Text-to-Image Generation </a> </b> <br> | |
How to use this demo: <br> | |
1. Upload two images generated using the same caption. | |
2. Enter the caption used to generate the images. | |
3. Click on the "Get Preference" button to get the image which scores higher on user preferences according to the model. <br> | |
<b> OR </b> <br> | |
1. Click on the "Random Example" button to get a random example from a <a href="https://huggingface.co/datasets/xzuyn/dalle-3_vs_sd-v1-5_dpo">Dalle 3 vs SD 1.5 DPO dataset. </a><br> | |
This demo demonstrates the use of this CLIP variant for DPO scoring. The scores can then be used for DPO fine-tuning with these <a href="https://github.com/huggingface/diffusers/tree/main/examples/research_projects/diffusion_dpo">scripts. </a><br> | |
Accuracy on the <a href="https://huggingface.co/datasets/xzuyn/dalle-3_vs_sd-v1-5_dpo">Dalle 3 vs SD 1.5 DPO dataset: </a><br> | |
<a href="https://huggingface.co/yuvalkirstain/PickScore_v1">PickScore_v1</a> - 97.3 <br> | |
<a href="https://huggingface.co/CIDAS/clipseg-rd64-refined">CLIPSeg</a> - 70.9 <br> | |
<a href="https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K">CLIP-ViT-H-14-laion2B-s32B-b79K</a> - 82.3 <br> | |
""" | |
citation = r""" | |
π **Citation** | |
```bibtex | |
@inproceedings{Kirstain2023PickaPicAO, | |
title={Pick-a-Pic: An Open Dataset of User Preferences for Text-to-Image Generation}, | |
author={Yuval Kirstain and Adam Polyak and Uriel Singer and Shahbuland Matiana and Joe Penna and Omer Levy}, | |
year={2023} | |
} | |
``` | |
""" | |
with gr.Blocks() as demo: | |
gr.Markdown(title) | |
gr.Markdown(description) | |
with gr.Row(): | |
first_image = gr.Image(height=400, width=400, label="First Image") | |
second_image = gr.Image(height=400, width=400, label="Second Image") | |
caption_box = gr.Textbox(lines=1, label="Caption") | |
with gr.Row(): | |
image_button = gr.Button("Get Preference") | |
random_example = gr.Button("Random Example") | |
image_output = gr.Image(height=400, width=400, label="Preference") | |
image_button.click( | |
get_preference, | |
inputs=[first_image, second_image, caption_box], | |
outputs=image_output, | |
) | |
random_example.click( | |
sample_example, outputs=[first_image, second_image, image_output, caption_box] | |
) | |
gr.Markdown(citation) | |
if __name__ == "__main__": | |
demo.launch() |