File size: 3,740 Bytes
0fc5095
 
 
 
 
 
 
 
 
 
0375f07
9467c94
d8addc5
0375f07
 
9467c94
 
 
 
 
0375f07
d8addc5
 
0375f07
0fc5095
d8addc5
 
 
 
 
 
0fc5095
 
 
 
 
 
 
 
 
d8addc5
0fc5095
 
0375f07
0fc5095
 
 
 
 
 
 
 
 
d8addc5
 
 
 
 
 
 
 
0fc5095
d8addc5
9467c94
d8addc5
 
 
 
 
 
 
 
 
 
9467c94
d8addc5
9467c94
c9933c7
9467c94
 
0fc5095
 
 
 
 
d8addc5
0fc5095
 
 
 
 
 
 
 
d8addc5
 
9467c94
 
0fc5095
 
 
 
 
 
 
 
9467c94
 
 
 
 
0fc5095
 
9467c94
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from typing import Optional
import gradio as gr
import numpy as np
import torch
from PIL import Image
import io
import base64, os
from utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
from PIL import Image

from ultralytics import YOLO
yolo_model = YOLO('weights/icon_detect/best.pt')

from transformers import AutoProcessor, AutoModelForCausalLM 
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    "weights/icon_caption_florence",
    torch_dtype=torch.float32,
    trust_remote_code=True
)

caption_model_processor = {'processor': processor, 'model': model}
print('Finished loading model.')

platform = 'pc'
draw_bbox_config = {
    'text_scale': 0.8,
    'text_thickness': 2,
    'text_padding': 2,
    'thickness': 2,
}

MARKDOWN = """
# OmniParser for Pure Vision Based General GUI Agent 🔥
<div>
    <a href="https://arxiv.org/pdf/2408.00203">
        <img src="https://img.shields.io/badge/arXiv-2408.00203-b31b1b.svg" alt="Arxiv" style="display:inline-block;">
    </a>
</div>

OmniParser is a screen parsing tool to convert general GUI screens to structured elements.
"""

@torch.inference_mode()
def process(
    image_input,
    box_threshold,
    iou_threshold
) -> Optional[Image.Image]:

    image_save_path = 'imgs/saved_image_demo.png'
    image_input.save(image_save_path)

    ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
        image_save_path,
        display_img=False,
        output_bb_format='xyxy',
        goal_filtering=None,
        easyocr_args={'paragraph': False, 'text_threshold': 0.9},
        use_paddleocr=True
    )
    text, ocr_bbox = ocr_bbox_rslt

    dino_labeled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
        image_save_path,
        yolo_model,
        BOX_TRESHOLD=box_threshold,
        output_coord_in_ratio=True,
        ocr_bbox=ocr_bbox,
        draw_bbox_config=draw_bbox_config,
        caption_model_processor=caption_model_processor,
        ocr_text=text,
        iou_threshold=iou_threshold
    )
    image = Image.open(io.BytesIO(base64.b64decode(dino_labeled_img)))
    print('Finished processing.')
    parsed_content_list_str = '\n'.join(parsed_content_list)
    label_coordinates_str = label_coordinates # '\n'.join([str(coord) for coord in label_coordinates])

    return image, parsed_content_list_str, label_coordinates_str

with gr.Blocks() as demo:
    gr.Markdown(MARKDOWN)
    with gr.Row():
        with gr.Column():
            image_input_component = gr.Image(type='pil', label='Upload Image')
            box_threshold_component = gr.Slider(
                label='Box Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.05)
            iou_threshold_component = gr.Slider(
                label='IOU Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.1)
            submit_button_component = gr.Button(
                value='Submit', variant='primary')
        with gr.Column():
            image_output_component = gr.Image(type='pil', label='Image Output')
            text_output_component = gr.Textbox(
                label='Parsed Screen Elements', placeholder='Text Output')
            coordinates_output_component = gr.Textbox(
                label='Coordinates', placeholder='Coordinates Output')

    submit_button_component.click(
        fn=process,
        inputs=[
            image_input_component,
            box_threshold_component,
            iou_threshold_component
        ],
        outputs=[
            image_output_component,
            text_output_component,
            coordinates_output_component
        ]
    )

demo.queue().launch(share=False)