chiayewken
commited on
Commit
·
d65c3c0
1
Parent(s):
3a8c647
Support doc loading, detector and retriever
Browse files
app.py
CHANGED
@@ -1,15 +1,22 @@
|
|
1 |
import base64
|
2 |
import hashlib
|
3 |
import io
|
|
|
4 |
import os
|
|
|
|
|
5 |
from pathlib import Path
|
6 |
from threading import Thread
|
7 |
-
from typing import Iterator, Optional, List, Union
|
8 |
|
|
|
9 |
import gradio as gr
|
|
|
10 |
import spaces
|
11 |
import torch
|
12 |
from PIL import Image
|
|
|
|
|
13 |
from pydantic import BaseModel
|
14 |
from qwen_vl_utils import process_vision_info
|
15 |
from swift.llm import (
|
@@ -20,6 +27,7 @@ from swift.llm import (
|
|
20 |
inference,
|
21 |
inference_stream,
|
22 |
)
|
|
|
23 |
from transformers import (
|
24 |
Qwen2VLForConditionalGeneration,
|
25 |
PreTrainedTokenizer,
|
@@ -27,6 +35,8 @@ from transformers import (
|
|
27 |
TextIteratorStreamer,
|
28 |
AutoTokenizer,
|
29 |
)
|
|
|
|
|
30 |
|
31 |
MAX_MAX_NEW_TOKENS = 2048
|
32 |
DEFAULT_MAX_NEW_TOKENS = 1024
|
@@ -49,6 +59,264 @@ this demo is governed by the original [license](https://huggingface.co/meta-llam
|
|
49 |
"""
|
50 |
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
def convert_image_to_text(image: Image) -> str:
|
53 |
# This is also how OpenAI encodes images: https://platform.openai.com/docs/guides/vision
|
54 |
with io.BytesIO() as output:
|
@@ -272,6 +540,24 @@ class QwenModel(EvalModel):
|
|
272 |
yield "".join(outputs)
|
273 |
|
274 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
if not torch.cuda.is_available():
|
276 |
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
|
277 |
|
@@ -279,86 +565,90 @@ if not torch.cuda.is_available():
|
|
279 |
if torch.cuda.is_available():
|
280 |
model = QwenModel()
|
281 |
model.load()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
|
283 |
|
284 |
@spaces.GPU
|
285 |
def generate(
|
286 |
-
|
287 |
-
chat_history: list[dict],
|
288 |
-
system_prompt: str = "",
|
289 |
-
max_new_tokens: int = 1024,
|
290 |
-
temperature: float = 0.6,
|
291 |
-
top_p: float = 0.9,
|
292 |
-
top_k: int = 50,
|
293 |
-
repetition_penalty: float = 1.2,
|
294 |
) -> Iterator[str]:
|
295 |
-
|
296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
|
|
|
|
|
298 |
|
299 |
-
chat_interface = gr.ChatInterface(
|
300 |
-
fn=generate,
|
301 |
-
additional_inputs=[
|
302 |
-
gr.Textbox(label="System prompt", lines=6),
|
303 |
-
gr.Slider(
|
304 |
-
label="Max new tokens",
|
305 |
-
minimum=1,
|
306 |
-
maximum=MAX_MAX_NEW_TOKENS,
|
307 |
-
step=1,
|
308 |
-
value=DEFAULT_MAX_NEW_TOKENS,
|
309 |
-
),
|
310 |
-
gr.Slider(
|
311 |
-
label="Temperature",
|
312 |
-
minimum=0.1,
|
313 |
-
maximum=4.0,
|
314 |
-
step=0.1,
|
315 |
-
value=0.6,
|
316 |
-
),
|
317 |
-
gr.Slider(
|
318 |
-
label="Top-p (nucleus sampling)",
|
319 |
-
minimum=0.05,
|
320 |
-
maximum=1.0,
|
321 |
-
step=0.05,
|
322 |
-
value=0.9,
|
323 |
-
),
|
324 |
-
gr.Slider(
|
325 |
-
label="Top-k",
|
326 |
-
minimum=1,
|
327 |
-
maximum=1000,
|
328 |
-
step=1,
|
329 |
-
value=50,
|
330 |
-
),
|
331 |
-
gr.Slider(
|
332 |
-
label="Repetition penalty",
|
333 |
-
minimum=1.0,
|
334 |
-
maximum=2.0,
|
335 |
-
step=0.05,
|
336 |
-
value=1.2,
|
337 |
-
),
|
338 |
-
],
|
339 |
-
stop_btn=None,
|
340 |
-
examples=[
|
341 |
-
[
|
342 |
-
"Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now?"
|
343 |
-
],
|
344 |
-
[
|
345 |
-
"Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?"
|
346 |
-
],
|
347 |
-
[
|
348 |
-
"Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?"
|
349 |
-
],
|
350 |
-
],
|
351 |
-
cache_examples=False,
|
352 |
-
type="messages",
|
353 |
-
)
|
354 |
|
355 |
with gr.Blocks(css_paths="style.css", fill_height=True) as demo:
|
356 |
gr.Markdown(DESCRIPTION)
|
357 |
gr.DuplicateButton(
|
358 |
value="Duplicate Space for private use", elem_id="duplicate-button"
|
359 |
)
|
360 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
361 |
gr.Markdown(LICENSE)
|
362 |
|
|
|
|
|
363 |
if __name__ == "__main__":
|
364 |
demo.queue(max_size=20).launch()
|
|
|
1 |
import base64
|
2 |
import hashlib
|
3 |
import io
|
4 |
+
import json
|
5 |
import os
|
6 |
+
import tempfile
|
7 |
+
from collections import OrderedDict as CollectionsOrderedDict
|
8 |
from pathlib import Path
|
9 |
from threading import Thread
|
10 |
+
from typing import Iterator, Optional, List, Union, OrderedDict
|
11 |
|
12 |
+
import fitz
|
13 |
import gradio as gr
|
14 |
+
import requests
|
15 |
import spaces
|
16 |
import torch
|
17 |
from PIL import Image
|
18 |
+
from colpali_engine import ColPali, ColPaliProcessor
|
19 |
+
from huggingface_hub import hf_hub_download
|
20 |
from pydantic import BaseModel
|
21 |
from qwen_vl_utils import process_vision_info
|
22 |
from swift.llm import (
|
|
|
27 |
inference,
|
28 |
inference_stream,
|
29 |
)
|
30 |
+
from tqdm import tqdm
|
31 |
from transformers import (
|
32 |
Qwen2VLForConditionalGeneration,
|
33 |
PreTrainedTokenizer,
|
|
|
35 |
TextIteratorStreamer,
|
36 |
AutoTokenizer,
|
37 |
)
|
38 |
+
from ultralytics import YOLO
|
39 |
+
from ultralytics.engine.results import Results
|
40 |
|
41 |
MAX_MAX_NEW_TOKENS = 2048
|
42 |
DEFAULT_MAX_NEW_TOKENS = 1024
|
|
|
59 |
"""
|
60 |
|
61 |
|
62 |
+
class MultimodalSample(BaseModel):
|
63 |
+
question: str
|
64 |
+
answer: str
|
65 |
+
category: str
|
66 |
+
evidence_pages: List[int] = []
|
67 |
+
raw_output: str = ""
|
68 |
+
pred: str = ""
|
69 |
+
source: str = ""
|
70 |
+
annotator: str = ""
|
71 |
+
generator: str = ""
|
72 |
+
retrieved_pages: List[int] = []
|
73 |
+
|
74 |
+
|
75 |
+
class MultimodalObject(BaseModel):
|
76 |
+
id: str = ""
|
77 |
+
page: int = 0
|
78 |
+
text: str = ""
|
79 |
+
image_string: str = ""
|
80 |
+
snippet: str = ""
|
81 |
+
score: float = 0.0
|
82 |
+
source: str = ""
|
83 |
+
category: str = ""
|
84 |
+
|
85 |
+
def get_image(self) -> Optional[Image.Image]:
|
86 |
+
if self.image_string:
|
87 |
+
return convert_text_to_image(self.image_string)
|
88 |
+
|
89 |
+
@classmethod
|
90 |
+
def from_image(cls, image: Image.Image, **kwargs):
|
91 |
+
return cls(image_string=convert_image_to_text(image), **kwargs)
|
92 |
+
|
93 |
+
|
94 |
+
class ObjectDetector(BaseModel, arbitrary_types_allowed=True):
|
95 |
+
def run(self, image: Image.Image) -> List[MultimodalObject]:
|
96 |
+
raise NotImplementedError()
|
97 |
+
|
98 |
+
|
99 |
+
class YoloDetector(ObjectDetector):
|
100 |
+
repo_id: str = "DILHTWD/documentlayoutsegmentation_YOLOv8_ondoclaynet"
|
101 |
+
filename: str = "yolov8x-doclaynet-epoch64-imgsz640-initiallr1e-4-finallr1e-5.pt"
|
102 |
+
local_dir: str = "data/yolo"
|
103 |
+
client: Optional[YOLO] = None
|
104 |
+
|
105 |
+
def load(self):
|
106 |
+
if self.client is None:
|
107 |
+
if not Path(self.local_dir, self.filename).exists():
|
108 |
+
hf_hub_download(
|
109 |
+
repo_id=self.repo_id,
|
110 |
+
filename=self.filename,
|
111 |
+
local_dir=self.local_dir,
|
112 |
+
)
|
113 |
+
self.client = YOLO(Path(self.local_dir, self.filename))
|
114 |
+
|
115 |
+
def save_image(self, image: Image.Image) -> str:
|
116 |
+
text = convert_image_to_text(image)
|
117 |
+
hash_id = hashlib.md5(text.encode()).hexdigest()
|
118 |
+
path = Path(self.local_dir, f"{hash_id}.png")
|
119 |
+
image.save(path)
|
120 |
+
return str(path)
|
121 |
+
|
122 |
+
@staticmethod
|
123 |
+
def extract_subimage(image: Image.Image, box: List[float]) -> Image.Image:
|
124 |
+
return image.crop((round(box[0]), round(box[1]), round(box[2]), round(box[3])))
|
125 |
+
|
126 |
+
def run(self, image: Image.Image) -> List[MultimodalObject]:
|
127 |
+
self.load()
|
128 |
+
path = self.save_image(image)
|
129 |
+
results: List[Results] = self.client(source=[path])
|
130 |
+
assert len(results) == 1
|
131 |
+
objects = []
|
132 |
+
|
133 |
+
for i, label_id in enumerate(results[0].boxes.cls):
|
134 |
+
label = results[0].names[label_id.item()]
|
135 |
+
score = results[0].boxes.conf[i].item()
|
136 |
+
box: List[float] = results[0].boxes.xyxy[i].tolist()
|
137 |
+
subimage = self.extract_subimage(image, box)
|
138 |
+
objects.append(
|
139 |
+
MultimodalObject(
|
140 |
+
image_string=convert_image_to_text(subimage),
|
141 |
+
category=label,
|
142 |
+
score=score,
|
143 |
+
)
|
144 |
+
)
|
145 |
+
|
146 |
+
return objects
|
147 |
+
|
148 |
+
|
149 |
+
class MultimodalPage(BaseModel):
|
150 |
+
number: int
|
151 |
+
objects: List[MultimodalObject]
|
152 |
+
text: str
|
153 |
+
image_string: str
|
154 |
+
source: str
|
155 |
+
score: float = 0.0
|
156 |
+
|
157 |
+
def get_tables_and_figures(self) -> List[MultimodalObject]:
|
158 |
+
return [o for o in self.objects if o.category in ["Table", "Picture"]]
|
159 |
+
|
160 |
+
def get_full_image(self) -> Image.Image:
|
161 |
+
return convert_text_to_image(self.image_string)
|
162 |
+
|
163 |
+
@classmethod
|
164 |
+
def from_text(cls, text: str):
|
165 |
+
return MultimodalPage(
|
166 |
+
text=text, number=0, objects=[], image_string="", source=""
|
167 |
+
)
|
168 |
+
|
169 |
+
@classmethod
|
170 |
+
def from_image(cls, image: Image.Image):
|
171 |
+
return MultimodalPage(
|
172 |
+
image_string=convert_image_to_text(image),
|
173 |
+
number=0,
|
174 |
+
objects=[],
|
175 |
+
text="",
|
176 |
+
source="",
|
177 |
+
)
|
178 |
+
|
179 |
+
|
180 |
+
class MultimodalDocument(BaseModel):
|
181 |
+
pages: List[MultimodalPage]
|
182 |
+
|
183 |
+
def get_page(self, i: int) -> MultimodalPage:
|
184 |
+
pages = [p for p in self.pages if p.number == i]
|
185 |
+
assert len(pages) == 1
|
186 |
+
return pages[0]
|
187 |
+
|
188 |
+
@classmethod
|
189 |
+
def load_from_pdf(cls, path: str, dpi: int = 150, detector: ObjectDetector = None):
|
190 |
+
# Each page as an image (with optional extracted text)
|
191 |
+
doc = fitz.open(path)
|
192 |
+
pages = []
|
193 |
+
|
194 |
+
for i, page in enumerate(tqdm(doc.pages(), desc=path)):
|
195 |
+
text = page.get_text()
|
196 |
+
zoom = dpi / 72 # 72 is the default DPI
|
197 |
+
matrix = fitz.Matrix(zoom, zoom)
|
198 |
+
pix = page.get_pixmap(matrix=matrix)
|
199 |
+
image = Image.frombytes("RGB", (pix.width, pix.height), pix.samples)
|
200 |
+
|
201 |
+
objects = []
|
202 |
+
if detector:
|
203 |
+
objects = detector.run(image)
|
204 |
+
for o in objects:
|
205 |
+
o.page, o.source = i + 1, path
|
206 |
+
|
207 |
+
pages.append(
|
208 |
+
MultimodalPage(
|
209 |
+
number=i + 1,
|
210 |
+
objects=objects,
|
211 |
+
text=text,
|
212 |
+
image_string=convert_image_to_text(image),
|
213 |
+
source=path,
|
214 |
+
)
|
215 |
+
)
|
216 |
+
|
217 |
+
return cls(pages=pages)
|
218 |
+
|
219 |
+
@classmethod
|
220 |
+
def load(cls, path: str):
|
221 |
+
pages = []
|
222 |
+
with open(path) as f:
|
223 |
+
for line in f:
|
224 |
+
pages.append(MultimodalPage(**json.loads(line)))
|
225 |
+
return cls(pages=pages)
|
226 |
+
|
227 |
+
def save(self, path: str):
|
228 |
+
Path(path).parent.mkdir(exist_ok=True, parents=True)
|
229 |
+
with open(path, "w") as f:
|
230 |
+
for o in self.pages:
|
231 |
+
print(o.model_dump_json(), file=f)
|
232 |
+
|
233 |
+
def get_domain(self) -> str:
|
234 |
+
filename = Path(self.pages[0].source).name
|
235 |
+
if filename.startswith("NYSE"):
|
236 |
+
return "Financial<br>Report"
|
237 |
+
elif filename[:4].isdigit() and filename[4] == "." and filename[5].isdigit():
|
238 |
+
return "Academic<br>Paper"
|
239 |
+
else:
|
240 |
+
return "Technical<br>Manuals"
|
241 |
+
|
242 |
+
|
243 |
+
class MultimodalRetriever(BaseModel, arbitrary_types_allowed=True):
|
244 |
+
def run(self, query: str, doc: MultimodalDocument) -> MultimodalDocument:
|
245 |
+
raise NotImplementedError
|
246 |
+
|
247 |
+
@staticmethod
|
248 |
+
def get_top_pages(doc: MultimodalDocument, k: int) -> List[int]:
|
249 |
+
# Get top-k in terms of score but maintain the original order
|
250 |
+
doc = doc.copy(deep=True)
|
251 |
+
pages = sorted(doc.pages, key=lambda x: x.score, reverse=True)
|
252 |
+
threshold = pages[:k][-1].score
|
253 |
+
return [p.number for p in doc.pages if p.score >= threshold]
|
254 |
+
|
255 |
+
|
256 |
+
class ColpaliRetriever(MultimodalRetriever):
|
257 |
+
path: str = "vidore/colpali-v1.2"
|
258 |
+
model: Optional[ColPali] = None
|
259 |
+
processor: Optional[ColPaliProcessor] = None
|
260 |
+
device: str = "cuda"
|
261 |
+
cache: OrderedDict[str, torch.Tensor] = CollectionsOrderedDict()
|
262 |
+
|
263 |
+
def load(self):
|
264 |
+
if self.model is None:
|
265 |
+
self.model = ColPali.from_pretrained(
|
266 |
+
self.path, torch_dtype=torch.bfloat16, device_map=self.device
|
267 |
+
)
|
268 |
+
self.model = self.model.eval()
|
269 |
+
self.processor = ColPaliProcessor.from_pretrained(self.path)
|
270 |
+
|
271 |
+
def encode_document(self, doc: MultimodalDocument) -> torch.Tensor:
|
272 |
+
hash_id = hashlib.md5(doc.json().encode()).hexdigest()
|
273 |
+
if len(self.cache) > 100:
|
274 |
+
self.cache.popitem(last=False)
|
275 |
+
if hash_id not in self.cache:
|
276 |
+
images = [page.get_full_image() for page in doc.pages]
|
277 |
+
batch_size = 8
|
278 |
+
|
279 |
+
ds: List[torch.Tensor] = []
|
280 |
+
for i in tqdm(range(0, len(images), batch_size), desc="Encoding document"):
|
281 |
+
batch = self.processor.process_images(images[i : i + batch_size])
|
282 |
+
with torch.no_grad():
|
283 |
+
# noinspection PyTypeChecker
|
284 |
+
ds.append(self.model(**batch.to(self.device)).cpu())
|
285 |
+
|
286 |
+
lengths = [x.shape[1] for x in ds]
|
287 |
+
if len(set(lengths)) != 1:
|
288 |
+
print("Warning: Inconsistent lengths from colqwen", set(lengths))
|
289 |
+
assert "colqwen" in self.path
|
290 |
+
for i, x in enumerate(ds):
|
291 |
+
ds[i] = x[:, : min(lengths), :]
|
292 |
+
self.cache[hash_id] = torch.cat(ds)
|
293 |
+
return self.cache[hash_id]
|
294 |
+
|
295 |
+
def run(self, query: str, doc: MultimodalDocument) -> MultimodalDocument:
|
296 |
+
doc = doc.copy(deep=True)
|
297 |
+
self.load()
|
298 |
+
ds = self.encode_document(doc)
|
299 |
+
with torch.no_grad():
|
300 |
+
# noinspection PyTypeChecker
|
301 |
+
qs = self.model(**self.processor.process_queries([query]).to(self.device))
|
302 |
+
|
303 |
+
# noinspection PyTypeChecker
|
304 |
+
scores = self.processor.score_multi_vector(qs.cpu(), ds).squeeze()
|
305 |
+
assert len(scores) == len(doc.pages)
|
306 |
+
for i, page in enumerate(doc.pages):
|
307 |
+
page.score = scores[i].item()
|
308 |
+
|
309 |
+
return doc
|
310 |
+
|
311 |
+
|
312 |
+
class DummyRetriever(MultimodalRetriever):
|
313 |
+
def run(self, query: str, doc: MultimodalDocument) -> MultimodalDocument:
|
314 |
+
doc = doc.copy(deep=True)
|
315 |
+
for i, page in enumerate(doc.pages):
|
316 |
+
page.score = i
|
317 |
+
return doc
|
318 |
+
|
319 |
+
|
320 |
def convert_image_to_text(image: Image) -> str:
|
321 |
# This is also how OpenAI encodes images: https://platform.openai.com/docs/guides/vision
|
322 |
with io.BytesIO() as output:
|
|
|
540 |
yield "".join(outputs)
|
541 |
|
542 |
|
543 |
+
class DummyModel(EvalModel):
|
544 |
+
engine: str = ""
|
545 |
+
|
546 |
+
def run(self, inputs: List[Union[str, Image.Image]]) -> str:
|
547 |
+
return " ".join(inputs)
|
548 |
+
|
549 |
+
def run_stream(self, inputs: List[Union[str, Image.Image]]) -> Iterator[str]:
|
550 |
+
assert self is not None
|
551 |
+
text = " ".join([x for x in inputs if isinstance(x, str)])
|
552 |
+
num_images = sum(1 for x in inputs if isinstance(x, Image.Image))
|
553 |
+
tokens = f"Hello this is your message: {text}, images: {num_images}".split()
|
554 |
+
for i in range(len(tokens)):
|
555 |
+
yield " ".join(tokens[: i + 1])
|
556 |
+
import time
|
557 |
+
|
558 |
+
time.sleep(0.05)
|
559 |
+
|
560 |
+
|
561 |
if not torch.cuda.is_available():
|
562 |
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
|
563 |
|
|
|
565 |
if torch.cuda.is_available():
|
566 |
model = QwenModel()
|
567 |
model.load()
|
568 |
+
detect_model = YoloDetector()
|
569 |
+
detect_model.load()
|
570 |
+
retriever = ColpaliRetriever()
|
571 |
+
retriever.load()
|
572 |
+
else:
|
573 |
+
model = DummyModel()
|
574 |
+
detect_model = None
|
575 |
+
retriever = DummyRetriever()
|
576 |
+
|
577 |
+
|
578 |
+
def get_file_path(file: gr.File = None, url: str = None) -> Optional[str]:
|
579 |
+
if file is not None:
|
580 |
+
# noinspection PyUnresolvedReferences
|
581 |
+
return file.name
|
582 |
+
|
583 |
+
if url is not None:
|
584 |
+
response = requests.get(url)
|
585 |
+
response.raise_for_status()
|
586 |
+
save_path = Path(tempfile.mkdtemp(), url.split("/")[-1])
|
587 |
+
|
588 |
+
if "application/pdf" in response.headers.get("Content-Type", ""):
|
589 |
+
# Open the file in binary write mode
|
590 |
+
with open(save_path, "wb") as file:
|
591 |
+
file.write(response.content)
|
592 |
+
return str(save_path)
|
593 |
|
594 |
|
595 |
@spaces.GPU
|
596 |
def generate(
|
597 |
+
query: str, file: gr.File = None, url: str = None, top_k=5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
598 |
) -> Iterator[str]:
|
599 |
+
sample = MultimodalSample(question=query, answer="", category="")
|
600 |
+
path = get_file_path(file, url)
|
601 |
+
|
602 |
+
if path is not None:
|
603 |
+
doc = MultimodalDocument.load_from_pdf(path, detector=detect_model)
|
604 |
+
output = retriever.run(sample.question, doc)
|
605 |
+
sorted_pages = sorted(output.pages, key=lambda p: p.score, reverse=True)
|
606 |
+
sample.retrieved_pages = sorted([p.number for p in sorted_pages][:top_k])
|
607 |
+
|
608 |
+
context = []
|
609 |
+
for p in doc.pages:
|
610 |
+
if p.number in sample.retrieved_pages:
|
611 |
+
if p.text:
|
612 |
+
context.append(p.text)
|
613 |
+
context.extend(o.get_image() for o in p.get_tables_and_figures())
|
614 |
+
|
615 |
+
inputs = [
|
616 |
+
"Context:",
|
617 |
+
*context,
|
618 |
+
f"Answer the following question in 200 words or less: {sample.question}",
|
619 |
+
]
|
620 |
+
else:
|
621 |
+
inputs = [
|
622 |
+
"Context:",
|
623 |
+
f"Answer the following question in 200 words or less: {sample.question}",
|
624 |
+
]
|
625 |
|
626 |
+
for text in model.run_stream(inputs):
|
627 |
+
yield text
|
628 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
629 |
|
630 |
with gr.Blocks(css_paths="style.css", fill_height=True) as demo:
|
631 |
gr.Markdown(DESCRIPTION)
|
632 |
gr.DuplicateButton(
|
633 |
value="Duplicate Space for private use", elem_id="duplicate-button"
|
634 |
)
|
635 |
+
|
636 |
+
with gr.Row():
|
637 |
+
pdf_upload = gr.File(label="Upload PDF (optional)", file_types=[".pdf"])
|
638 |
+
with gr.Column():
|
639 |
+
url_input = gr.Textbox(label="Enter PDF URL (optional)")
|
640 |
+
text_input = gr.Textbox(label="Enter your message", lines=3)
|
641 |
+
|
642 |
+
submit_button = gr.Button("Submit")
|
643 |
+
result = gr.Textbox(label="Response", lines=10)
|
644 |
+
|
645 |
+
submit_button.click(
|
646 |
+
generate, inputs=[text_input, pdf_upload, url_input], outputs=result
|
647 |
+
)
|
648 |
+
|
649 |
gr.Markdown(LICENSE)
|
650 |
|
651 |
+
demo.launch()
|
652 |
+
|
653 |
if __name__ == "__main__":
|
654 |
demo.queue(max_size=20).launch()
|