Spaces:
Running
on
Zero
Running
on
Zero
# Necessary imports | |
import os | |
import sys | |
from dotenv import load_dotenv | |
from typing import Any | |
import torch | |
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor | |
# Local imports | |
from src.logger import logging | |
from src.exception import CustomExceptionHandling | |
# Load the Environment Variables from .env file | |
load_dotenv() | |
# Access token for using the model | |
access_token = os.environ.get("ACCESS_TOKEN") | |
def load_model_and_processor(model_name: str, device: str) -> Any: | |
""" | |
Load the model and processor. | |
Args: | |
- model_name (str): The name of the model to load. | |
- device (str): The device to load the model onto. | |
Returns: | |
- model: The loaded model. | |
- processor: The loaded processor. | |
""" | |
try: | |
# Load the model and processor | |
model = ( | |
PaliGemmaForConditionalGeneration.from_pretrained( | |
model_name, torch_dtype=torch.bfloat16, token=access_token | |
) | |
.eval() | |
.to(device) | |
) | |
processor = PaliGemmaProcessor.from_pretrained(model_name, token=access_token) | |
# Log the successful loading of the model and processor | |
logging.info("Model and processor loaded successfully.") | |
# Return the model and processor | |
return model, processor | |
# Handle exceptions that may occur during model and processor loading | |
except Exception as e: | |
# Custom exception handling | |
raise CustomExceptionHandling(e, sys) from e | |