File size: 1,842 Bytes
668e702
bfcd10f
668e702
 
bfcd10f
 
 
668e702
 
 
bfcd10f
 
668e702
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfcd10f
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import gradio as gr
from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
import torch


model_id = "HuggingFaceM4/idefics2-8b"

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.float16
)
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForVision2Seq.from_pretrained(model_id, torch_dtype=torch.float16, quantization_config=quantization_config)


def respond(multimodal_input):
    images = multimodal_input["files"]
    content = [{"type": "image"} for _ in images]
    content.append({"type": "text", "text": multimodal_input["text"]})
    messages = [{"role": "user", "content": content}]
    prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
    inputs = processor(text=prompt, images=[images], return_tensors="pt")
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    num_tokens = len(inputs["input_ids"][0])
    with torch.inference_mode():
        generated_ids = model.generate(**inputs, max_new_tokens=500)

    new_tokens = generated_ids[:, num_tokens:]
    generated_text = processor.batch_decode(new_tokens, skip_special_tokens=True)[0]
    return generated_text


gr.Interface(
    respond,
    inputs=[gr.MultimodalTextbox(file_types=["image"], show_label=False)],
    outputs="text",
    title="IDEFICS2-8B DPO",
    description="Try IDEFICS2-8B fine-tuned using direct preference optimization (DPO) in this demo. Learn more about vision language model DPO integration of TRL [here](https://huggingface.co/blog/dpo_vlm).",
    examples=[
        {"text": "What is the type of flower in the image and what insect is on it?", "files": ["./bee.jpg"]},
        {"text": "Describe the image", "files": ["./howl.jpg"]},
    ],
).launch()