|
import gradio as gr |
|
import spaces |
|
from PIL import Image |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoProcessor |
|
import requests |
|
import json |
|
from io import BytesIO |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True).to(device) |
|
processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True) |
|
|
|
SERVER_URL = 'http://43.156.72.113:8188' |
|
FETCH_TASKS_URL = SERVER_URL + '/fetch/' |
|
UPDATE_TASK_STATUS_URL = SERVER_URL + '/update/' |
|
|
|
def fetch_task(category, fetch_all=False): |
|
params = {'fetch_all': 'true' if fetch_all else 'false'} |
|
response = requests.post(FETCH_TASKS_URL + category, params=params) |
|
if response.status_code == 200: |
|
return response.json() |
|
else: |
|
print(f"Failed to fetch tasks: {response.status_code} - {response.text}") |
|
return None |
|
|
|
def update_task_status(category, task_id, status, result=None): |
|
data = {'status': status} |
|
if result: |
|
data['result'] = result |
|
|
|
response = requests.post(UPDATE_TASK_STATUS_URL + category + f'/{task_id}', json=data) |
|
if response.status_code == 200: |
|
print(f"Task {task_id} updated successfully: {json.dumps(response.json(), indent=4)}") |
|
else: |
|
print(f"Failed to update task {task_id}: {response.status_code} - {response.text}") |
|
|
|
|
|
@spaces.GPU(duration=150) |
|
def infer(): |
|
|
|
img2text_tasks = fetch_task('img2text', fetch_all=True) |
|
|
|
if not img2text_tasks: |
|
return "No tasks found." |
|
|
|
for task in img2text_tasks: |
|
try: |
|
image_url = task['content']['url'] |
|
prompt = task['content']['prompt'] |
|
|
|
image_response = requests.get(image_url) |
|
image = Image.open(BytesIO(image_response.content)).convert("RGB") |
|
|
|
max_size = 256 |
|
width, height = image.size |
|
if width > height: |
|
new_width = max_size |
|
new_height = int((new_width / width) * height) |
|
else: |
|
new_height = max_size |
|
new_width = int((new_height / height) * width) |
|
image = image.resize((new_width, new_height), Image.LANCZOS) |
|
|
|
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device) |
|
generated_ids = model.generate( |
|
input_ids=inputs["input_ids"], |
|
pixel_values=inputs["pixel_values"], |
|
max_new_tokens=1024, |
|
do_sample=False, |
|
num_beams=3 |
|
) |
|
|
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] |
|
parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height)) |
|
|
|
update_task_status('img2text', task['id'], 'Successed', {"text": parsed_answer}) |
|
except Exception as e: |
|
print(f"Error processing task {task['id']}: {e}") |
|
update_task_status('img2text', task['id'], 'Failed', {"error": str(e)}) |
|
return f"Error processing task {task['id']}: {e}" |
|
|
|
return "Successed! No pending tasks found." |
|
|
|
|
|
css = """ |
|
#col-container { |
|
margin: 0 auto; |
|
max-width: 800px; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=css) as app: |
|
with gr.Column(elem_id="col-container"): |
|
gr.Markdown(f"""# Tag The Image |
|
Get tag based on images using the Florence-2-base-PromptGen-v1.5 model. |
|
""") |
|
|
|
run_button = gr.Button("Submit", scale=0, elem_id="run-button") |
|
result = gr.Textbox(label="Generated Text", show_label=False) |
|
|
|
gr.on( |
|
triggers=[run_button.click], |
|
fn=infer, |
|
inputs=[], |
|
outputs=[result] |
|
) |
|
|
|
app.queue() |
|
app.launch(show_error=True) |
|
|