Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration | |
import torch | |
from transformers import BitsAndBytesConfig | |
from PIL import Image | |
import os | |
def load_model(): | |
"""Load the model and processor""" | |
repo_name = "ighoshsubho/pali-gamma-finetuned-json" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Configure quantization | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True | |
) | |
# Load processor and model | |
processor = PaliGemmaProcessor.from_pretrained(repo_name) | |
model = PaliGemmaForConditionalGeneration.from_pretrained( | |
repo_name, | |
quantization_config=quantization_config, | |
device_map=device, | |
torch_dtype=torch.bfloat16 if device == "cuda" else None | |
) | |
return model, processor | |
# Load model globally | |
print("Loading model...") | |
model, processor = load_model() | |
print("Model loaded successfully!") | |
def process_image(image, prompt): | |
"""Process the image and return the model's output""" | |
try: | |
# Ensure image is in PIL format | |
if not isinstance(image, Image.Image): | |
image = Image.open(image) | |
# Prepare inputs | |
inputs = processor( | |
text=[f"<image>{prompt}"], | |
images=[image], | |
return_tensors="pt", | |
padding="longest" | |
).to(model.device) | |
# Generate output | |
outputs = model.generate( | |
**inputs, | |
max_length=512, | |
num_beams=5, | |
temperature=0.7 | |
) | |
# Decode output | |
result = processor.decode(outputs[0], skip_special_tokens=True) | |
return result | |
except Exception as e: | |
return f"Error processing image: {str(e)}" | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=process_image, | |
inputs=[ | |
gr.Image(type="pil", label="Upload Image"), | |
gr.Textbox( | |
label="Prompt", | |
placeholder="Enter your prompt here...", | |
value="extract data in JSON format" | |
) | |
], | |
outputs=gr.Textbox(label="Generated Output"), | |
title="PaLI-GAMMA Image Analysis", | |
description="Upload an image and get structured data extracted in JSON format. The model is running in 4-bit quantization mode.", | |
) | |
if __name__ == "__main__": | |
demo.launch() |