JoPmt commited on
Commit
d814db7
·
verified ·
1 Parent(s): b1bde60

Upload visual_qa (2).py

Browse files
Files changed (1) hide show
  1. visual_qa (2).py +192 -0
visual_qa (2).py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import base64
3
+ from io import BytesIO
4
+ import json
5
+ import os
6
+ import requests
7
+ from typing import Optional
8
+ from huggingface_hub import InferenceClient
9
+ from transformers import AutoProcessor, Tool
10
+ import uuid
11
+ import mimetypes
12
+ ##from dotenv import load_dotenv
13
+
14
+ ##load_dotenv(override=True)
15
+
16
+ idefics_processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b")
17
+
18
+ def process_images_and_text(image_path, query, client):
19
+ messages = [
20
+ {
21
+ "role": "user", "content": [
22
+ {"type": "image"},
23
+ {"type": "text", "text": query},
24
+ ]
25
+ },
26
+ ]
27
+
28
+ prompt_with_template = idefics_processor.apply_chat_template(messages, add_generation_prompt=True)
29
+
30
+ # load images from local directory
31
+
32
+ # encode images to strings which can be sent to the endpoint
33
+ def encode_local_image(image_path):
34
+ # load image
35
+ image = Image.open(image_path).convert('RGB')
36
+
37
+ # Convert the image to a base64 string
38
+ buffer = BytesIO()
39
+ image.save(buffer, format="JPEG") # Use the appropriate format (e.g., JPEG, PNG)
40
+ base64_image = base64.b64encode(buffer.getvalue()).decode('utf-8')
41
+
42
+ # add string formatting required by the endpoint
43
+ image_string = f"data:image/jpeg;base64,{base64_image}"
44
+
45
+ return image_string
46
+
47
+
48
+ image_string = encode_local_image(image_path)
49
+ prompt_with_images = prompt_with_template.replace("<image>", "![]({}) ").format(image_string)
50
+
51
+
52
+ payload = {
53
+ "inputs": prompt_with_images,
54
+ "parameters": {
55
+ "return_full_text": False,
56
+ "max_new_tokens": 200,
57
+ }
58
+ }
59
+
60
+ return json.loads(client.post(json=payload).decode())[0]
61
+
62
+ # Function to encode the image
63
+ def encode_image(image_path):
64
+ if image_path.startswith("http"):
65
+ user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
66
+ request_kwargs = {
67
+ "headers": {"User-Agent": user_agent},
68
+ "stream": True,
69
+ }
70
+
71
+ # Send a HTTP request to the URL
72
+ response = requests.get(image_path, **request_kwargs)
73
+ response.raise_for_status()
74
+ content_type = response.headers.get("content-type", "")
75
+
76
+ extension = mimetypes.guess_extension(content_type)
77
+ if extension is None:
78
+ extension = ".download"
79
+
80
+ fname = str(uuid.uuid4()) + extension
81
+ download_path = os.path.abspath(os.path.join("downloads", fname))
82
+
83
+ with open(download_path, "wb") as fh:
84
+ for chunk in response.iter_content(chunk_size=512):
85
+ fh.write(chunk)
86
+
87
+ image_path = download_path
88
+
89
+ with open(image_path, "rb") as image_file:
90
+ return base64.b64encode(image_file.read()).decode('utf-8')
91
+
92
+ headers = {
93
+ "Content-Type": "application/json",
94
+ "Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}"
95
+ }
96
+
97
+
98
+ def resize_image(image_path):
99
+ img = Image.open(image_path)
100
+ width, height = img.size
101
+ img = img.resize((int(width / 2), int(height / 2)))
102
+ new_image_path = f"resized_{image_path}"
103
+ img.save(new_image_path)
104
+ return new_image_path
105
+
106
+
107
+ class VisualQATool(Tool):
108
+ name = "visualizer"
109
+ description = "A tool that can answer questions about attached images."
110
+ inputs = {
111
+ "question": {"description": "the question to answer", "type": "text"},
112
+ "image_path": {
113
+ "description": "The path to the image on which to answer the question",
114
+ "type": "text",
115
+ },
116
+ }
117
+ output_type = "text"
118
+
119
+ client = InferenceClient("HuggingFaceM4/idefics2-8b-chatty")
120
+
121
+ def forward(self, image_path: str, question: Optional[str] = None) -> str:
122
+ add_note = False
123
+ if not question:
124
+ add_note = True
125
+ question = "Please write a detailed caption for this image."
126
+ try:
127
+ output = process_images_and_text(image_path, question, self.client)
128
+ except Exception as e:
129
+ print(e)
130
+ if "Payload Too Large" in str(e):
131
+ new_image_path = resize_image(image_path)
132
+ output = process_images_and_text(new_image_path, question, self.client)
133
+
134
+ if add_note:
135
+ output = f"You did not provide a particular question, so here is a detailed caption for the image: {output}"
136
+
137
+ return output
138
+
139
+ class VisualQAGPT4Tool(Tool):
140
+ name = "visualizer"
141
+ description = "A tool that can answer questions about attached images."
142
+ inputs = {
143
+ "question": {"description": "the question to answer", "type": "text"},
144
+ "image_path": {
145
+ "description": "The path to the image on which to answer the question. This should be a local path to downloaded image.",
146
+ "type": "text",
147
+ },
148
+ }
149
+ output_type = "text"
150
+
151
+ def forward(self, image_path: str, question: Optional[str] = None) -> str:
152
+ add_note = False
153
+ if not question:
154
+ add_note = True
155
+ question = "Please write a detailed caption for this image."
156
+ if not isinstance(image_path, str):
157
+ raise Exception("You should provide only one string as argument to this tool!")
158
+
159
+ base64_image = encode_image(image_path)
160
+
161
+ payload = {
162
+ "model": "gpt-4o",
163
+ "messages": [
164
+ {
165
+ "role": "user",
166
+ "content": [
167
+ {
168
+ "type": "text",
169
+ "text": question
170
+ },
171
+ {
172
+ "type": "image_url",
173
+ "image_url": {
174
+ "url": f"data:image/jpeg;base64,{base64_image}"
175
+ }
176
+ }
177
+ ]
178
+ }
179
+ ],
180
+ "max_tokens": 500
181
+ }
182
+ response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
183
+ try:
184
+ output = response.json()['choices'][0]['message']['content']
185
+ except Exception:
186
+ raise Exception(f"Response format unexpected: {response.json()}")
187
+
188
+ if add_note:
189
+ output = f"You did not provide a particular question, so here is a detailed caption for the image: {output}"
190
+
191
+ return output
192
+