File size: 2,482 Bytes
73684bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
import torch
from transformers import BitsAndBytesConfig
from PIL import Image
import os

def load_model():
    """Load the model and processor"""
    repo_name = "ighoshsubho/pali-gamma-finetuned-json"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Configure quantization
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True
    )
    
    # Load processor and model
    processor = PaliGemmaProcessor.from_pretrained(repo_name)
    model = PaliGemmaForConditionalGeneration.from_pretrained(
        repo_name,
        quantization_config=quantization_config,
        device_map=device,
        torch_dtype=torch.bfloat16 if device == "cuda" else None
    )
    
    return model, processor

# Load model globally
print("Loading model...")
model, processor = load_model()
print("Model loaded successfully!")

def process_image(image, prompt):
    """Process the image and return the model's output"""
    try:
        # Ensure image is in PIL format
        if not isinstance(image, Image.Image):
            image = Image.open(image)
        
        # Prepare inputs
        inputs = processor(
            text=[f"<image>{prompt}"],
            images=[image],
            return_tensors="pt",
            padding="longest"
        ).to(model.device)
        
        # Generate output
        outputs = model.generate(
            **inputs,
            max_length=512,
            num_beams=5,
            temperature=0.7
        )
        
        # Decode output
        result = processor.decode(outputs[0], skip_special_tokens=True)
        return result
    
    except Exception as e:
        return f"Error processing image: {str(e)}"

# Create Gradio interface
demo = gr.Interface(
    fn=process_image,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Textbox(
            label="Prompt",
            placeholder="Enter your prompt here...",
            value="extract data in JSON format"
        )
    ],
    outputs=gr.Textbox(label="Generated Output"),
    title="PaLI-GAMMA Image Analysis",
    description="Upload an image and get structured data extracted in JSON format. The model is running in 4-bit quantization mode.",
)

if __name__ == "__main__":
    demo.launch()