|
--- |
|
license: mit |
|
--- |
|
|
|
## MIRAGE |
|
|
|
**Model Type:** MIRAGE is an innovative open-source visual-RAG model capable of processing over 10,000 images as input. It integrates a retriever and a large multimodal model (LMM) for enhanced performance. |
|
|
|
**Key Features:** |
|
- **Compressor:** Reduces data size by compressing image tokens by 18x per image, enabling efficient handling of large datasets. |
|
- **Query-Aware Retriever:** Dynamically filters out irrelevant images to focus processing power on content that enhances task performance. |
|
- **Multi-Image LMM:** Features a tailored pretraining and instruction tuning dataset, designed to optimize model performance across a range of multimodal tasks. |
|
|
|
**Performance:** |
|
- MIRAGE establishes a new benchmark in open-source performance on the [Visual Haystacks (VHs) benchmark](https://huggingface.co/datasets/tsunghanwu/visual_haystacks). |
|
- Delivers robust results across various single- and multi-image question answering tasks, such as RETVQA, MMBench, MMVet, VQAv2, and more. |
|
|
|
**Usage:** |
|
Please refer to the installation guide on our GitHub repository to get started with MIRAGE: [Installation Guide](https://github.com/visual-haystacks/mirage) |
|
|
|
**Additional Resources:** |
|
For detailed information and updates, visit our project page: [Visual Haystacks Project](https://visual-haystacks.github.io/) |
|
|
|
**Support:** |
|
For questions or comments about the model, please open an issue on our GitHub page: [GitHub Issues](https://github.com/visual-haystacks/mirage/issues) |
|
|
|
**Intended Use:** |
|
MIRAGE is primarily intended for research into large multimodal models (LMMs), long-context modeling, and retrieval-augmented generation (RAG). |
|
|
|
### Example Usage Code |
|
|
|
```python |
|
from PIL import Image |
|
import argparse |
|
import torch |
|
import os |
|
|
|
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN |
|
from llava.conversation import conv_templates |
|
from llava.model.builder import load_pretrained_model |
|
from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path |
|
from llava.utils import disable_torch_init |
|
|
|
@torch.inference_mode() |
|
def run(model_path, image_paths, prompt, num_retrievals=1): |
|
''' |
|
Executes MIRAGE with specified inputs to generate descriptive text based on the provided images. |
|
|
|
Args: |
|
model_path (str): Path to the MIRAGE model, e.g., 'tsunghanwu/mirage-llama3.1-8.3B' |
|
image_paths (list): List of paths to image files, e.g., images in 'assets/example' |
|
prompt (str): Text prompt for image description, e.g., 'Here are a set of random images in my photo album. |
|
If you can find a cat, tell me what's the cat doing and what's its color.' |
|
num_retrievals (int): Maximum number of images to retrieve and pass to the LMM |
|
|
|
Returns: |
|
output_text (str): Descriptive text generated by the LMM |
|
output_ret (list): List of images retrieved by the model |
|
''' |
|
# Load the model and prepare the environment |
|
model_name = get_model_name_from_path(model_path) |
|
disable_torch_init() |
|
model_name = os.path.expanduser(model_name) |
|
tokenizer, model, image_processor, _ = \ |
|
load_pretrained_model(model_path=model_path, model_base=None, model_name=model_name, device="cuda") |
|
model.eval_mode = True |
|
|
|
# Process the images |
|
clip_images = [] |
|
for image_path in image_paths: |
|
image = Image.open(image_path). convert("RGB") |
|
image_tensor = process_images([image], image_processor, model.config)[0] |
|
image_tensor = image_tensor.to(dtype=torch.float16) |
|
clip_images.append(image_tensor) |
|
|
|
# Prepare text input and interaction |
|
qformer_text_input = tokenizer(prompt, return_tensors='pt')["input_ids"].to(model.device) |
|
N = len(clip_images) |
|
img_str = DEFAULT_IMAGE_TOKEN * N + "\n" |
|
inp = img_str + prompt |
|
conv.append_message(conv.roles[0], inp) |
|
conv.append_message(conv.roles[1], None) |
|
prompt = conv.get_prompt() |
|
|
|
# Generate model output |
|
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) |
|
tokenizer.pad_token_id = 128002 |
|
batch_clip_images = [torch.stack(clip_images).to(model.device)] |
|
|
|
output_ret, output_ids = model.generate( |
|
input_ids, |
|
pad_token_id=tokenizer.pad_token_id, |
|
clip_images=batch_clip_images, |
|
qformer_text_input=qformer_text_input, |
|
relevance=None, |
|
num_retrieval=num_retrievals, |
|
do_sample=False, |
|
max_new_tokens=512, |
|
use_cache=True) |
|
|
|
# Process output |
|
output_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() |
|
if not isinstance(output_ret[0], list): |
|
output_ret[0] = output_ret[0].tolist() |
|
return output_text, output_ret[0] |
|
``` |