img2text / app.py
Nerva1228's picture
Update app.py
9409942 verified
import gradio as gr
import spaces
from PIL import Image
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
import requests
import json
from io import BytesIO
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True)
SERVER_URL = 'http://43.156.72.113:8188'
FETCH_TASKS_URL = SERVER_URL + '/fetch/'
UPDATE_TASK_STATUS_URL = SERVER_URL + '/update/'
def fetch_task(category, fetch_all=False):
params = {'fetch_all': 'true' if fetch_all else 'false'}
response = requests.post(FETCH_TASKS_URL + category, params=params)
if response.status_code == 200:
return response.json()
else:
print(f"Failed to fetch tasks: {response.status_code} - {response.text}")
return None
def update_task_status(category, task_id, status, result=None):
data = {'status': status}
if result:
data['result'] = result
response = requests.post(UPDATE_TASK_STATUS_URL + category + f'/{task_id}', json=data)
if response.status_code == 200:
print(f"Task {task_id} updated successfully: {json.dumps(response.json(), indent=4)}")
else:
print(f"Failed to update task {task_id}: {response.status_code} - {response.text}")
@spaces.GPU(duration=150)
def infer():
img2text_tasks = fetch_task('img2text', fetch_all=True)
if not img2text_tasks:
return "No tasks found."
for task in img2text_tasks:
try:
image_url = task['content']['url']
prompt = task['content']['prompt']
image_response = requests.get(image_url)
image = Image.open(BytesIO(image_response.content)).convert("RGB")
max_size = 256
width, height = image.size
if width > height:
new_width = max_size
new_height = int((new_width / width) * height)
else:
new_height = max_size
new_width = int((new_height / height) * width)
image = image.resize((new_width, new_height), Image.LANCZOS)
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
do_sample=False,
num_beams=3
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
update_task_status('img2text', task['id'], 'Successed', {"text": parsed_answer})
except Exception as e:
print(f"Error processing task {task['id']}: {e}")
update_task_status('img2text', task['id'], 'Failed', {"error": str(e)})
return f"Error processing task {task['id']}: {e}"
return "Successed! No pending tasks found."
css = """
#col-container {
margin: 0 auto;
max-width: 800px;
}
"""
with gr.Blocks(css=css) as app:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""# Tag The Image
Get tag based on images using the Florence-2-base-PromptGen-v1.5 model.
""")
run_button = gr.Button("Submit", scale=0, elem_id="run-button")
result = gr.Textbox(label="Generated Text", show_label=False)
gr.on(
triggers=[run_button.click],
fn=infer,
inputs=[],
outputs=[result]
)
app.queue()
app.launch(show_error=True)