Spaces:
Sleeping
Sleeping
import torch | |
from transformers import ( | |
PaliGemmaProcessor, | |
PaliGemmaForConditionalGeneration, | |
) | |
import streamlit as st | |
from PIL import Image | |
import os | |
# write access token in secrets | |
token = os.environ.get('HF_TOKEN') | |
# paligemma model | |
model_id = "google/paligemma2-3b-pt-896" | |
def model_setup(model_id): | |
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id,torch_dtype=torch.bfloat16,device_map="auto",token=token).eval() | |
processor = PaliGemmaProcessor.from_pretrained(model_id,token=token) | |
return model,processor | |
def runModel(prompt,image): | |
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to(model.device) | |
input_len = model_inputs["input_ids"].shape[-1] | |
with torch.inference_mode(): | |
generation = model.generate(**model_inputs, max_new_tokens=1000, do_sample=False) | |
generation = generation[0][input_len:] | |
return processor.decode(generation, skip_special_tokens=True) | |
def initialize(): | |
# initialize chat history | |
st.session_state.messages = [] | |
### load model | |
model,processor = model_setup(model_id) | |
### upload a file | |
uploaded_file = st.file_uploader("Choose an image",on_change=initialize) | |
if uploaded_file: | |
st.image(uploaded_file) | |
image = Image.open(uploaded_file).convert("RGB") | |
# tasks | |
task = st.radio( | |
"Task", | |
tuple(['Caption','OCR','Segment','Enter your prompt']), | |
horizontal=True) | |
# display chat messages from history on app rerun | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
if task == 'Enter your prompt': | |
if prompt := st.chat_input("Type here!",key="question"): | |
# display user message in chat message container | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
# Add user message to chat history | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
# run the VLM | |
response = runModel(prompt,image) | |
# display assistant response in chat message container | |
with st.chat_message("assistant"): | |
st.markdown(response) | |
# Add assistant response to chat history | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
else: | |
# display user message in chat message container | |
with st.chat_message("user"): | |
st.markdown(task) | |
# Add user message to chat history | |
st.session_state.messages.append({"role": "user", "content": task}) | |
# run the VLM | |
response = runModel(task,image) | |
# display assistant response in chat message container | |
with st.chat_message("assistant"): | |
st.markdown(response) | |
# Add assistant response to chat history | |
st.session_state.messages.append({"role": "assistant", "content": response}) |