davanstrien's picture
davanstrien HF staff
link
b42e8c6
raw
history blame
6.27 kB
import subprocess # 🥲
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
import spaces
import gradio as gr
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
import os
import json
from pydantic import BaseModel
from typing import Tuple
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto",
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
class GeneralRetrievalQuery(BaseModel):
broad_topical_query: str
broad_topical_explanation: str
specific_detail_query: str
specific_detail_explanation: str
visual_element_query: str
visual_element_explanation: str
def get_retrieval_prompt(prompt_name: str) -> Tuple[str, GeneralRetrievalQuery]:
if prompt_name != "general":
raise ValueError("Only 'general' prompt is available in this version")
prompt = """You are an AI assistant specialized in document retrieval tasks. Given an image of a document page, your task is to generate retrieval queries that someone might use to find this document in a large corpus.
Please generate 3 different types of retrieval queries:
1. A broad topical query: This should cover the main subject of the document.
2. A specific detail query: This should focus on a particular fact, figure, or point made in the document.
3. A visual element query: This should reference a chart, graph, image, or other visual component in the document, if present. Don't just reference the name of the visual element but generate a query which this illustration may help answer or be related to.
Important guidelines:
- Ensure the queries are relevant for retrieval tasks, not just describing the page content.
- Frame the queries as if someone is searching for this document, not asking questions about its content.
- Make the queries diverse and representative of different search strategies.
For each query, also provide a brief explanation of why this query would be effective in retrieving this document.
Format your response as a JSON object with the following structure:
{
"broad_topical_query": "Your query here",
"broad_topical_explanation": "Brief explanation",
"specific_detail_query": "Your query here",
"specific_detail_explanation": "Brief explanation",
"visual_element_query": "Your query here",
"visual_element_explanation": "Brief explanation"
}
If there are no relevant visual elements, replace the third query with another specific detail query.
Here is the document image to analyze:
<image>
Generate the queries based on this image and provide the response in the specified JSON format."""
return prompt, GeneralRetrievalQuery
# defined like this so we can later add more prompting options
prompt, pydantic_model = get_retrieval_prompt("general")
def _prep_data_for_input(image):
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": image,
},
{"type": "text", "text": prompt},
],
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
return processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
@spaces.GPU
def generate_response(image):
inputs = _prep_data_for_input(image)
inputs = inputs.to("cuda")
generated_ids = model.generate(**inputs, max_new_tokens=200)
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
try:
return json.loads(output_text[0])
except Exception:
gr.Warning("Failed to parse JSON from output")
return {}
title = "ColPali fine-tuning Query Generator"
description = """[ColPali](https://huggingface.co/papers/2407.01449) is a very exciting new approach to multimodal document retrieval which aims to replace existing document retrievers which often rely on an OCR step with an end-to-end multimodal approach.
To train or fine-tune a ColPali model, we need a dataset of image-text pairs which represent the document images and the relevant text queries which those documents should match.
To make the ColPali models work even better we might want a dataset of query/image document pairs related to our domain or task.
One way in which we might go about generating such a dataset is to use an VLM to generate synthetic queries for us.
This space uses the [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) VLM model to generate queries for a document, based on an input document image.
**Note** there is a lot of scope for improving to prompts and the quality of the generated queries! If you have any suggestions for improvements please [open a Discussion](https://huggingface.co/spaces/davanstrien/ColPali-Query-Generator/discussions/new)!
This [blog post](https://danielvanstrien.xyz/posts/post-with-code/colpali/2024-09-23-generate_colpali_dataset.html) gives an overview of how you can use this kind of approach to generate a full dataset for fine-tuning ColPali models.
If you want to convert a PDF(s) to a dataset of page images you can try out the [ PDFs to Page Images Converter](https://huggingface.co/spaces/Dataset-Creation-Tools/pdf-to-page-images-dataset) Space.
"""
examples = [
"examples/Approche_no_13_1977.pdf_page_22.jpg",
"examples/SRCCL_Technical-Summary.pdf_page_7.jpg",
]
demo = gr.Interface(
fn=generate_response,
inputs=gr.Image(type="pil"),
outputs=gr.Json(),
title=title,
description=description,
examples=examples,
)
demo.launch()