chiayewken commited on
Commit
d65c3c0
·
1 Parent(s): 3a8c647

Support doc loading, detector and retriever

Browse files
Files changed (1) hide show
  1. app.py +357 -67
app.py CHANGED
@@ -1,15 +1,22 @@
1
  import base64
2
  import hashlib
3
  import io
 
4
  import os
 
 
5
  from pathlib import Path
6
  from threading import Thread
7
- from typing import Iterator, Optional, List, Union
8
 
 
9
  import gradio as gr
 
10
  import spaces
11
  import torch
12
  from PIL import Image
 
 
13
  from pydantic import BaseModel
14
  from qwen_vl_utils import process_vision_info
15
  from swift.llm import (
@@ -20,6 +27,7 @@ from swift.llm import (
20
  inference,
21
  inference_stream,
22
  )
 
23
  from transformers import (
24
  Qwen2VLForConditionalGeneration,
25
  PreTrainedTokenizer,
@@ -27,6 +35,8 @@ from transformers import (
27
  TextIteratorStreamer,
28
  AutoTokenizer,
29
  )
 
 
30
 
31
  MAX_MAX_NEW_TOKENS = 2048
32
  DEFAULT_MAX_NEW_TOKENS = 1024
@@ -49,6 +59,264 @@ this demo is governed by the original [license](https://huggingface.co/meta-llam
49
  """
50
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def convert_image_to_text(image: Image) -> str:
53
  # This is also how OpenAI encodes images: https://platform.openai.com/docs/guides/vision
54
  with io.BytesIO() as output:
@@ -272,6 +540,24 @@ class QwenModel(EvalModel):
272
  yield "".join(outputs)
273
 
274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  if not torch.cuda.is_available():
276
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
277
 
@@ -279,86 +565,90 @@ if not torch.cuda.is_available():
279
  if torch.cuda.is_available():
280
  model = QwenModel()
281
  model.load()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
 
284
  @spaces.GPU
285
  def generate(
286
- message: str,
287
- chat_history: list[dict],
288
- system_prompt: str = "",
289
- max_new_tokens: int = 1024,
290
- temperature: float = 0.6,
291
- top_p: float = 0.9,
292
- top_k: int = 50,
293
- repetition_penalty: float = 1.2,
294
  ) -> Iterator[str]:
295
- for text in model.run_stream([message]):
296
- yield text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
 
 
 
298
 
299
- chat_interface = gr.ChatInterface(
300
- fn=generate,
301
- additional_inputs=[
302
- gr.Textbox(label="System prompt", lines=6),
303
- gr.Slider(
304
- label="Max new tokens",
305
- minimum=1,
306
- maximum=MAX_MAX_NEW_TOKENS,
307
- step=1,
308
- value=DEFAULT_MAX_NEW_TOKENS,
309
- ),
310
- gr.Slider(
311
- label="Temperature",
312
- minimum=0.1,
313
- maximum=4.0,
314
- step=0.1,
315
- value=0.6,
316
- ),
317
- gr.Slider(
318
- label="Top-p (nucleus sampling)",
319
- minimum=0.05,
320
- maximum=1.0,
321
- step=0.05,
322
- value=0.9,
323
- ),
324
- gr.Slider(
325
- label="Top-k",
326
- minimum=1,
327
- maximum=1000,
328
- step=1,
329
- value=50,
330
- ),
331
- gr.Slider(
332
- label="Repetition penalty",
333
- minimum=1.0,
334
- maximum=2.0,
335
- step=0.05,
336
- value=1.2,
337
- ),
338
- ],
339
- stop_btn=None,
340
- examples=[
341
- [
342
- "Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now?"
343
- ],
344
- [
345
- "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?"
346
- ],
347
- [
348
- "Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. How much did she earn?"
349
- ],
350
- ],
351
- cache_examples=False,
352
- type="messages",
353
- )
354
 
355
  with gr.Blocks(css_paths="style.css", fill_height=True) as demo:
356
  gr.Markdown(DESCRIPTION)
357
  gr.DuplicateButton(
358
  value="Duplicate Space for private use", elem_id="duplicate-button"
359
  )
360
- chat_interface.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  gr.Markdown(LICENSE)
362
 
 
 
363
  if __name__ == "__main__":
364
  demo.queue(max_size=20).launch()
 
1
  import base64
2
  import hashlib
3
  import io
4
+ import json
5
  import os
6
+ import tempfile
7
+ from collections import OrderedDict as CollectionsOrderedDict
8
  from pathlib import Path
9
  from threading import Thread
10
+ from typing import Iterator, Optional, List, Union, OrderedDict
11
 
12
+ import fitz
13
  import gradio as gr
14
+ import requests
15
  import spaces
16
  import torch
17
  from PIL import Image
18
+ from colpali_engine import ColPali, ColPaliProcessor
19
+ from huggingface_hub import hf_hub_download
20
  from pydantic import BaseModel
21
  from qwen_vl_utils import process_vision_info
22
  from swift.llm import (
 
27
  inference,
28
  inference_stream,
29
  )
30
+ from tqdm import tqdm
31
  from transformers import (
32
  Qwen2VLForConditionalGeneration,
33
  PreTrainedTokenizer,
 
35
  TextIteratorStreamer,
36
  AutoTokenizer,
37
  )
38
+ from ultralytics import YOLO
39
+ from ultralytics.engine.results import Results
40
 
41
  MAX_MAX_NEW_TOKENS = 2048
42
  DEFAULT_MAX_NEW_TOKENS = 1024
 
59
  """
60
 
61
 
62
+ class MultimodalSample(BaseModel):
63
+ question: str
64
+ answer: str
65
+ category: str
66
+ evidence_pages: List[int] = []
67
+ raw_output: str = ""
68
+ pred: str = ""
69
+ source: str = ""
70
+ annotator: str = ""
71
+ generator: str = ""
72
+ retrieved_pages: List[int] = []
73
+
74
+
75
+ class MultimodalObject(BaseModel):
76
+ id: str = ""
77
+ page: int = 0
78
+ text: str = ""
79
+ image_string: str = ""
80
+ snippet: str = ""
81
+ score: float = 0.0
82
+ source: str = ""
83
+ category: str = ""
84
+
85
+ def get_image(self) -> Optional[Image.Image]:
86
+ if self.image_string:
87
+ return convert_text_to_image(self.image_string)
88
+
89
+ @classmethod
90
+ def from_image(cls, image: Image.Image, **kwargs):
91
+ return cls(image_string=convert_image_to_text(image), **kwargs)
92
+
93
+
94
+ class ObjectDetector(BaseModel, arbitrary_types_allowed=True):
95
+ def run(self, image: Image.Image) -> List[MultimodalObject]:
96
+ raise NotImplementedError()
97
+
98
+
99
+ class YoloDetector(ObjectDetector):
100
+ repo_id: str = "DILHTWD/documentlayoutsegmentation_YOLOv8_ondoclaynet"
101
+ filename: str = "yolov8x-doclaynet-epoch64-imgsz640-initiallr1e-4-finallr1e-5.pt"
102
+ local_dir: str = "data/yolo"
103
+ client: Optional[YOLO] = None
104
+
105
+ def load(self):
106
+ if self.client is None:
107
+ if not Path(self.local_dir, self.filename).exists():
108
+ hf_hub_download(
109
+ repo_id=self.repo_id,
110
+ filename=self.filename,
111
+ local_dir=self.local_dir,
112
+ )
113
+ self.client = YOLO(Path(self.local_dir, self.filename))
114
+
115
+ def save_image(self, image: Image.Image) -> str:
116
+ text = convert_image_to_text(image)
117
+ hash_id = hashlib.md5(text.encode()).hexdigest()
118
+ path = Path(self.local_dir, f"{hash_id}.png")
119
+ image.save(path)
120
+ return str(path)
121
+
122
+ @staticmethod
123
+ def extract_subimage(image: Image.Image, box: List[float]) -> Image.Image:
124
+ return image.crop((round(box[0]), round(box[1]), round(box[2]), round(box[3])))
125
+
126
+ def run(self, image: Image.Image) -> List[MultimodalObject]:
127
+ self.load()
128
+ path = self.save_image(image)
129
+ results: List[Results] = self.client(source=[path])
130
+ assert len(results) == 1
131
+ objects = []
132
+
133
+ for i, label_id in enumerate(results[0].boxes.cls):
134
+ label = results[0].names[label_id.item()]
135
+ score = results[0].boxes.conf[i].item()
136
+ box: List[float] = results[0].boxes.xyxy[i].tolist()
137
+ subimage = self.extract_subimage(image, box)
138
+ objects.append(
139
+ MultimodalObject(
140
+ image_string=convert_image_to_text(subimage),
141
+ category=label,
142
+ score=score,
143
+ )
144
+ )
145
+
146
+ return objects
147
+
148
+
149
+ class MultimodalPage(BaseModel):
150
+ number: int
151
+ objects: List[MultimodalObject]
152
+ text: str
153
+ image_string: str
154
+ source: str
155
+ score: float = 0.0
156
+
157
+ def get_tables_and_figures(self) -> List[MultimodalObject]:
158
+ return [o for o in self.objects if o.category in ["Table", "Picture"]]
159
+
160
+ def get_full_image(self) -> Image.Image:
161
+ return convert_text_to_image(self.image_string)
162
+
163
+ @classmethod
164
+ def from_text(cls, text: str):
165
+ return MultimodalPage(
166
+ text=text, number=0, objects=[], image_string="", source=""
167
+ )
168
+
169
+ @classmethod
170
+ def from_image(cls, image: Image.Image):
171
+ return MultimodalPage(
172
+ image_string=convert_image_to_text(image),
173
+ number=0,
174
+ objects=[],
175
+ text="",
176
+ source="",
177
+ )
178
+
179
+
180
+ class MultimodalDocument(BaseModel):
181
+ pages: List[MultimodalPage]
182
+
183
+ def get_page(self, i: int) -> MultimodalPage:
184
+ pages = [p for p in self.pages if p.number == i]
185
+ assert len(pages) == 1
186
+ return pages[0]
187
+
188
+ @classmethod
189
+ def load_from_pdf(cls, path: str, dpi: int = 150, detector: ObjectDetector = None):
190
+ # Each page as an image (with optional extracted text)
191
+ doc = fitz.open(path)
192
+ pages = []
193
+
194
+ for i, page in enumerate(tqdm(doc.pages(), desc=path)):
195
+ text = page.get_text()
196
+ zoom = dpi / 72 # 72 is the default DPI
197
+ matrix = fitz.Matrix(zoom, zoom)
198
+ pix = page.get_pixmap(matrix=matrix)
199
+ image = Image.frombytes("RGB", (pix.width, pix.height), pix.samples)
200
+
201
+ objects = []
202
+ if detector:
203
+ objects = detector.run(image)
204
+ for o in objects:
205
+ o.page, o.source = i + 1, path
206
+
207
+ pages.append(
208
+ MultimodalPage(
209
+ number=i + 1,
210
+ objects=objects,
211
+ text=text,
212
+ image_string=convert_image_to_text(image),
213
+ source=path,
214
+ )
215
+ )
216
+
217
+ return cls(pages=pages)
218
+
219
+ @classmethod
220
+ def load(cls, path: str):
221
+ pages = []
222
+ with open(path) as f:
223
+ for line in f:
224
+ pages.append(MultimodalPage(**json.loads(line)))
225
+ return cls(pages=pages)
226
+
227
+ def save(self, path: str):
228
+ Path(path).parent.mkdir(exist_ok=True, parents=True)
229
+ with open(path, "w") as f:
230
+ for o in self.pages:
231
+ print(o.model_dump_json(), file=f)
232
+
233
+ def get_domain(self) -> str:
234
+ filename = Path(self.pages[0].source).name
235
+ if filename.startswith("NYSE"):
236
+ return "Financial<br>Report"
237
+ elif filename[:4].isdigit() and filename[4] == "." and filename[5].isdigit():
238
+ return "Academic<br>Paper"
239
+ else:
240
+ return "Technical<br>Manuals"
241
+
242
+
243
+ class MultimodalRetriever(BaseModel, arbitrary_types_allowed=True):
244
+ def run(self, query: str, doc: MultimodalDocument) -> MultimodalDocument:
245
+ raise NotImplementedError
246
+
247
+ @staticmethod
248
+ def get_top_pages(doc: MultimodalDocument, k: int) -> List[int]:
249
+ # Get top-k in terms of score but maintain the original order
250
+ doc = doc.copy(deep=True)
251
+ pages = sorted(doc.pages, key=lambda x: x.score, reverse=True)
252
+ threshold = pages[:k][-1].score
253
+ return [p.number for p in doc.pages if p.score >= threshold]
254
+
255
+
256
+ class ColpaliRetriever(MultimodalRetriever):
257
+ path: str = "vidore/colpali-v1.2"
258
+ model: Optional[ColPali] = None
259
+ processor: Optional[ColPaliProcessor] = None
260
+ device: str = "cuda"
261
+ cache: OrderedDict[str, torch.Tensor] = CollectionsOrderedDict()
262
+
263
+ def load(self):
264
+ if self.model is None:
265
+ self.model = ColPali.from_pretrained(
266
+ self.path, torch_dtype=torch.bfloat16, device_map=self.device
267
+ )
268
+ self.model = self.model.eval()
269
+ self.processor = ColPaliProcessor.from_pretrained(self.path)
270
+
271
+ def encode_document(self, doc: MultimodalDocument) -> torch.Tensor:
272
+ hash_id = hashlib.md5(doc.json().encode()).hexdigest()
273
+ if len(self.cache) > 100:
274
+ self.cache.popitem(last=False)
275
+ if hash_id not in self.cache:
276
+ images = [page.get_full_image() for page in doc.pages]
277
+ batch_size = 8
278
+
279
+ ds: List[torch.Tensor] = []
280
+ for i in tqdm(range(0, len(images), batch_size), desc="Encoding document"):
281
+ batch = self.processor.process_images(images[i : i + batch_size])
282
+ with torch.no_grad():
283
+ # noinspection PyTypeChecker
284
+ ds.append(self.model(**batch.to(self.device)).cpu())
285
+
286
+ lengths = [x.shape[1] for x in ds]
287
+ if len(set(lengths)) != 1:
288
+ print("Warning: Inconsistent lengths from colqwen", set(lengths))
289
+ assert "colqwen" in self.path
290
+ for i, x in enumerate(ds):
291
+ ds[i] = x[:, : min(lengths), :]
292
+ self.cache[hash_id] = torch.cat(ds)
293
+ return self.cache[hash_id]
294
+
295
+ def run(self, query: str, doc: MultimodalDocument) -> MultimodalDocument:
296
+ doc = doc.copy(deep=True)
297
+ self.load()
298
+ ds = self.encode_document(doc)
299
+ with torch.no_grad():
300
+ # noinspection PyTypeChecker
301
+ qs = self.model(**self.processor.process_queries([query]).to(self.device))
302
+
303
+ # noinspection PyTypeChecker
304
+ scores = self.processor.score_multi_vector(qs.cpu(), ds).squeeze()
305
+ assert len(scores) == len(doc.pages)
306
+ for i, page in enumerate(doc.pages):
307
+ page.score = scores[i].item()
308
+
309
+ return doc
310
+
311
+
312
+ class DummyRetriever(MultimodalRetriever):
313
+ def run(self, query: str, doc: MultimodalDocument) -> MultimodalDocument:
314
+ doc = doc.copy(deep=True)
315
+ for i, page in enumerate(doc.pages):
316
+ page.score = i
317
+ return doc
318
+
319
+
320
  def convert_image_to_text(image: Image) -> str:
321
  # This is also how OpenAI encodes images: https://platform.openai.com/docs/guides/vision
322
  with io.BytesIO() as output:
 
540
  yield "".join(outputs)
541
 
542
 
543
+ class DummyModel(EvalModel):
544
+ engine: str = ""
545
+
546
+ def run(self, inputs: List[Union[str, Image.Image]]) -> str:
547
+ return " ".join(inputs)
548
+
549
+ def run_stream(self, inputs: List[Union[str, Image.Image]]) -> Iterator[str]:
550
+ assert self is not None
551
+ text = " ".join([x for x in inputs if isinstance(x, str)])
552
+ num_images = sum(1 for x in inputs if isinstance(x, Image.Image))
553
+ tokens = f"Hello this is your message: {text}, images: {num_images}".split()
554
+ for i in range(len(tokens)):
555
+ yield " ".join(tokens[: i + 1])
556
+ import time
557
+
558
+ time.sleep(0.05)
559
+
560
+
561
  if not torch.cuda.is_available():
562
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
563
 
 
565
  if torch.cuda.is_available():
566
  model = QwenModel()
567
  model.load()
568
+ detect_model = YoloDetector()
569
+ detect_model.load()
570
+ retriever = ColpaliRetriever()
571
+ retriever.load()
572
+ else:
573
+ model = DummyModel()
574
+ detect_model = None
575
+ retriever = DummyRetriever()
576
+
577
+
578
+ def get_file_path(file: gr.File = None, url: str = None) -> Optional[str]:
579
+ if file is not None:
580
+ # noinspection PyUnresolvedReferences
581
+ return file.name
582
+
583
+ if url is not None:
584
+ response = requests.get(url)
585
+ response.raise_for_status()
586
+ save_path = Path(tempfile.mkdtemp(), url.split("/")[-1])
587
+
588
+ if "application/pdf" in response.headers.get("Content-Type", ""):
589
+ # Open the file in binary write mode
590
+ with open(save_path, "wb") as file:
591
+ file.write(response.content)
592
+ return str(save_path)
593
 
594
 
595
  @spaces.GPU
596
  def generate(
597
+ query: str, file: gr.File = None, url: str = None, top_k=5
 
 
 
 
 
 
 
598
  ) -> Iterator[str]:
599
+ sample = MultimodalSample(question=query, answer="", category="")
600
+ path = get_file_path(file, url)
601
+
602
+ if path is not None:
603
+ doc = MultimodalDocument.load_from_pdf(path, detector=detect_model)
604
+ output = retriever.run(sample.question, doc)
605
+ sorted_pages = sorted(output.pages, key=lambda p: p.score, reverse=True)
606
+ sample.retrieved_pages = sorted([p.number for p in sorted_pages][:top_k])
607
+
608
+ context = []
609
+ for p in doc.pages:
610
+ if p.number in sample.retrieved_pages:
611
+ if p.text:
612
+ context.append(p.text)
613
+ context.extend(o.get_image() for o in p.get_tables_and_figures())
614
+
615
+ inputs = [
616
+ "Context:",
617
+ *context,
618
+ f"Answer the following question in 200 words or less: {sample.question}",
619
+ ]
620
+ else:
621
+ inputs = [
622
+ "Context:",
623
+ f"Answer the following question in 200 words or less: {sample.question}",
624
+ ]
625
 
626
+ for text in model.run_stream(inputs):
627
+ yield text
628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
629
 
630
  with gr.Blocks(css_paths="style.css", fill_height=True) as demo:
631
  gr.Markdown(DESCRIPTION)
632
  gr.DuplicateButton(
633
  value="Duplicate Space for private use", elem_id="duplicate-button"
634
  )
635
+
636
+ with gr.Row():
637
+ pdf_upload = gr.File(label="Upload PDF (optional)", file_types=[".pdf"])
638
+ with gr.Column():
639
+ url_input = gr.Textbox(label="Enter PDF URL (optional)")
640
+ text_input = gr.Textbox(label="Enter your message", lines=3)
641
+
642
+ submit_button = gr.Button("Submit")
643
+ result = gr.Textbox(label="Response", lines=10)
644
+
645
+ submit_button.click(
646
+ generate, inputs=[text_input, pdf_upload, url_input], outputs=result
647
+ )
648
+
649
  gr.Markdown(LICENSE)
650
 
651
+ demo.launch()
652
+
653
  if __name__ == "__main__":
654
  demo.queue(max_size=20).launch()