sitammeur commited on
Commit
1790ec9
·
verified ·
1 Parent(s): e4b8814

Upload 3 files

Browse files
src/paligemma/__init__.py ADDED
File without changes
src/paligemma/model.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Necessary imports
2
+ import os
3
+ import sys
4
+ from dotenv import load_dotenv
5
+ from typing import Any
6
+ import torch
7
+ from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
8
+
9
+ # Local imports
10
+ from src.logger import logging
11
+ from src.exception import CustomExceptionHandling
12
+
13
+
14
+ # Load the Environment Variables from .env file
15
+ load_dotenv()
16
+
17
+ # Access token for using the model
18
+ access_token = os.environ.get("ACCESS_TOKEN")
19
+
20
+
21
+ def load_model_and_processor(model_name: str, device: str) -> Any:
22
+ """
23
+ Load the model and processor.
24
+
25
+ Args:
26
+ - model_name (str): The name of the model to load.
27
+ - device (str): The device to load the model onto.
28
+
29
+ Returns:
30
+ - model: The loaded model.
31
+ - processor: The loaded processor.
32
+ """
33
+ try:
34
+ # Load the model and processor
35
+ model = (
36
+ PaliGemmaForConditionalGeneration.from_pretrained(
37
+ model_name, torch_dtype=torch.bfloat16, token=access_token
38
+ )
39
+ .eval()
40
+ .to(device)
41
+ )
42
+ processor = PaliGemmaProcessor.from_pretrained(model_name, token=access_token)
43
+
44
+ # Log the successful loading of the model and processor
45
+ logging.info("Model and processor loaded successfully.")
46
+
47
+ # Return the model and processor
48
+ return model, processor
49
+
50
+ # Handle exceptions that may occur during model and processor loading
51
+ except Exception as e:
52
+ # Custom exception handling
53
+ raise CustomExceptionHandling(e, sys) from e
src/paligemma/response.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Necessary imports
2
+ import sys
3
+ import PIL.Image
4
+ import torch
5
+ import gradio as gr
6
+ import spaces
7
+
8
+ # Local imports
9
+ from src.config import device, model_name
10
+ from src.paligemma.model import load_model_and_processor
11
+ from src.logger import logging
12
+ from src.exception import CustomExceptionHandling
13
+
14
+
15
+ # Language dictionary
16
+ language_dict = {
17
+ "English": "en",
18
+ "Spanish": "es",
19
+ "French": "fr",
20
+ }
21
+
22
+ # Model and processor
23
+ model, processor = load_model_and_processor(model_name, device)
24
+
25
+
26
+ @spaces.GPU
27
+ def caption_image(image: PIL.Image.Image, max_new_tokens: int, language: str) -> str:
28
+ """
29
+ Generates a caption based on the given image using the model.
30
+
31
+ Args:
32
+ - image (PIL.Image.Image): The input image to be processed.
33
+ - max_new_tokens (int): The maximum number of new tokens to generate.
34
+ - language (str): The language of the generated caption.
35
+
36
+ Returns:
37
+ str: The generated caption text.
38
+ """
39
+ try:
40
+ # Check if image is None
41
+ if not image:
42
+ gr.Warning("Please provide an image.")
43
+
44
+ # Prepare the inputs
45
+ language = language_dict[language]
46
+ prompt = f"<image>caption {language}"
47
+ model_inputs = (
48
+ processor(text=prompt, images=image, return_tensors="pt")
49
+ .to(torch.bfloat16)
50
+ .to(device)
51
+ )
52
+ input_len = model_inputs["input_ids"].shape[-1]
53
+
54
+ # Generate the response
55
+ with torch.inference_mode():
56
+ generation = model.generate(
57
+ **model_inputs, max_new_tokens=max_new_tokens, do_sample=False
58
+ )
59
+ generation = generation[0][input_len:]
60
+ decoded = processor.decode(generation, skip_special_tokens=True)
61
+
62
+ # Log the successful generation of the caption
63
+ logging.info("Caption generated successfully.")
64
+
65
+ # Return the generated caption
66
+ return decoded
67
+
68
+ # Handle exceptions that may occur during caption generation
69
+ except Exception as e:
70
+ # Custom exception handling
71
+ raise CustomExceptionHandling(e, sys) from e