TotoB12 commited on
Commit
d8addc5
·
verified ·
1 Parent(s): 39f8e6b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -55
app.py CHANGED
@@ -1,54 +1,30 @@
1
  from typing import Optional
2
- import spaces
3
-
4
  import gradio as gr
5
  import numpy as np
6
  import torch
7
  from PIL import Image
8
  import io
9
-
10
-
11
  import base64, os
12
  from utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
13
- import torch
14
  from PIL import Image
15
 
16
- # yolo_model = get_yolo_model(model_path='weights/icon_detect/best.pt')
17
- # caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="weights/icon_caption_florence")
18
-
19
  from ultralytics import YOLO
20
- yolo_model = YOLO('weights/icon_detect/best.pt').to('cuda')
 
21
  from transformers import AutoProcessor, AutoModelForCausalLM
22
  processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
23
- model = AutoModelForCausalLM.from_pretrained("weights/icon_caption_florence", torch_dtype=torch.float16, trust_remote_code=True).to('cuda')
24
- caption_model_processor = {'processor': processor, 'model': model}
25
- print('finish loading model!!!')
26
 
 
 
27
 
28
  platform = 'pc'
29
- if platform == 'pc':
30
- draw_bbox_config = {
31
- 'text_scale': 0.8,
32
- 'text_thickness': 2,
33
- 'text_padding': 2,
34
- 'thickness': 2,
35
- }
36
- elif platform == 'web':
37
- draw_bbox_config = {
38
- 'text_scale': 0.8,
39
- 'text_thickness': 2,
40
- 'text_padding': 3,
41
- 'thickness': 3,
42
- }
43
- elif platform == 'mobile':
44
- draw_bbox_config = {
45
- 'text_scale': 0.8,
46
- 'text_thickness': 2,
47
- 'text_padding': 3,
48
- 'thickness': 3,
49
- }
50
-
51
-
52
 
53
  MARKDOWN = """
54
  # OmniParser for Pure Vision Based General GUI Agent 🔥
@@ -58,15 +34,10 @@ MARKDOWN = """
58
  </a>
59
  </div>
60
 
61
- OmniParser is a screen parsing tool to convert general GUI screen to structured elements.
62
  """
63
 
64
- # DEVICE = torch.device('cuda')
65
-
66
- # @spaces.GPU
67
  @torch.inference_mode()
68
- # @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
69
- @spaces.GPU(duration=65)
70
  def process(
71
  image_input,
72
  box_threshold,
@@ -75,36 +46,48 @@ def process(
75
 
76
  image_save_path = 'imgs/saved_image_demo.png'
77
  image_input.save(image_save_path)
78
- # import pdb; pdb.set_trace()
79
 
80
- ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_save_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9}, use_paddleocr=True)
 
 
 
 
 
 
 
81
  text, ocr_bbox = ocr_bbox_rslt
82
- # print('prompt:', prompt)
83
- dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_save_path, yolo_model, BOX_TRESHOLD = box_threshold, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,iou_threshold=iou_threshold)
 
 
 
 
 
 
 
 
 
 
84
  image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
85
- print('finish processing')
86
  parsed_content_list = '\n'.join(parsed_content_list)
87
  return image, str(parsed_content_list)
88
 
89
-
90
-
91
  with gr.Blocks() as demo:
92
  gr.Markdown(MARKDOWN)
93
  with gr.Row():
94
  with gr.Column():
95
- image_input_component = gr.Image(
96
- type='pil', label='Upload image')
97
- # set the threshold for removing the bounding boxes with low confidence, default is 0.05
98
  box_threshold_component = gr.Slider(
99
  label='Box Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.05)
100
- # set the threshold for removing the bounding boxes with large overlap, default is 0.1
101
  iou_threshold_component = gr.Slider(
102
  label='IOU Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.1)
103
  submit_button_component = gr.Button(
104
  value='Submit', variant='primary')
105
  with gr.Column():
106
  image_output_component = gr.Image(type='pil', label='Image Output')
107
- text_output_component = gr.Textbox(label='Parsed screen elements', placeholder='Text Output')
 
108
 
109
  submit_button_component.click(
110
  fn=process,
@@ -116,6 +99,4 @@ with gr.Blocks() as demo:
116
  outputs=[image_output_component, text_output_component]
117
  )
118
 
119
- # demo.launch(debug=False, show_error=True, share=True)
120
- # demo.launch(share=True, server_port=7861, server_name='0.0.0.0')
121
  demo.queue().launch(share=False)
 
1
  from typing import Optional
 
 
2
  import gradio as gr
3
  import numpy as np
4
  import torch
5
  from PIL import Image
6
  import io
 
 
7
  import base64, os
8
  from utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
 
9
  from PIL import Image
10
 
 
 
 
11
  from ultralytics import YOLO
12
+ yolo_model = YOLO('weights/icon_detect/best.pt') # Removed .to('cuda')
13
+
14
  from transformers import AutoProcessor, AutoModelForCausalLM
15
  processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
16
+ model = AutoModelForCausalLM.from_pretrained("weights/icon_caption_florence", torch_dtype=torch.float32, trust_remote_code=True) # Changed dtype to float32 and removed .to('cuda')
 
 
17
 
18
+ caption_model_processor = {'processor': processor, 'model': model}
19
+ print('Finished loading model.')
20
 
21
  platform = 'pc'
22
+ draw_bbox_config = {
23
+ 'text_scale': 0.8,
24
+ 'text_thickness': 2,
25
+ 'text_padding': 2,
26
+ 'thickness': 2,
27
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  MARKDOWN = """
30
  # OmniParser for Pure Vision Based General GUI Agent 🔥
 
34
  </a>
35
  </div>
36
 
37
+ OmniParser is a screen parsing tool to convert general GUI screens to structured elements.
38
  """
39
 
 
 
 
40
  @torch.inference_mode()
 
 
41
  def process(
42
  image_input,
43
  box_threshold,
 
46
 
47
  image_save_path = 'imgs/saved_image_demo.png'
48
  image_input.save(image_save_path)
 
49
 
50
+ ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
51
+ image_save_path,
52
+ display_img=False,
53
+ output_bb_format='xyxy',
54
+ goal_filtering=None,
55
+ easyocr_args={'paragraph': False, 'text_threshold': 0.9},
56
+ use_paddleocr=True
57
+ )
58
  text, ocr_bbox = ocr_bbox_rslt
59
+
60
+ dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
61
+ image_save_path,
62
+ yolo_model,
63
+ BOX_TRESHOLD=box_threshold,
64
+ output_coord_in_ratio=True,
65
+ ocr_bbox=ocr_bbox,
66
+ draw_bbox_config=draw_bbox_config,
67
+ caption_model_processor=caption_model_processor,
68
+ ocr_text=text,
69
+ iou_threshold=iou_threshold
70
+ )
71
  image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
72
+ print('Finished processing.')
73
  parsed_content_list = '\n'.join(parsed_content_list)
74
  return image, str(parsed_content_list)
75
 
 
 
76
  with gr.Blocks() as demo:
77
  gr.Markdown(MARKDOWN)
78
  with gr.Row():
79
  with gr.Column():
80
+ image_input_component = gr.Image(type='pil', label='Upload Image')
 
 
81
  box_threshold_component = gr.Slider(
82
  label='Box Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.05)
 
83
  iou_threshold_component = gr.Slider(
84
  label='IOU Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.1)
85
  submit_button_component = gr.Button(
86
  value='Submit', variant='primary')
87
  with gr.Column():
88
  image_output_component = gr.Image(type='pil', label='Image Output')
89
+ text_output_component = gr.Textbox(
90
+ label='Parsed Screen Elements', placeholder='Text Output')
91
 
92
  submit_button_component.click(
93
  fn=process,
 
99
  outputs=[image_output_component, text_output_component]
100
  )
101
 
 
 
102
  demo.queue().launch(share=False)