Kaushik Bar commited on
Commit
2d4f9f1
·
1 Parent(s): 007aa6e

removing yoloxl

Browse files
Files changed (3) hide show
  1. app.py +19 -26
  2. app_bk.py +0 -144
  3. requirements.txt +1 -1
app.py CHANGED
@@ -6,7 +6,6 @@ import torch
6
  import pathlib
7
  from PIL import Image
8
  from transformers import AutoFeatureExtractor, DetrForObjectDetection, YolosForObjectDetection
9
- from keras_cv_attention_models.yolox import *
10
 
11
  import os
12
 
@@ -20,18 +19,12 @@ COLORS = [
20
  [0.301, 0.745, 0.933]
21
  ]
22
 
23
- def make_prediction(img, feature_extractor, model, model_name):
24
- if 'yolox' in model_name:
25
- inputs = feature_extractor(img)
26
- outputs = model(**inputs)
27
- processed_outputs = {}
28
- processed_outputs['boxes'], processed_outputs['labels'], processed_outputs['scores'] = model.decode_predictions(outputs)[0]
29
- else:
30
- inputs = feature_extractor(img, return_tensors="pt")
31
- outputs = model(**inputs)
32
- img_size = torch.tensor([tuple(reversed(img.size))])
33
- processed_outputs = feature_extractor.post_process(outputs, img_size)[0]
34
- return processed_outputs
35
 
36
  def fig2img(fig):
37
  buf = io.BytesIO()
@@ -60,23 +53,26 @@ def visualize_prediction(pil_img, output_dict, threshold=0.7, id2label=None):
60
  return fig2img(plt.gcf())
61
 
62
  def detect_objects(model_name,url_input,image_input,threshold):
 
 
 
 
63
  if 'detr' in model_name:
 
64
  model = DetrForObjectDetection.from_pretrained(model_name)
65
- feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
66
  elif 'yolos' in model_name:
 
67
  model = YolosForObjectDetection.from_pretrained(model_name)
68
- feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
69
- elif 'yolox' in model_name:
70
- model = YOLOXL(pretrained="coco")
71
- feature_extractor = model.preprocess_input
72
 
73
  if validators.url(url_input):
74
- image = Image.open(requests.get(url_input, stream=True).raw)
 
75
  elif image_input:
76
  image = image_input
77
 
78
  #Make prediction
79
- processed_outputs = make_prediction(image, feature_extractor, model, model_name)
80
 
81
  #Visualize prediction
82
  viz_img = visualize_prediction(image, processed_outputs, threshold, model.config.id2label)
@@ -94,15 +90,13 @@ title = """<h1 id="title">Object Detection App with DETR and YOLOS</h1>"""
94
 
95
  description = """
96
  Links to HuggingFace Models:
97
-
98
  - [facebook/detr-resnet-50](https://huggingface.co/facebook/detr-resnet-50)
99
  - [facebook/detr-resnet-101](https://huggingface.co/facebook/detr-resnet-101)
100
  - [hustvl/yolos-small](https://huggingface.co/hustvl/yolos-small)
101
  - [hustvl/yolos-tiny](https://huggingface.co/hustvl/yolos-tiny)
102
-
103
  """
104
 
105
- models = ["facebook/detr-resnet-50","facebook/detr-resnet-101","hustvl/yolos-small","hustvl/yolos-tiny","yoloxl"]
106
  urls = ["https://c8.alamy.com/comp/J2AB4K/the-new-york-stock-exchange-on-the-wall-street-in-new-york-J2AB4K.jpg"]
107
 
108
  css = '''
@@ -114,7 +108,7 @@ demo = gr.Blocks(css=css)
114
 
115
  with demo:
116
  gr.Markdown(title)
117
- #gr.Markdown(description)
118
  options = gr.Dropdown(choices=models,label='Select Object Detection Model',show_label=True)
119
  slider_input = gr.Slider(minimum=0.2,maximum=1,value=0.5,label='Prediction Threshold')
120
 
@@ -131,8 +125,7 @@ with demo:
131
 
132
  with gr.TabItem('Image Upload'):
133
  with gr.Row():
134
- img_input = gr.Image()
135
- #img_input = gr.Image(type='pil')
136
  img_output_from_upload= gr.Image(shape=(650,650))
137
 
138
  with gr.Row():
 
6
  import pathlib
7
  from PIL import Image
8
  from transformers import AutoFeatureExtractor, DetrForObjectDetection, YolosForObjectDetection
 
9
 
10
  import os
11
 
 
19
  [0.301, 0.745, 0.933]
20
  ]
21
 
22
+ def make_prediction(img, feature_extractor, model):
23
+ inputs = feature_extractor(img, return_tensors="pt")
24
+ outputs = model(**inputs)
25
+ img_size = torch.tensor([tuple(reversed(img.size))])
26
+ processed_outputs = feature_extractor.post_process(outputs, img_size)
27
+ return processed_outputs[0]
 
 
 
 
 
 
28
 
29
  def fig2img(fig):
30
  buf = io.BytesIO()
 
53
  return fig2img(plt.gcf())
54
 
55
  def detect_objects(model_name,url_input,image_input,threshold):
56
+
57
+ #Extract model and feature extractor
58
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
59
+
60
  if 'detr' in model_name:
61
+
62
  model = DetrForObjectDetection.from_pretrained(model_name)
63
+
64
  elif 'yolos' in model_name:
65
+
66
  model = YolosForObjectDetection.from_pretrained(model_name)
 
 
 
 
67
 
68
  if validators.url(url_input):
69
+ image = Image.open(requests.get(url_input, stream=True).raw)
70
+
71
  elif image_input:
72
  image = image_input
73
 
74
  #Make prediction
75
+ processed_outputs = make_prediction(image, feature_extractor, model)
76
 
77
  #Visualize prediction
78
  viz_img = visualize_prediction(image, processed_outputs, threshold, model.config.id2label)
 
90
 
91
  description = """
92
  Links to HuggingFace Models:
 
93
  - [facebook/detr-resnet-50](https://huggingface.co/facebook/detr-resnet-50)
94
  - [facebook/detr-resnet-101](https://huggingface.co/facebook/detr-resnet-101)
95
  - [hustvl/yolos-small](https://huggingface.co/hustvl/yolos-small)
96
  - [hustvl/yolos-tiny](https://huggingface.co/hustvl/yolos-tiny)
 
97
  """
98
 
99
+ models = ["facebook/detr-resnet-50","facebook/detr-resnet-101",'hustvl/yolos-small','hustvl/yolos-tiny']
100
  urls = ["https://c8.alamy.com/comp/J2AB4K/the-new-york-stock-exchange-on-the-wall-street-in-new-york-J2AB4K.jpg"]
101
 
102
  css = '''
 
108
 
109
  with demo:
110
  gr.Markdown(title)
111
+ gr.Markdown(description)
112
  options = gr.Dropdown(choices=models,label='Select Object Detection Model',show_label=True)
113
  slider_input = gr.Slider(minimum=0.2,maximum=1,value=0.5,label='Prediction Threshold')
114
 
 
125
 
126
  with gr.TabItem('Image Upload'):
127
  with gr.Row():
128
+ img_input = gr.Image(type='pil')
 
129
  img_output_from_upload= gr.Image(shape=(650,650))
130
 
131
  with gr.Row():
app_bk.py DELETED
@@ -1,144 +0,0 @@
1
- import io
2
- import gradio as gr
3
- import matplotlib.pyplot as plt
4
- import requests, validators
5
- import torch
6
- import pathlib
7
- from PIL import Image
8
- from transformers import AutoFeatureExtractor, DetrForObjectDetection, YolosForObjectDetection
9
-
10
- import os
11
-
12
- # colors for visualization
13
- COLORS = [
14
- [0.000, 0.447, 0.741],
15
- [0.850, 0.325, 0.098],
16
- [0.929, 0.694, 0.125],
17
- [0.494, 0.184, 0.556],
18
- [0.466, 0.674, 0.188],
19
- [0.301, 0.745, 0.933]
20
- ]
21
-
22
- def make_prediction(img, feature_extractor, model):
23
- inputs = feature_extractor(img, return_tensors="pt")
24
- outputs = model(**inputs)
25
- img_size = torch.tensor([tuple(reversed(img.size))])
26
- processed_outputs = feature_extractor.post_process(outputs, img_size)
27
- return processed_outputs[0]
28
-
29
- def fig2img(fig):
30
- buf = io.BytesIO()
31
- fig.savefig(buf)
32
- buf.seek(0)
33
- img = Image.open(buf)
34
- return img
35
-
36
-
37
- def visualize_prediction(pil_img, output_dict, threshold=0.7, id2label=None):
38
- keep = output_dict["scores"] > threshold
39
- boxes = output_dict["boxes"][keep].tolist()
40
- scores = output_dict["scores"][keep].tolist()
41
- labels = output_dict["labels"][keep].tolist()
42
- if id2label is not None:
43
- labels = [id2label[x] for x in labels]
44
-
45
- plt.figure(figsize=(16, 10))
46
- plt.imshow(pil_img)
47
- ax = plt.gca()
48
- colors = COLORS * 100
49
- for score, (xmin, ymin, xmax, ymax), label, color in zip(scores, boxes, labels, colors):
50
- ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=color, linewidth=3))
51
- ax.text(xmin, ymin, f"{label}: {score:0.2f}", fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5))
52
- plt.axis("off")
53
- return fig2img(plt.gcf())
54
-
55
- def detect_objects(model_name,url_input,image_input,threshold):
56
-
57
- #Extract model and feature extractor
58
- feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
59
-
60
- if 'detr' in model_name:
61
-
62
- model = DetrForObjectDetection.from_pretrained(model_name)
63
-
64
- elif 'yolos' in model_name:
65
-
66
- model = YolosForObjectDetection.from_pretrained(model_name)
67
-
68
- if validators.url(url_input):
69
- image = Image.open(requests.get(url_input, stream=True).raw)
70
-
71
- elif image_input:
72
- image = image_input
73
-
74
- #Make prediction
75
- processed_outputs = make_prediction(image, feature_extractor, model)
76
-
77
- #Visualize prediction
78
- viz_img = visualize_prediction(image, processed_outputs, threshold, model.config.id2label)
79
-
80
- return viz_img
81
-
82
- def set_example_image(example: list) -> dict:
83
- return gr.Image.update(value=example[0])
84
-
85
- def set_example_url(example: list) -> dict:
86
- return gr.Textbox.update(value=example[0])
87
-
88
-
89
- title = """<h1 id="title">Object Detection App with DETR and YOLOS</h1>"""
90
-
91
- description = """
92
- Links to HuggingFace Models:
93
- - [facebook/detr-resnet-50](https://huggingface.co/facebook/detr-resnet-50)
94
- - [facebook/detr-resnet-101](https://huggingface.co/facebook/detr-resnet-101)
95
- - [hustvl/yolos-small](https://huggingface.co/hustvl/yolos-small)
96
- - [hustvl/yolos-tiny](https://huggingface.co/hustvl/yolos-tiny)
97
- """
98
-
99
- models = ["facebook/detr-resnet-50","facebook/detr-resnet-101",'hustvl/yolos-small','hustvl/yolos-tiny']
100
- urls = ["https://c8.alamy.com/comp/J2AB4K/the-new-york-stock-exchange-on-the-wall-street-in-new-york-J2AB4K.jpg"]
101
-
102
- css = '''
103
- h1#title {
104
- text-align: center;
105
- }
106
- '''
107
- demo = gr.Blocks(css=css)
108
-
109
- with demo:
110
- gr.Markdown(title)
111
- gr.Markdown(description)
112
- options = gr.Dropdown(choices=models,label='Select Object Detection Model',show_label=True)
113
- slider_input = gr.Slider(minimum=0.2,maximum=1,value=0.5,label='Prediction Threshold')
114
-
115
- with gr.Tabs():
116
- with gr.TabItem('Image URL'):
117
- with gr.Row():
118
- url_input = gr.Textbox(lines=2,label='Enter valid image URL here..')
119
- img_output_from_url = gr.Image(shape=(650,650))
120
-
121
- with gr.Row():
122
- example_url = gr.Dataset(components=[url_input],samples=[[str(url)] for url in urls])
123
-
124
- url_but = gr.Button('Detect')
125
-
126
- with gr.TabItem('Image Upload'):
127
- with gr.Row():
128
- img_input = gr.Image(type='pil')
129
- img_output_from_upload= gr.Image(shape=(650,650))
130
-
131
- with gr.Row():
132
- example_images = gr.Dataset(components=[img_input],
133
- samples=[[path.as_posix()]
134
- for path in sorted(pathlib.Path('images').rglob('*.JPG'))])
135
-
136
- img_but = gr.Button('Detect')
137
-
138
-
139
- url_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=img_output_from_url,queue=True)
140
- img_but.click(detect_objects,inputs=[options,url_input,img_input,slider_input],outputs=img_output_from_upload,queue=True)
141
- example_images.click(fn=set_example_image,inputs=[example_images],outputs=[img_input])
142
- example_url.click(fn=set_example_url,inputs=[example_url],outputs=[url_input])
143
-
144
- demo.launch(enable_queue=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -5,4 +5,4 @@ torch==1.10.1
5
  validators==0.18.2
6
  timm==0.5.4
7
  transformers
8
- keras_cv_attention_models==1.2.9
 
5
  validators==0.18.2
6
  timm==0.5.4
7
  transformers
8
+ #keras_cv_attention_models==1.2.9