fffiloni's picture
Create app.py
35a9ed4 verified
raw
history blame
1.7 kB
import torch
from transformers import AutoTokenizer, AutoModel
from PIL import Image
import numpy as np
import os
import gradio as gr
# Load the model and tokenizer
model_path = "ByteDance/Sa2VA-4B"
model = AutoModel.from_pretrained(
model_path,
torch_dtype = torch.bfloat16,
low_cpu_mem_usage = True,
use_flash_attn = True,
trust_remote_code = True
).eval().cuda()
tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code = True,
use_fast = False
)
def image_vision(image_input_path, prompt):
image_path = image_input_path
text_prompts = f"<image>{prompt}"
image = Image.open(image_path).convert('RGB')
input_dict = {
'image': image,
'text': text_prompts,
'past_text': '',
'mask_prompts': None,
'tokenizer': tokenizer,
}
return_dict = model.predict_forward(**input_dict)
answer = return_dict["prediction"] # the text format answer
print(answer)
def main_infer(image_input_path, prompt):
response = image_vision(image_input_path, prompt)
return response
# Gradio UI
with gr.Blocks() as demo:
with gr.Column():
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Image IN")
with gr.Row():
instruction = gr.Textbox(label="Instruction")
submit_btn = gr.Button("SUbmit", scale=1)
with gr.Column():
output_res = gr.Textbox(label="Response")
submit_btn.click(
fn = main_infer,
inputs = [image_input, instruction],
outputs = [output_res]
)
demo.queue().launch(show_api=False, show_error=True)