keremberke commited on
Commit
10f3130
·
1 Parent(s): 307beef

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +20 -16
  2. utils.py +63 -1
app.py CHANGED
@@ -4,7 +4,7 @@ import gradio as gr
4
  from datasets import load_dataset
5
  from ultralyticsplus import YOLO, render_result, postprocess_classify_output
6
 
7
- from utils import load_models_from_txt_files
8
 
9
  EXAMPLE_IMAGE_DIR = 'example_images'
10
 
@@ -17,6 +17,7 @@ DEFAULT_CLS_DATASET_ID = 'keremberke/chest-xray-classification'
17
 
18
  # load model ids and default models
19
  det_model_ids, seg_model_ids, cls_model_ids = load_models_from_txt_files()
 
20
  det_model = YOLO(DEFAULT_DET_MODEL_ID)
21
  det_model_id = DEFAULT_DET_MODEL_ID
22
  seg_model = YOLO(DEFAULT_SEG_MODEL_ID)
@@ -25,22 +26,25 @@ cls_model = YOLO(DEFAULT_CLS_MODEL_ID)
25
  cls_model_id = DEFAULT_CLS_MODEL_ID
26
 
27
 
28
- def get_examples(model_id, dataset_id, task):
29
  examples = []
30
- ds = load_dataset(dataset_id, name="mini")["validation"]
31
  Path(EXAMPLE_IMAGE_DIR).mkdir(parents=True, exist_ok=True)
32
- for ind in range(min(5, len(ds))):
33
- jpeg_image_file = ds[ind]["image"]
34
- image_file_path = str(Path(EXAMPLE_IMAGE_DIR) / f"{task}_example_{ind}.jpg")
35
- jpeg_image_file.save(image_file_path, format='JPEG', quality=100)
36
- image_path = os.path.abspath(image_file_path)
37
- examples.append([image_path, model_id, 0.25])
 
 
 
38
  return examples
39
 
 
40
  # load default examples using default datasets
41
- det_examples = get_examples(DEFAULT_DET_MODEL_ID, DEFAULT_DET_DATASET_ID, 'detect')
42
- seg_examples = get_examples(DEFAULT_SEG_MODEL_ID, DEFAULT_SEG_DATASET_ID, 'segment')
43
- cls_examples = get_examples(DEFAULT_CLS_MODEL_ID, DEFAULT_CLS_DATASET_ID, 'classification')
44
 
45
 
46
  def predict(image, model_id, threshold):
@@ -120,12 +124,12 @@ with gr.Blocks() as demo:
120
  with gr.Column():
121
  detect_output = gr.Image(label="Predictions:", interactive=False)
122
  with gr.Row():
123
- gr.Examples(
124
  det_examples,
125
  inputs=[detect_input, detect_model_id, detect_threshold],
126
  outputs=detect_output,
127
  fn=predict,
128
- cache_examples=True,
129
  )
130
  with gr.Tab("Segmentation"):
131
  with gr.Row():
@@ -137,7 +141,7 @@ with gr.Blocks() as demo:
137
  with gr.Column():
138
  segment_output = gr.Image(label="Predictions:", interactive=False)
139
  with gr.Row():
140
- gr.Examples(
141
  seg_examples,
142
  inputs=[segment_input, segment_model_id, segment_threshold],
143
  outputs=segment_output,
@@ -156,7 +160,7 @@ with gr.Blocks() as demo:
156
  label="Predictions:", show_label=True, num_top_classes=5
157
  )
158
  with gr.Row():
159
- gr.Examples(
160
  cls_examples,
161
  inputs=[classify_input, classify_model_id, classify_threshold],
162
  outputs=classify_output,
 
4
  from datasets import load_dataset
5
  from ultralyticsplus import YOLO, render_result, postprocess_classify_output
6
 
7
+ from utils import load_models_from_txt_files, get_dataset_id_from_model_id, get_task_from_readme
8
 
9
  EXAMPLE_IMAGE_DIR = 'example_images'
10
 
 
17
 
18
  # load model ids and default models
19
  det_model_ids, seg_model_ids, cls_model_ids = load_models_from_txt_files()
20
+ task_to_model_ids = {'detect': det_model_ids, 'segment': seg_model_ids, 'classify': cls_model_ids}
21
  det_model = YOLO(DEFAULT_DET_MODEL_ID)
22
  det_model_id = DEFAULT_DET_MODEL_ID
23
  seg_model = YOLO(DEFAULT_SEG_MODEL_ID)
 
26
  cls_model_id = DEFAULT_CLS_MODEL_ID
27
 
28
 
29
+ def get_examples(task):
30
  examples = []
 
31
  Path(EXAMPLE_IMAGE_DIR).mkdir(parents=True, exist_ok=True)
32
+ for model_id in task_to_model_ids[task]:
33
+ dataset_id = get_dataset_id_from_model_id(model_id)
34
+ ds = load_dataset(dataset_id, name="mini")["validation"]
35
+ for ind in range(min(2, len(ds))):
36
+ jpeg_image_file = ds[ind]["image"]
37
+ image_file_path = str(Path(EXAMPLE_IMAGE_DIR) / f"{task}_example_{ind}.jpg")
38
+ jpeg_image_file.save(image_file_path, format='JPEG', quality=100)
39
+ image_path = os.path.abspath(image_file_path)
40
+ examples.append([image_path, model_id, 0.25])
41
  return examples
42
 
43
+
44
  # load default examples using default datasets
45
+ det_examples = get_examples('detect')
46
+ seg_examples = get_examples('segment')
47
+ cls_examples = get_examples('classify')
48
 
49
 
50
  def predict(image, model_id, threshold):
 
124
  with gr.Column():
125
  detect_output = gr.Image(label="Predictions:", interactive=False)
126
  with gr.Row():
127
+ detect_examples = gr.Examples(
128
  det_examples,
129
  inputs=[detect_input, detect_model_id, detect_threshold],
130
  outputs=detect_output,
131
  fn=predict,
132
+ cache_examples=False,
133
  )
134
  with gr.Tab("Segmentation"):
135
  with gr.Row():
 
141
  with gr.Column():
142
  segment_output = gr.Image(label="Predictions:", interactive=False)
143
  with gr.Row():
144
+ segment_examples = gr.Examples(
145
  seg_examples,
146
  inputs=[segment_input, segment_model_id, segment_threshold],
147
  outputs=segment_output,
 
160
  label="Predictions:", show_label=True, num_top_classes=5
161
  )
162
  with gr.Row():
163
+ classify_examples = gr.Examples(
164
  cls_examples,
165
  inputs=[classify_input, classify_model_id, classify_threshold],
166
  outputs=classify_output,
utils.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  DET_MODELS_FILENAME = 'det_models.txt'
2
  SEG_MODELS_FILENAME = 'seg_models.txt'
3
  CLS_MODELS_FILENAME = 'cls_models.txt'
@@ -11,4 +15,62 @@ def load_models_from_txt_files():
11
  seg_models = [line.strip() for line in file]
12
  with open(CLS_MODELS_FILENAME, 'r') as file:
13
  cls_models = [line.strip() for line in file]
14
- return det_models, seg_models, cls_models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import re
3
+
4
+
5
  DET_MODELS_FILENAME = 'det_models.txt'
6
  SEG_MODELS_FILENAME = 'seg_models.txt'
7
  CLS_MODELS_FILENAME = 'cls_models.txt'
 
15
  seg_models = [line.strip() for line in file]
16
  with open(CLS_MODELS_FILENAME, 'r') as file:
17
  cls_models = [line.strip() for line in file]
18
+ return det_models, seg_models, cls_models
19
+
20
+
21
+ def get_dataset_id_from_model_id(model_id):
22
+ """
23
+ Gets the dataset ID from the README file for a given Hugging Face model ID.
24
+
25
+ Args:
26
+ model_id (str): The Hugging Face model ID.
27
+
28
+ Returns:
29
+ The dataset ID as a string, or None if the dataset ID cannot be found.
30
+ """
31
+ # Define the URL of the README file for the model
32
+ readme_url = f"https://huggingface.co/{model_id}/raw/main/README.md"
33
+
34
+ # Make a GET request to the README URL and get the contents
35
+ response = requests.get(readme_url)
36
+ readme_contents = response.text
37
+
38
+ # Use regular expressions to search for the dataset ID in the README file
39
+ match = re.search(r"datasets:\s*\n- (\S+)", readme_contents)
40
+
41
+ # If a match is found, extract the dataset ID and return it. Otherwise, return None.
42
+ if match is not None:
43
+ dataset_id = match.group(1)
44
+ return dataset_id
45
+ else:
46
+ return None
47
+
48
+
49
+ def get_task_from_readme(model_id):
50
+ """
51
+ Gets the task from the README file for a given Hugging Face model ID.
52
+
53
+ Args:
54
+ model_id (str): The Hugging Face model ID.
55
+
56
+ Returns:
57
+ The task as a string ("detect", "segment", or "classify"), or None if the task cannot be found.
58
+ """
59
+ # Define the URL of the README file for the model
60
+ readme_url = f"https://huggingface.co/{model_id}/raw/main/README.md"
61
+
62
+ # Make a GET request to the README URL and get the contents
63
+ response = requests.get(readme_url)
64
+ readme_contents = response.text
65
+
66
+ # Use regular expressions to search for the task in the tags section of the README file
67
+ if re.search(r"tags:", readme_contents):
68
+ if re.search(r"object-detection", readme_contents):
69
+ return "detect"
70
+ elif re.search(r"image-segmentation", readme_contents):
71
+ return "segment"
72
+ elif re.search(r"image-classification", readme_contents):
73
+ return "classify"
74
+
75
+ # If the task cannot be found, return None
76
+ return None