piperod91 commited on
Commit
730f5a5
·
1 Parent(s): 77dfadf

adding closest samples

Browse files
Files changed (4) hide show
  1. app.py +81 -23
  2. inference_resnet.py +15 -14
  3. inference_sam.py +5 -4
  4. pre-requirements.txt +2 -1
app.py CHANGED
@@ -9,6 +9,7 @@ if os.getenv('SYSTEM') == 'spaces':
9
  subprocess.call('pip install git+https://github.com/cocodataset/panopticapi.git'.split())
10
  subprocess.call('pip install python-dotenv'.split())
11
  subprocess.call('pip install torch torchvision '.split())
 
12
 
13
  import gradio as gr
14
  from huggingface_hub import snapshot_download
@@ -19,13 +20,38 @@ import numpy as np
19
  import gradio as gr
20
  import glob
21
  from inference_sam import segmentation_sam
22
-
 
23
  import pathlib
 
 
24
 
25
  if not os.path.exists('images'):
26
  REPO_ID='Serrelab/image_examples_gradio'
27
  snapshot_download(repo_id=REPO_ID, token=os.environ.get('READ_TOKEN'),repo_type='dataset',local_dir='images')
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  def segment_image(input_image):
31
  img = segmentation_sam(input_image)
@@ -34,24 +60,54 @@ def segment_image(input_image):
34
  def classify_image(input_image, model_name):
35
  if 'Rock 170' ==model_name:
36
  from inference_resnet import inference_resnet_finer
37
- result = inference_resnet_finer(input_image,model_name,n_classes=171)
 
38
  return result
39
  elif 'Mummified 170' ==model_name:
40
  from inference_resnet import inference_resnet_finer
41
- result = inference_resnet_finer(input_image,model_name,n_classes=170)
 
42
  return result
43
  if 'Fossils 19' ==model_name:
44
  from inference_beit import inference_dino
 
45
  return inference_dino(input_image,model_name)
46
  return None
47
 
48
- def find_closest(input_image):
49
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
53
 
54
- with gr.Tab(" 19 Classes Support"):
55
 
56
  with gr.Row():
57
  with gr.Column():
@@ -64,10 +120,10 @@ with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
64
  #classify_segmented_button = gr.Button("Classify Segmented Image")
65
 
66
  with gr.Column():
67
- drop_2 = gr.Dropdown(
68
- ["Mummified 170", "Rock 170", "Fossils 19"],
69
  multiselect=False,
70
- value=["Rock 170"],
71
  label="Model",
72
  interactive=True,
73
  )
@@ -81,24 +137,24 @@ with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
81
  samples=[[path.as_posix()] for path in paths if 'leaves' in str(path) ][:19]
82
  examples_leaves = gr.Examples(samples, inputs=input_image,examples_per_page=5,label='Leaves Examples from the dataset')
83
 
84
- with gr.Accordion("Using Diffuser"):
85
- with gr.Column():
86
- prompt = gr.Textbox(lines=1, label="Prompt")
87
- output_image = gr.Image(label="Output")
88
- generate_button = gr.Button("Generate Leave")
89
- with gr.Column():
90
- class_predicted2 = gr.Label(label='Class Predicted from diffuser')
91
- classify_button = gr.Button("Classify Image")
92
 
93
 
94
  with gr.Accordion("Explanations "):
95
  gr.Markdown("Computing Explanations from the model")
96
  with gr.Row():
97
- original_input = gr.Image(label="Original Frame")
98
  saliency = gr.Image(label="saliency")
99
- gradcam = gr.Image(label='gradcam')
100
- guided_gradcam = gr.Image(label='guided gradcam')
101
- guided_backprop = gr.Image(label='guided backprop')
102
  generate_explanations = gr.Button("Generate Explanations")
103
 
104
  with gr.Accordion('Closest Images'):
@@ -112,8 +168,10 @@ with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
112
  find_closest_btn = gr.Button("Find Closest Images")
113
 
114
  segment_button.click(segment_image, inputs=input_image, outputs=segmented_image)
115
- classify_image_button.click(classify_image, inputs=[input_image,drop_2], outputs=class_predicted)
116
- #classify_segmented_button.click(classify_image, inputs=[segmented_image,drop_2], outputs=class_predicted)
 
 
117
 
118
  demo.queue()
119
 
 
9
  subprocess.call('pip install git+https://github.com/cocodataset/panopticapi.git'.split())
10
  subprocess.call('pip install python-dotenv'.split())
11
  subprocess.call('pip install torch torchvision '.split())
12
+ subprocess.call('pip install xplique'.split())
13
 
14
  import gradio as gr
15
  from huggingface_hub import snapshot_download
 
20
  import gradio as gr
21
  import glob
22
  from inference_sam import segmentation_sam
23
+ from explanations import explain
24
+ from inference_resnet import get_triplet_model
25
  import pathlib
26
+ import tensorflow as tf
27
+ from closest_sample import get_images
28
 
29
  if not os.path.exists('images'):
30
  REPO_ID='Serrelab/image_examples_gradio'
31
  snapshot_download(repo_id=REPO_ID, token=os.environ.get('READ_TOKEN'),repo_type='dataset',local_dir='images')
32
 
33
+ def get_model(model_name):
34
+
35
+
36
+ if model_name=='Mummified 170':
37
+ n_classes = 170
38
+ model = get_triplet_model(input_shape = (600, 600, 3),
39
+ embedding_units = 256,
40
+ embedding_depth = 2,
41
+ backbone_class=tf.keras.applications.ResNet50V2,
42
+ nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2')
43
+ model.load_weights('model_classification/mummified-170.h5')
44
+ elif model_name=='Rock 170':
45
+ n_classes = 171
46
+ model = get_triplet_model(input_shape = (600, 600, 3),
47
+ embedding_units = 256,
48
+ embedding_depth = 2,
49
+ backbone_class=tf.keras.applications.ResNet50V2,
50
+ nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2')
51
+ model.load_weights('model_classification/rock-170.h5')
52
+ else:
53
+ return 'Error'
54
+ return model,n_classes
55
 
56
  def segment_image(input_image):
57
  img = segmentation_sam(input_image)
 
60
  def classify_image(input_image, model_name):
61
  if 'Rock 170' ==model_name:
62
  from inference_resnet import inference_resnet_finer
63
+ model,n_classes= get_model(model_name)
64
+ result = inference_resnet_finer(input_image,model,size=600,n_classes=n_classes)
65
  return result
66
  elif 'Mummified 170' ==model_name:
67
  from inference_resnet import inference_resnet_finer
68
+ model, n_classes= get_model(model_name)
69
+ result = inference_resnet_finer(input_image,model,size=600,n_classes=n_classes)
70
  return result
71
  if 'Fossils 19' ==model_name:
72
  from inference_beit import inference_dino
73
+ model,n_classes = get_model(model_name)
74
  return inference_dino(input_image,model_name)
75
  return None
76
 
77
+ def get_embeddings(input_image,model_name):
78
+ if 'Rock 170' ==model_name:
79
+ from inference_resnet import inference_resnet_embedding
80
+ model,n_classes= get_model(model_name)
81
+ result = inference_resnet_embedding(input_image,model,size=600,n_classes=n_classes)
82
+ return result
83
+ elif 'Mummified 170' ==model_name:
84
+ from inference_resnet import inference_resnet_embedding
85
+ model, n_classes= get_model(model_name)
86
+ result = inference_resnet_embedding(input_image,model,size=600,n_classes=n_classes)
87
+ return result
88
+ if 'Fossils 19' ==model_name:
89
+ from inference_beit import inference_dino
90
+ model,n_classes = get_model(model_name)
91
+ return inference_dino(input_image,model_name)
92
+ return None
93
+
94
 
95
+ def find_closest(input_image,model_name):
96
+ embedding = get_embeddings(input_image,model_name)
97
+ paths = get_images(embedding)
98
+ return paths
99
+
100
+ def explain_image(input_image,model_name):
101
+ model,n_classes= get_model(model_name)
102
+ saliency, integrated, smoothgrad = explain(model,input_image,n_classes=n_classes)
103
+ #original = saliency + integrated + smoothgrad
104
+ print('done')
105
+ return saliency, integrated, smoothgrad,
106
+
107
 
108
  with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
109
 
110
+ with gr.Tab(" Florrissant Fossils"):
111
 
112
  with gr.Row():
113
  with gr.Column():
 
120
  #classify_segmented_button = gr.Button("Classify Segmented Image")
121
 
122
  with gr.Column():
123
+ model_name = gr.Dropdown(
124
+ ["Mummified 170", "Rock 170"],
125
  multiselect=False,
126
+ value="Rock 170",
127
  label="Model",
128
  interactive=True,
129
  )
 
137
  samples=[[path.as_posix()] for path in paths if 'leaves' in str(path) ][:19]
138
  examples_leaves = gr.Examples(samples, inputs=input_image,examples_per_page=5,label='Leaves Examples from the dataset')
139
 
140
+ # with gr.Accordion("Using Diffuser"):
141
+ # with gr.Column():
142
+ # prompt = gr.Textbox(lines=1, label="Prompt")
143
+ # output_image = gr.Image(label="Output")
144
+ # generate_button = gr.Button("Generate Leave")
145
+ # with gr.Column():
146
+ # class_predicted2 = gr.Label(label='Class Predicted from diffuser')
147
+ # classify_button = gr.Button("Classify Image")
148
 
149
 
150
  with gr.Accordion("Explanations "):
151
  gr.Markdown("Computing Explanations from the model")
152
  with gr.Row():
153
+ #original_input = gr.Image(label="Original Frame")
154
  saliency = gr.Image(label="saliency")
155
+ gradcam = gr.Image(label='integraged gradients')
156
+ guided_gradcam = gr.Image(label='gradcam')
157
+ #guided_backprop = gr.Image(label='guided backprop')
158
  generate_explanations = gr.Button("Generate Explanations")
159
 
160
  with gr.Accordion('Closest Images'):
 
168
  find_closest_btn = gr.Button("Find Closest Images")
169
 
170
  segment_button.click(segment_image, inputs=input_image, outputs=segmented_image)
171
+ classify_image_button.click(classify_image, inputs=[input_image,model_name], outputs=class_predicted)
172
+ generate_explanations.click(explain_image, inputs=[input_image,model_name], outputs=[saliency,gradcam,guided_gradcam])
173
+ find_closest_btn.click(find_closest, inputs=[input_image,model_name], outputs=[closest_image_0,closest_image_1,closest_image_2,closest_image_3,closest_image_4])
174
+ #classify_segmented_button.click(classify_image, inputs=[segmented_image,model_name], outputs=class_predicted)
175
 
176
  demo.queue()
177
 
inference_resnet.py CHANGED
@@ -12,9 +12,10 @@ from huggingface_hub import snapshot_download
12
  from labels import lookup_170
13
  import numpy as np
14
 
 
15
 
16
- REPO_ID='Serrelab/fossil_classification_models'
17
- snapshot_download(repo_id=REPO_ID, token=os.environ.get('READ_TOKEN'),repo_type='model',local_dir='model_classification')
18
 
19
 
20
  def get_model(base_arch='Nasnet',weights='imagenet',input_shape=(600,600,3),classes=64500):
@@ -146,19 +147,19 @@ def parse_results(top_n,logits):
146
  results[label] = float(logits[n])
147
  return results
148
 
149
- def inference_resnet_finer(x,type_model,size=576,n_classes=170,n_top=10):
 
 
 
 
 
 
 
 
 
 
 
150
 
151
- model = get_triplet_model(input_shape = (size, size, 3),
152
- embedding_units = 256,
153
- embedding_depth = 2,
154
- backbone_class=tf.keras.applications.ResNet50V2,
155
- nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2')
156
- if type_model=='Mummified 170':
157
- model.load_weights('model_classification/mummified-170.h5')
158
- elif type_model=='Rock 170':
159
- model.load_weights('model_classification/rock-170.h5')
160
- else:
161
- return 'Error'
162
  cropped = _clever_crop(x,(size,size))[0]
163
  prep = preprocess(cropped,size=size)
164
  logits = tf.nn.softmax(model.predict(np.array([prep]))[1][0]).cpu().numpy()
 
12
  from labels import lookup_170
13
  import numpy as np
14
 
15
+ if not os.path.exists('model_classification'):
16
 
17
+ REPO_ID='Serrelab/fossil_classification_models'
18
+ snapshot_download(repo_id=REPO_ID, token=os.environ.get('READ_TOKEN'),repo_type='model',local_dir='model_classification')
19
 
20
 
21
  def get_model(base_arch='Nasnet',weights='imagenet',input_shape=(600,600,3),classes=64500):
 
147
  results[label] = float(logits[n])
148
  return results
149
 
150
+ def inference_resnet_embedding(x,model,size=576,n_classes=170,n_top=10):
151
+
152
+
153
+ cropped = _clever_crop(x,(size,size))[0]
154
+ prep = preprocess(cropped,size=size)
155
+ embedding = model.predict(np.array([prep]))[0][0]
156
+
157
+
158
+ return embedding
159
+
160
+ def inference_resnet_finer(x,model,size=576,n_classes=170,n_top=10):
161
+
162
 
 
 
 
 
 
 
 
 
 
 
 
163
  cropped = _clever_crop(x,(size,size))[0]
164
  prep = preprocess(cropped,size=size)
165
  logits = tf.nn.softmax(model.predict(np.array([prep]))[1][0]).cpu().numpy()
inference_sam.py CHANGED
@@ -12,10 +12,11 @@ from math import ceil
12
  import os
13
  from huggingface_hub import snapshot_download
14
 
15
- REPO_ID='Serrelab/SAM_Leaves'
16
- snapshot_download(repo_id=REPO_ID, token=os.environ.get('READ_TOKEN'),repo_type='model',local_dir='model')
 
17
 
18
- sam = sam_model_registry["default"]("model/sam_02-06_dice_mse_0.pth")
19
  sam.cuda()
20
  predictor = SamPredictor(sam)
21
 
@@ -172,4 +173,4 @@ def segmentation_sam(x,SIZE=384):
172
  data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
173
  data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
174
  plt.close()
175
- return data
 
12
  import os
13
  from huggingface_hub import snapshot_download
14
 
15
+ if not os.path.exists('model'):
16
+ REPO_ID='Serrelab/SAM_Leaves'
17
+ snapshot_download(repo_id=REPO_ID, token=os.environ.get('READ_TOKEN'),repo_type='model',local_dir='model')
18
 
19
+ sam = sam_model_registry["default"]("/home/irodri15/Documents/Projects/Fossils/fossil_app/model/sam_02-06_dice_mse_0.pth")
20
  sam.cuda()
21
  predictor = SamPredictor(sam)
22
 
 
173
  data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
174
  data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
175
  plt.close()
176
+ return data
pre-requirements.txt CHANGED
@@ -3,4 +3,5 @@ opencv-python-headless==4.5.5.64
3
  openmim==0.1.5
4
  torch==1.11.0
5
  torchvision==0.12.0
6
- tensorflow==2.8
 
 
3
  openmim==0.1.5
4
  torch==1.11.0
5
  torchvision==0.12.0
6
+ tensorflow==2.8
7
+ xplique