Merve Noyan commited on
Commit
e352103
·
1 Parent(s): e716569

initial commit

Browse files
app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoProcessor, Idefics3ForConditionalGeneration
3
+ import re
4
+ import time
5
+ from PIL import Image
6
+ import torch
7
+ import spaces
8
+ import subprocess
9
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
10
+
11
+
12
+ processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics3-8b-new")
13
+
14
+ model = Idefics3ForConditionalGeneration.from_pretrained("HuggingFaceM4/idefics3-8b-new",
15
+ torch_dtype=torch.bfloat16,
16
+ #_attn_implementation="flash_attention_2",
17
+ trust_remote_code=True).to("cuda")
18
+
19
+ BAD_WORDS_IDS = processor.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
20
+ EOS_WORDS_IDS = [processor.tokenizer.eos_token_id]
21
+
22
+ #@spaces.GPU
23
+ def model_inference(
24
+ images, text, decoding_strategy, temperature, max_new_tokens,
25
+ repetition_penalty, top_p
26
+ ):
27
+ if text == "" and not images:
28
+ gr.Error("Please input a query and optionally image(s).")
29
+
30
+ if text == "" and images:
31
+ gr.Error("Please input a text query along the image(s).")
32
+
33
+ if isinstance(images, Image.Image):
34
+ images = [images]
35
+
36
+ if isinstance(text, str):
37
+ text = "<image>" + text
38
+ text = [text]
39
+
40
+ inputs = processor(text=text, images=images, padding=True, return_tensors="pt").to("cuda")
41
+ print("inputs",inputs)
42
+
43
+ assert decoding_strategy in [
44
+ "Greedy",
45
+ "Top P Sampling",
46
+ ]
47
+ if decoding_strategy == "Greedy":
48
+ do_sample = False
49
+ elif decoding_strategy == "Top P Sampling":
50
+ do_sample = True
51
+
52
+ # Generate
53
+
54
+ generated_ids = model.generate(**inputs, bad_words_ids=BAD_WORDS_IDS, max_new_tokens=max_new_tokens,
55
+ temperature=temperature, do_sample=do_sample, repetition_penalty=repetition_penalty,
56
+ top_p=top_p),
57
+ generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
58
+ #generated_texts = processor.batch_decode(generated_ids[:, generation_args["input_ids"].size(1):], skip_special_tokens=True)
59
+ print("INPUT:", text, "|OUTPUT:", generated_texts)
60
+ return generated_texts[0]
61
+
62
+
63
+ with gr.Blocks(fill_height=True) as demo:
64
+ gr.Markdown("## IDEFICS2Llama 🐶")
65
+ gr.Markdown("Play with [IDEFICS2Llama](https://huggingface.co/HuggingFaceM4/idefics2-8b) in this demo. To get started, upload an image and text or try one of the examples.")
66
+ gr.Markdown("**Important note**: This model is not made for chatting, the chatty IDEFICS2 will be released in the upcoming days. **This model is very strong on various tasks, including visual question answering, document retrieval and more, you can see it through the examples.**")
67
+ gr.Markdown("Learn more about IDEFICS2 in this [blog post](https://huggingface.co/blog/idefics2).")
68
+
69
+
70
+ with gr.Column():
71
+ image_input = gr.Image(label="Upload your Image", type="pil")
72
+ query_input = gr.Textbox(label="Prompt")
73
+ submit_btn = gr.Button("Submit")
74
+ output = gr.Textbox(label="Output")
75
+
76
+ with gr.Accordion(label="Example Inputs and Advanced Generation Parameters"):
77
+ examples=[["example_images/travel_tips.jpg", "I want to go somewhere similar to the one in the photo. Give me destinations and travel tips.", "Greedy", 0.4, 512, 1.2, 0.8],
78
+ ["example_images/dummy_pdf.png", "How much percent is the order status?", "Greedy", 0.4, 512, 1.2, 0.8],
79
+ ["example_images/art_critic.png", "As an art critic AI assistant, could you describe this painting in details and make a thorough critic?.", "Greedy", 0.4, 512, 1.2, 0.8],
80
+ ["example_images/s2w_example.png", "What is this UI about?", "Greedy", 0.4, 512, 1.2, 0.8]]
81
+
82
+ # Hyper-parameters for generation
83
+ max_new_tokens = gr.Slider(
84
+ minimum=8,
85
+ maximum=1024,
86
+ value=512,
87
+ step=1,
88
+ interactive=True,
89
+ label="Maximum number of new tokens to generate",
90
+ )
91
+ repetition_penalty = gr.Slider(
92
+ minimum=0.01,
93
+ maximum=5.0,
94
+ value=1.2,
95
+ step=0.01,
96
+ interactive=True,
97
+ label="Repetition penalty",
98
+ info="1.0 is equivalent to no penalty",
99
+ )
100
+ temperature = gr.Slider(
101
+ minimum=0.0,
102
+ maximum=5.0,
103
+ value=0.4,
104
+ step=0.1,
105
+ interactive=True,
106
+ label="Sampling temperature",
107
+ info="Higher values will produce more diverse outputs.",
108
+ )
109
+ top_p = gr.Slider(
110
+ minimum=0.01,
111
+ maximum=0.99,
112
+ value=0.8,
113
+ step=0.01,
114
+ interactive=True,
115
+ label="Top P",
116
+ info="Higher values is equivalent to sampling more low-probability tokens.",
117
+ )
118
+ decoding_strategy = gr.Radio(
119
+ [
120
+ "Greedy",
121
+ "Top P Sampling",
122
+ ],
123
+ value="Greedy",
124
+ label="Decoding strategy",
125
+ interactive=True,
126
+ info="Higher values is equivalent to sampling more low-probability tokens.",
127
+ )
128
+ decoding_strategy.change(
129
+ fn=lambda selection: gr.Slider(
130
+ visible=(
131
+ selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"]
132
+ )
133
+ ),
134
+ inputs=decoding_strategy,
135
+ outputs=temperature,
136
+ )
137
+
138
+ decoding_strategy.change(
139
+ fn=lambda selection: gr.Slider(
140
+ visible=(
141
+ selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"]
142
+ )
143
+ ),
144
+ inputs=decoding_strategy,
145
+ outputs=repetition_penalty,
146
+ )
147
+ decoding_strategy.change(
148
+ fn=lambda selection: gr.Slider(visible=(selection in ["Top P Sampling"])),
149
+ inputs=decoding_strategy,
150
+ outputs=top_p,
151
+ )
152
+ gr.Examples(
153
+ examples = examples,
154
+ inputs=[image_input, query_input, decoding_strategy, temperature,
155
+ max_new_tokens, repetition_penalty, top_p],
156
+ outputs=output,
157
+ fn=model_inference
158
+ )
159
+
160
+ submit_btn.click(model_inference, inputs = [image_input, query_input, decoding_strategy, temperature,
161
+ max_new_tokens, repetition_penalty, top_p], outputs=output)
162
+
163
+
164
+ demo.launch(debug=True)
example_images/art_critic.png ADDED
example_images/dummy_pdf.png ADDED
example_images/s2w_example.png ADDED
example_images/travel_tips.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ accelerate
3
+ huggingface_hub
4
+ gradio
5
+ git+https://github.com/andimarafioti/transformers.git@idefics3
6
+ spaces