Files changed (1) hide show
  1. app.py +84 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
3
+ import torch
4
+ from transformers import BitsAndBytesConfig
5
+ from PIL import Image
6
+ import os
7
+
8
+ def load_model():
9
+ """Load the model and processor"""
10
+ repo_name = "ighoshsubho/pali-gamma-finetuned-json"
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+
13
+ # Configure quantization
14
+ quantization_config = BitsAndBytesConfig(
15
+ load_in_4bit=True,
16
+ bnb_4bit_compute_dtype=torch.bfloat16,
17
+ bnb_4bit_quant_type="nf4",
18
+ bnb_4bit_use_double_quant=True
19
+ )
20
+
21
+ # Load processor and model
22
+ processor = PaliGemmaProcessor.from_pretrained(repo_name)
23
+ model = PaliGemmaForConditionalGeneration.from_pretrained(
24
+ repo_name,
25
+ quantization_config=quantization_config,
26
+ device_map=device,
27
+ torch_dtype=torch.bfloat16 if device == "cuda" else None
28
+ )
29
+
30
+ return model, processor
31
+
32
+ # Load model globally
33
+ print("Loading model...")
34
+ model, processor = load_model()
35
+ print("Model loaded successfully!")
36
+
37
+ def process_image(image, prompt):
38
+ """Process the image and return the model's output"""
39
+ try:
40
+ # Ensure image is in PIL format
41
+ if not isinstance(image, Image.Image):
42
+ image = Image.open(image)
43
+
44
+ # Prepare inputs
45
+ inputs = processor(
46
+ text=[f"<image>{prompt}"],
47
+ images=[image],
48
+ return_tensors="pt",
49
+ padding="longest"
50
+ ).to(model.device)
51
+
52
+ # Generate output
53
+ outputs = model.generate(
54
+ **inputs,
55
+ max_length=512,
56
+ num_beams=5,
57
+ temperature=0.7
58
+ )
59
+
60
+ # Decode output
61
+ result = processor.decode(outputs[0], skip_special_tokens=True)
62
+ return result
63
+
64
+ except Exception as e:
65
+ return f"Error processing image: {str(e)}"
66
+
67
+ # Create Gradio interface
68
+ demo = gr.Interface(
69
+ fn=process_image,
70
+ inputs=[
71
+ gr.Image(type="pil", label="Upload Image"),
72
+ gr.Textbox(
73
+ label="Prompt",
74
+ placeholder="Enter your prompt here...",
75
+ value="extract data in JSON format"
76
+ )
77
+ ],
78
+ outputs=gr.Textbox(label="Generated Output"),
79
+ title="PaLI-GAMMA Image Analysis",
80
+ description="Upload an image and get structured data extracted in JSON format. The model is running in 4-bit quantization mode.",
81
+ )
82
+
83
+ if __name__ == "__main__":
84
+ demo.launch()