ighoshsubho's picture
app.py created (#1)
73684bb verified
import gradio as gr
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
import torch
from transformers import BitsAndBytesConfig
from PIL import Image
import os
def load_model():
"""Load the model and processor"""
repo_name = "ighoshsubho/pali-gamma-finetuned-json"
device = "cuda" if torch.cuda.is_available() else "cpu"
# Configure quantization
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True
)
# Load processor and model
processor = PaliGemmaProcessor.from_pretrained(repo_name)
model = PaliGemmaForConditionalGeneration.from_pretrained(
repo_name,
quantization_config=quantization_config,
device_map=device,
torch_dtype=torch.bfloat16 if device == "cuda" else None
)
return model, processor
# Load model globally
print("Loading model...")
model, processor = load_model()
print("Model loaded successfully!")
def process_image(image, prompt):
"""Process the image and return the model's output"""
try:
# Ensure image is in PIL format
if not isinstance(image, Image.Image):
image = Image.open(image)
# Prepare inputs
inputs = processor(
text=[f"<image>{prompt}"],
images=[image],
return_tensors="pt",
padding="longest"
).to(model.device)
# Generate output
outputs = model.generate(
**inputs,
max_length=512,
num_beams=5,
temperature=0.7
)
# Decode output
result = processor.decode(outputs[0], skip_special_tokens=True)
return result
except Exception as e:
return f"Error processing image: {str(e)}"
# Create Gradio interface
demo = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Textbox(
label="Prompt",
placeholder="Enter your prompt here...",
value="extract data in JSON format"
)
],
outputs=gr.Textbox(label="Generated Output"),
title="PaLI-GAMMA Image Analysis",
description="Upload an image and get structured data extracted in JSON format. The model is running in 4-bit quantization mode.",
)
if __name__ == "__main__":
demo.launch()