Yuxiang Wang commited on
Commit
0c61c42
·
1 Parent(s): 8c11be1

feat:add beit,rise xai;display closest imgs with gallery

Browse files
Files changed (8) hide show
  1. app.py +45 -29
  2. closest_sample.py +1 -0
  3. env.py +1 -0
  4. explanations.py +29 -18
  5. fossils_paths.csv +0 -0
  6. inference_beit.py +100 -186
  7. labels.py +144 -0
  8. update_csv.py +10 -0
app.py CHANGED
@@ -18,6 +18,7 @@ import glob
18
  from inference_sam import segmentation_sam
19
  from explanations import explain
20
  from inference_resnet import get_triplet_model
 
21
  import pathlib
22
  import tensorflow as tf
23
  from closest_sample import get_images
@@ -26,6 +27,14 @@ if not os.path.exists('images'):
26
  REPO_ID='Serrelab/image_examples_gradio'
27
  snapshot_download(repo_id=REPO_ID, token=os.environ.get('READ_TOKEN'),repo_type='dataset',local_dir='images')
28
 
 
 
 
 
 
 
 
 
29
  def get_model(model_name):
30
 
31
 
@@ -45,25 +54,24 @@ def get_model(model_name):
45
  backbone_class=tf.keras.applications.ResNet50V2,
46
  nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2')
47
  model.load_weights('model_classification/rock-170.h5')
 
 
 
 
 
 
 
48
  else:
49
  raise ValueError(f"Model name '{model_name}' is not recognized")
50
  return model,n_classes
51
 
52
- '''
53
- elif model_name == 'Fossils 19':
54
- n_classes = 19 or 23?
55
- model = get_beit_model(input_shape=(600, 600, 3),
56
- num_labels=n_classes,
57
- load_weights=False,
58
- )
59
- model.load_weights('model_classification/beit-fossils-19.h5')
60
- '''
61
 
62
  def segment_image(input_image):
63
  img = segmentation_sam(input_image)
64
  return img
65
 
66
  def classify_image(input_image, model_name):
 
67
  if 'Rock 170' ==model_name:
68
  from inference_resnet import inference_resnet_finer
69
  model,n_classes= get_model(model_name)
@@ -74,10 +82,10 @@ def classify_image(input_image, model_name):
74
  model, n_classes= get_model(model_name)
75
  result = inference_resnet_finer(input_image,model,size=600,n_classes=n_classes)
76
  return result
77
- if 'Fossils 19' ==model_name:
78
- from inference_beit import inference_dino
79
  model,n_classes = get_model(model_name)
80
- result = inference_dino(input_image,model_name)
81
  return result
82
  return None
83
 
@@ -92,12 +100,10 @@ def get_embeddings(input_image,model_name):
92
  model, n_classes= get_model(model_name)
93
  result = inference_resnet_embedding(input_image,model,size=600,n_classes=n_classes)
94
  return result
95
- if 'Fossils 19' ==model_name:
96
- from inference_beit import inference_dino
97
  model,n_classes = get_model(model_name)
98
- result = inference_dino(input_image,model_name)
99
- #TODO
100
- #result = inference_beit_embedding
101
  return result
102
  return None
103
 
@@ -110,11 +116,16 @@ def find_closest(input_image,model_name):
110
 
111
  def explain_image(input_image,model_name):
112
  model,n_classes= get_model(model_name)
 
 
 
 
113
  #saliency, integrated, smoothgrad,
114
- rise = explain(model,input_image,n_classes=n_classes)
115
  #original = saliency + integrated + smoothgrad
116
  print('done')
117
- return rise
 
118
 
119
  #minimalist theme
120
  with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
@@ -126,17 +137,17 @@ with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
126
  input_image = gr.Image(label="Input")
127
  classify_image_button = gr.Button("Classify Image")
128
 
129
- with gr.Column():
130
- #segmented_image = gr.outputs.Image(label="SAM output",type='numpy')
131
- segmented_image=gr.Image(label="Segmented Image", type='numpy')
132
- segment_button = gr.Button("Segment Image")
133
- #classify_segmented_button = gr.Button("Classify Segmented Image")
134
 
135
  with gr.Column():
136
  model_name = gr.Dropdown(
137
- ["Mummified 170", "Rock 170","Fossils 19"],
138
  multiselect=False,
139
- value="Rock 170", # default option
140
  label="Model",
141
  interactive=True,
142
  )
@@ -168,7 +179,12 @@ with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
168
  #gradcam = gr.Image(label='integraged gradients')
169
  #guided_gradcam = gr.Image(label='gradcam')
170
  #guided_backprop = gr.Image(label='guided backprop')
171
- rise = gr.Image(label = 'Rise')
 
 
 
 
 
172
  generate_explanations = gr.Button("Generate Explanations")
173
 
174
  # with gr.Accordion('Closest Images'):
@@ -199,9 +215,9 @@ with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
199
 
200
  find_closest_btn = gr.Button("Find Closest Images")
201
 
202
- segment_button.click(segment_image, inputs=input_image, outputs=segmented_image)
203
  classify_image_button.click(classify_image, inputs=[input_image,model_name], outputs=class_predicted)
204
- generate_explanations.click(explain_image, inputs=[input_image,model_name], outputs=[rise]) #saliency,gradcam,guided_gradcam
205
  #find_closest_btn.click(find_closest, inputs=[input_image,model_name], outputs=[label_closest_image_0,label_closest_image_1,label_closest_image_2,label_closest_image_3,label_closest_image_4,closest_image_0,closest_image_1,closest_image_2,closest_image_3,closest_image_4])
206
  def update_outputs(input_image,model_name):
207
  labels, images = find_closest(input_image,model_name)
 
18
  from inference_sam import segmentation_sam
19
  from explanations import explain
20
  from inference_resnet import get_triplet_model
21
+ from inference_beit import get_triplet_model_beit
22
  import pathlib
23
  import tensorflow as tf
24
  from closest_sample import get_images
 
27
  REPO_ID='Serrelab/image_examples_gradio'
28
  snapshot_download(repo_id=REPO_ID, token=os.environ.get('READ_TOKEN'),repo_type='dataset',local_dir='images')
29
 
30
+ if not os.path.exists('dataset'):
31
+ REPO_ID='Serrelab/Fossils'
32
+ token = os.environ.get('READ_TOKEN')
33
+ print(f"Read token:{token}")
34
+ if token is None:
35
+ print("warning! A read token in env variables is needed for authentication.")
36
+ snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset')
37
+
38
  def get_model(model_name):
39
 
40
 
 
54
  backbone_class=tf.keras.applications.ResNet50V2,
55
  nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2')
56
  model.load_weights('model_classification/rock-170.h5')
57
+ elif model_name == 'Fossils 142':
58
+ n_classes = 142
59
+ model = get_triplet_model_beit(input_shape = (384, 384, 3),
60
+ embedding_units = 256,
61
+ embedding_depth = 2,
62
+ n_classes = n_classes)
63
+ model.load_weights('model_classification/fossil-142.h5')
64
  else:
65
  raise ValueError(f"Model name '{model_name}' is not recognized")
66
  return model,n_classes
67
 
 
 
 
 
 
 
 
 
 
68
 
69
  def segment_image(input_image):
70
  img = segmentation_sam(input_image)
71
  return img
72
 
73
  def classify_image(input_image, model_name):
74
+ #segmented_image = segment_image(input_image)
75
  if 'Rock 170' ==model_name:
76
  from inference_resnet import inference_resnet_finer
77
  model,n_classes= get_model(model_name)
 
82
  model, n_classes= get_model(model_name)
83
  result = inference_resnet_finer(input_image,model,size=600,n_classes=n_classes)
84
  return result
85
+ if 'Fossils 142' ==model_name:
86
+ from inference_beit import inference_resnet_finer_beit
87
  model,n_classes = get_model(model_name)
88
+ result = inference_resnet_finer_beit(input_image,model,size=384,n_classes=n_classes)
89
  return result
90
  return None
91
 
 
100
  model, n_classes= get_model(model_name)
101
  result = inference_resnet_embedding(input_image,model,size=600,n_classes=n_classes)
102
  return result
103
+ if 'Fossils 142' ==model_name:
104
+ from inference_beit import inference_resnet_embedding_beit
105
  model,n_classes = get_model(model_name)
106
+ result = inference_resnet_embedding_beit(input_image,model,size=384,n_classes=n_classes)
 
 
107
  return result
108
  return None
109
 
 
116
 
117
  def explain_image(input_image,model_name):
118
  model,n_classes= get_model(model_name)
119
+ if model_name=='Fossils 142':
120
+ size = 384
121
+ else:
122
+ size = 600
123
  #saliency, integrated, smoothgrad,
124
+ rise,avg = explain(model,input_image,size = size, n_classes=n_classes)
125
  #original = saliency + integrated + smoothgrad
126
  print('done')
127
+ rise1,rise2,rise3,rise4,rise5,avg = rise[0],rise[1],rise[2],rise[3],rise[4],avg[0]
128
+ return rise1,rise2,rise3,rise4,rise5,avg
129
 
130
  #minimalist theme
131
  with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
 
137
  input_image = gr.Image(label="Input")
138
  classify_image_button = gr.Button("Classify Image")
139
 
140
+ # with gr.Column():
141
+ # #segmented_image = gr.outputs.Image(label="SAM output",type='numpy')
142
+ # segmented_image=gr.Image(label="Segmented Image", type='numpy')
143
+ # segment_button = gr.Button("Segment Image")
144
+ # #classify_segmented_button = gr.Button("Classify Segmented Image")
145
 
146
  with gr.Column():
147
  model_name = gr.Dropdown(
148
+ ["Mummified 170", "Rock 170","Fossils 142"],
149
  multiselect=False,
150
+ value="Fossils 142", # default option
151
  label="Model",
152
  interactive=True,
153
  )
 
179
  #gradcam = gr.Image(label='integraged gradients')
180
  #guided_gradcam = gr.Image(label='gradcam')
181
  #guided_backprop = gr.Image(label='guided backprop')
182
+ rise1 = gr.Image(label = 'Rise1')
183
+ rise2 = gr.Image(label = 'Rise2')
184
+ rise3 = gr.Image(label = 'Rise3')
185
+ rise4 = gr.Image(label = 'Rise4')
186
+ rise5 = gr.Image(label = 'Rise5')
187
+ avg = gr.Image(label = 'Avg')
188
  generate_explanations = gr.Button("Generate Explanations")
189
 
190
  # with gr.Accordion('Closest Images'):
 
215
 
216
  find_closest_btn = gr.Button("Find Closest Images")
217
 
218
+ #segment_button.click(segment_image, inputs=input_image, outputs=segmented_image)
219
  classify_image_button.click(classify_image, inputs=[input_image,model_name], outputs=class_predicted)
220
+ generate_explanations.click(explain_image, inputs=[input_image,model_name], outputs=[rise1,rise2,rise3,rise4,rise5,avg]) #
221
  #find_closest_btn.click(find_closest, inputs=[input_image,model_name], outputs=[label_closest_image_0,label_closest_image_1,label_closest_image_2,label_closest_image_3,label_closest_image_4,closest_image_0,closest_image_1,closest_image_2,closest_image_3,closest_image_4])
222
  def update_outputs(input_image,model_name):
223
  labels, images = find_closest(input_image,model_name)
closest_sample.py CHANGED
@@ -77,6 +77,7 @@ def get_images(embedding):
77
  public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/General_Fossil/512/full/jpg/', folder_general)
78
  else:
79
  print("no match found")
 
80
  download_public_image(public_path, local_file_path)
81
  names = []
82
  parts = [part for part in public_path.split('/') if part]
 
77
  public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/General_Fossil/512/full/jpg/', folder_general)
78
  else:
79
  print("no match found")
80
+ print(public_path)
81
  download_public_image(public_path, local_file_path)
82
  names = []
83
  parts = [part for part in public_path.split('/') if part]
env.py CHANGED
@@ -17,6 +17,7 @@ def config_env():
17
  ('xplique', None),
18
  ('segment_anything', None),
19
  ('panopticapi', None),
 
20
  ]
21
 
22
  name_to_command = {'segment_anything':'git+https://github.com/facebookresearch/segment-anything.git',
 
17
  ('xplique', None),
18
  ('segment_anything', None),
19
  ('panopticapi', None),
20
+ ('keras_cv_attention_models',None)
21
  ]
22
 
23
  name_to_command = {'segment_anything':'git+https://github.com/facebookresearch/segment-anything.git',
explanations.py CHANGED
@@ -54,35 +54,46 @@ def explain(model, input_image,size=600, n_classes=171) :
54
  #IntegratedGradients(class_model, steps=50, batch_size=BATCH_SIZE),
55
  #SmoothGrad(class_model, nb_samples=50, batch_size=BATCH_SIZE),
56
  #GradCAM(class_model),
57
- Rise(class_model,nb_samples = 50, batch_size = BATCH_SIZE,grid_size=7,
58
  preservation_probability=0.5)
59
  #
60
- ]
 
 
 
61
  cropped,repetitions = _clever_crop(input_image,(size,size))
62
  size_repetitions = int(size//(repetitions.numpy()+1))
63
  X = preprocess(cropped,size=size)
64
- Y = np.argmax(class_model.predict(np.array([X])))
 
 
 
65
  X = np.expand_dims(X, 0)
66
  explanations = []
67
- for e,explainer in enumerate(explainers):
68
- print(f'{e}/{len(explainers)}')
69
- print('1')
70
  Y = tf.one_hot([Y], n_classes)
 
71
  phi = np.abs(explainer(X, Y))[0]
72
- print('1')
73
  if len(phi.shape) == 3:
74
  phi = np.mean(phi, -1)
75
- print('1')
76
  show(X[0][:,size_repetitions:2*size_repetitions,:])
77
  show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
78
- print('1')
79
- plt.savefig(f'phi_{e}.png')
80
- print('1')
81
- explanations.append(f'phi_{e}.png')
82
- print('1')
83
- print(type(explanations))
84
- print(len(explanations))
85
-
86
- print('Done')
 
 
 
 
 
 
 
 
87
 
88
- return explanations
 
54
  #IntegratedGradients(class_model, steps=50, batch_size=BATCH_SIZE),
55
  #SmoothGrad(class_model, nb_samples=50, batch_size=BATCH_SIZE),
56
  #GradCAM(class_model),
57
+ Rise(class_model,nb_samples = 50, batch_size = BATCH_SIZE,grid_size=15,
58
  preservation_probability=0.5)
59
  #
60
+ ]
61
+ explainer = Rise(class_model,nb_samples = 50, batch_size = BATCH_SIZE,grid_size=15,
62
+ preservation_probability=0.5)
63
+
64
  cropped,repetitions = _clever_crop(input_image,(size,size))
65
  size_repetitions = int(size//(repetitions.numpy()+1))
66
  X = preprocess(cropped,size=size)
67
+ predictions = class_model.predict(np.array([X]))
68
+ #Y = np.argmax(predictions)
69
+ top_5_indices = np.argsort(predictions[0])[-5:][::-1]
70
+ #print(top_5_indices)
71
  X = np.expand_dims(X, 0)
72
  explanations = []
73
+ for i,Y in enumerate(top_5_indices):
 
 
74
  Y = tf.one_hot([Y], n_classes)
75
+ print(f'{i}/{len(top_5_indices)}')
76
  phi = np.abs(explainer(X, Y))[0]
 
77
  if len(phi.shape) == 3:
78
  phi = np.mean(phi, -1)
 
79
  show(X[0][:,size_repetitions:2*size_repetitions,:])
80
  show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
81
+ plt.savefig(f'phi_{i}.png')
82
+ explanations.append(f'phi_{i}.png')
83
+ avg=[]
84
+ for i,Y in enumerate(top_5_indices):
85
+ Y = tf.one_hot([Y], n_classes)
86
+ print(f'{i}/{len(top_5_indices)}')
87
+ phi = np.abs(explainer(X, Y))[0]
88
+ if len(phi.shape) == 3:
89
+ phi = np.mean(phi, -1)
90
+ show(X[0][:,size_repetitions:2*size_repetitions,:])
91
+ show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
92
+ plt.savefig(f'phi_6.png')
93
+ avg.append(f'phi_6.png')
94
+
95
+ print('Done')
96
+ if len(explanations)==1:
97
+ explanations = explanations[0]
98
 
99
+ return explanations,avg
fossils_paths.csv CHANGED
The diff for this file is too large to render. See raw diff
 
inference_beit.py CHANGED
@@ -9,195 +9,109 @@ import os
9
  import numpy as np
10
  import keras
11
  from PIL import Image
12
- import keras_cv
13
  from keras_cv_attention_models import beit
14
  import matplotlib.pyplot as plt
15
 
16
-
17
- #preprocessing
18
- #TODO
19
- num_classes = len(class_names)
20
- AUTO = tf.data.AUTOTUNE
21
- rand_augment = keras_cv.layers.RandAugment(value_range = (-1, 1), augmentations_per_image = 3, magnitude=0.5)
22
-
23
- SIZE = 384
24
- debug = None
25
-
26
- def augmentations(x, crop_size=22, brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2):
27
- x = tf.cast(x, tf.float32)
28
- x = tf.image.random_crop(x, (tf.shape(x)[0], 100, 100, 3))
29
- x = tf.image.random_brightness(x, max_delta=brightness)
30
- x = tf.image.random_contrast(x, lower=1.0-contrast, upper=1+contrast)
31
- x = tf.image.random_saturation(x, lower=1.0-saturation, upper=1.0+saturation)
32
- x = tf.image.random_hue(x, max_delta=hue)
33
- x = tf.image.resize(x, (128, 128))
34
- x = tf.clip_by_value(x, 0.0, 255.0)
35
- x = tf.keras.applications.resnet_v2.preprocess_input(x)
36
- return x
37
-
38
-
39
- def pad_gt(x):
40
- h, w = x.shape[-2:]
41
- padh = sam.image_encoder.img_size - h
42
- padw = sam.image_encoder.img_size - w
43
- x = F.pad(x, (0, padw, 0, padh))
44
- return x
45
-
46
- def preprocess(img):
47
-
48
- img = np.array(img).astype(np.uint8)
49
-
50
- #assert img.max() > 127.0
51
-
52
- img_preprocess = predictor.transform.apply_image(img)
53
- intermediate_shape = img_preprocess.shape
54
-
55
- img_preprocess = torch.as_tensor(img_preprocess).cuda()
56
- img_preprocess = img_preprocess.permute(2, 0, 1).contiguous()[None, :, :, :]
57
-
58
- img_preprocess = sam.preprocess(img_preprocess)
59
- if len(intermediate_shape) == 3:
60
- intermediate_shape = intermediate_shape[:2]
61
- elif len(intermediate_shape) == 4:
62
- intermediate_shape = intermediate_shape[1:3]
63
-
64
- return img_preprocess, intermediate_shape
65
-
66
-
67
-
68
- def normalize(img):
69
- img = img - tf.math.reduce_min(img)
70
- img = img / tf.math.reduce_max(img)
71
- img = img * 2.0 - 1.0
72
- return img
73
-
74
- def smooth_mask(mask, ds=20):
75
- shape = tf.shape(mask)
76
- w, h = shape[0], shape[1]
77
- return tf.image.resize(tf.image.resize(mask, (ds, ds), method="bicubic"), (w, h), method="bicubic")
78
-
79
- def resize(img):
80
- # default resize function for all pi outputs
81
- return tf.image.resize(img, (SIZE, SIZE), method="bicubic")
82
-
83
- def pi(img, mask):
84
- img = tf.cast(img, tf.float32)
85
-
86
- shape = tf.shape(img)
87
- w, h = tf.cast(shape[0], tf.int64), tf.cast(shape[1], tf.int64)
88
-
89
- mask = smooth_mask(mask)
90
- mask = tf.reduce_mean(mask, -1)
91
-
92
- img = img * tf.cast(mask > 0.1, tf.float32)[:, :, None]
93
-
94
- img_resize = tf.image.resize(img, (SIZE, SIZE), method="bicubic", antialias=True)
95
- img_pad = tf.image.resize_with_pad(img, SIZE, SIZE, method="bicubic", antialias=True)
96
-
97
- # building 2 anchors
98
- anchors = tf.where(mask > 0.15)
99
- anchor_xmin = tf.math.reduce_min(anchors[:, 0])
100
- anchor_xmax = tf.math.reduce_max(anchors[:, 0])
101
- anchor_ymin = tf.math.reduce_min(anchors[:, 1])
102
- anchor_ymax = tf.math.reduce_max(anchors[:, 1])
103
-
104
- if anchor_xmax - anchor_xmin > 50 and anchor_ymax - anchor_ymin > 50:
105
-
106
- img_anchor_1 = resize(img[anchor_xmin:anchor_xmax, anchor_ymin:anchor_ymax])
107
-
108
- delta_x = (anchor_xmax - anchor_xmin) // 4
109
- delta_y = (anchor_ymax - anchor_ymin) // 4
110
- img_anchor_2 = img[anchor_xmin+delta_x:anchor_xmax-delta_x,
111
- anchor_ymin+delta_y:anchor_ymax-delta_y]
112
- img_anchor_2 = resize(img_anchor_2)
113
- else:
114
- img_anchor_1 = img_resize
115
- img_anchor_2 = img_pad
116
-
117
- # building the anchors max
118
- anchor_max = tf.where(mask == tf.math.reduce_max(mask))[0]
119
- anchor_max_x, anchor_max_y = anchor_max[0], anchor_max[1]
120
-
121
- img_max_zoom1 = img[tf.math.maximum(anchor_max_x-SIZE, 0): tf.math.minimum(anchor_max_x+SIZE, w),
122
- tf.math.maximum(anchor_max_y-SIZE, 0): tf.math.minimum(anchor_max_y+SIZE, h)]
123
-
124
- img_max_zoom1 = resize(img_max_zoom1)
125
- img_max_zoom2 = img[anchor_max_x-SIZE//2:anchor_max_x+SIZE//2,
126
- anchor_max_y-SIZE//2:anchor_max_y+SIZE//2]
127
- img_max_zoom2 = img[tf.math.maximum(anchor_max_x-SIZE//2, 0): tf.math.minimum(anchor_max_x+SIZE//2, w),
128
- tf.math.maximum(anchor_max_y-SIZE//2, 0): tf.math.minimum(anchor_max_y+SIZE//2, h)]
129
- #tf.print(img_max_zoom2.shape)
130
- #img_max_zoom2 = resize(img_max_zoom2)
131
-
132
- return tf.cast(img_resize, tf.float32)
133
-
134
- def parse_img(element, split, randaugment,maskaugment=True):
135
- #global debug
136
- path, class_id = element[0], element[1]
137
-
138
- data = tf.io.read_file(path)
139
- img = tf.io.decode_jpeg(data)
140
- img = tf.cast(img, tf.uint8)
141
- img = normalize(img)
142
- shape = tf.shape(img)
143
-
144
- # data_mask = tf.io.read_file(path_mask)
145
- # mask = tf.io.decode_jpeg(data_mask)
146
-
147
- class_id = tf.strings.to_number(class_id)
148
- class_id = tf.cast(class_id, tf.int32)
149
-
150
- label = tf.one_hot(class_id, num_classes)
151
-
152
- # img = pi(img, mask)
153
- img = tf.image.resize_with_pad(img, SIZE, SIZE, method="bicubic", antialias=True)
154
-
155
- return tf.cast(img, tf.float32), tf.cast(label, tf.int32)
156
-
157
- SIZE = 384
158
- wsize=hsize=SIZE
159
- def resize_images(batch_x, width=224, height=224):
160
- return tf.image.resize(batch_x, (width, height))
161
-
162
- def load_img(image_path,gray=False):
163
- img = tf.io.read_file(image_path)
164
- img = tf.image.decode_jpeg(img, channels=3)
165
- img = tf.image.convert_image_dtype(img, tf.float32)
166
- if gray:
167
  img = tf.image.rgb_to_grayscale(img)
168
  img = tf.image.grayscale_to_rgb(img)
169
- img = tf.image.resize(img,(wsize,hsize))
170
- return img
171
-
172
- LR = 1e-3
173
-
174
- optimizer = tf.keras.optimizers.Adam(LR)
175
- cce = tf.keras.losses.categorical_crossentropy
176
-
177
- model_path = '/content/drive/MyDrive/Gg_Fossils_data_shared_copy/Fossils/models/model-13.h5'
178
- model = keras.models.load_model(model_path, custom_objects = {'cce': cce})
179
-
180
- outputs = model.predict(images)
181
-
182
- predictions = tf.math.top_k(outputs[1], k = 5)
183
- cid = 1
184
- dataset = np.array(dataset)
185
- final_predictions = []
186
- for ele in predictions[1]:
187
- if cid in ele:
188
- final_predictions.append(cid)
189
- else:
190
- final_predictions.append(cid+10)
191
- final_predictions = np.array(final_predictions)
192
- images2 = images[final_predictions == cid]
193
- image2_paths = dataset[final_predictions == cid][:,0]
194
- print(images2.shape)
195
-
196
- def get_beit_model(input_shape, num_labels, load_weights=False, ...):
197
- pass
198
-
199
- def inference_dino(input_image, model_name):
200
- pass
201
 
202
- def inference_beit_embedding(input_image, model, size=600):
203
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  import numpy as np
10
  import keras
11
  from PIL import Image
 
12
  from keras_cv_attention_models import beit
13
  import matplotlib.pyplot as plt
14
 
15
+ import tensorflow as tf
16
+ from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
17
+ from typing import Tuple
18
+ #from huggingface_hub import snapshot_download
19
+ from labels import lookup_140
20
+
21
+
22
+ def get_triplet_model_beit(input_shape = (600, 600, 3),
23
+ embedding_units = 256,
24
+ embedding_depth = 2,
25
+ n_classes = 19,backbone_name ='Beit'):
26
+
27
+ backbone_class = beit.BeitBasePatch16(input_shape=input_shape, pretrained="imagenet21k-ft1k")
28
+
29
+ backbone_class = tf.keras.Model(backbone_class.input, backbone_class.layers[-2].output)
30
+ #features = GlobalAveragePooling2D()(backbone_class.output)
31
+ embedding_head = backbone_class.output
32
+
33
+ for embed_i in range(embedding_depth):
34
+ embedding_head = Dense(embedding_units, activation="relu" if embed_i < embedding_depth-1 else "linear")(embedding_head)
35
+ embedding_head = tf.nn.l2_normalize(embedding_head, -1, epsilon=1e-5)
36
+
37
+ logits_head = Dense(n_classes)(backbone_class.output)
38
+
39
+ model = tf.keras.Model(backbone_class.input, [embedding_head, logits_head])
40
+ model.compile(loss='cce',metrics=['accuracy'])
41
+ #model.summary()
42
+
43
+ return model
44
+
45
+
46
+
47
+
48
+ load_size = 600
49
+ crop_size = 600
50
+ def _clever_crop(img: tf.Tensor,
51
+ target_size: Tuple[int]=(128,128),
52
+ grayscale: bool=False
53
+ ) -> tf.Tensor:
54
+ """[summary]
55
+ Args:
56
+ img (tf.Tensor): [description]
57
+ target_size (Tuple[int], optional): [description]. Defaults to (128,128).
58
+ grayscale (bool, optional): [description]. Defaults to False.
59
+ Returns:
60
+ tf.Tensor: [description]
61
+ """
62
+ maxside = tf.math.maximum(tf.shape(img)[0],tf.shape(img)[1])
63
+ minside = tf.math.minimum(tf.shape(img)[0],tf.shape(img)[1])
64
+ new_img = img
65
+
66
+ if tf.math.divide(maxside,minside) > 1.2:
67
+ repeating = tf.math.floor(tf.math.divide(maxside,minside))
68
+ new_img = img
69
+ if tf.math.equal(tf.shape(img)[1],minside):
70
+ for _ in range(int(repeating)):
71
+ new_img = tf.concat((new_img, img), axis=1)
72
+
73
+ if tf.math.equal(tf.shape(img)[0],minside):
74
+ for _ in range(int(repeating)):
75
+ new_img = tf.concat((new_img, img), axis=0)
76
+ new_img = tf.image.rot90(new_img)
77
+ else:
78
+ new_img = img
79
+ repeating = 0
80
+ img = tf.image.resize(new_img, target_size)
81
+ if grayscale:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  img = tf.image.rgb_to_grayscale(img)
83
  img = tf.image.grayscale_to_rgb(img)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ return img,repeating
86
+
87
+ def preprocess(img,size=384):
88
+ img = np.array(img, np.float32) / 255.0
89
+ img = tf.image.resize(img, (size, size))
90
+ return np.array(img, np.float32)
91
+
92
+ def select_top_n(preds,n=10):
93
+ top_n = np.argsort(preds)[-n:][::-1]
94
+ return top_n
95
+
96
+ def parse_results(top_n,logits):
97
+ results = {}
98
+ for n in top_n:
99
+ label = lookup_140[n]
100
+ results[label] = float(logits[n])
101
+ return results
102
+
103
+ def inference_resnet_embedding_beit(x,model,size=576,n_classes=142,n_top=10):
104
+ cropped = _clever_crop(x,(size,size))[0]
105
+ prep = preprocess(cropped,size=size)
106
+ embedding = model.predict(np.array([prep]))[0][0]
107
+
108
+
109
+ return embedding
110
+
111
+ def inference_resnet_finer_beit(x,model,size=576,n_classes=142,n_top=10):
112
+ cropped = _clever_crop(x,(size,size))[0]
113
+ prep = preprocess(cropped,size=size)
114
+ logits = tf.nn.softmax(model.predict(np.array([prep]))[1][0]).cpu().numpy()
115
+ top_n = select_top_n(logits,n=n_top)
116
+
117
+ return parse_results(top_n,logits)
labels.py CHANGED
@@ -173,3 +173,147 @@ lookup_170 = {0: 'Anacardiaceae',
173
  dict_lu ={}
174
  for i in range(171):
175
  dict_lu[i] = lookup_170[i]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  dict_lu ={}
174
  for i in range(171):
175
  dict_lu[i] = lookup_170[i]
176
+
177
+
178
+ lookup_140 = {0: 'Anacardiaceae',
179
+ 1: 'Berberidaceae',
180
+ 2: 'Betulaceae',
181
+ 3: 'Cupressaceae',
182
+ 4: 'Dryopteridaceae',
183
+ 5: 'Fabaceae',
184
+ 6: 'Fagaceae',
185
+ 7: 'Juglandaceae',
186
+ 8: 'Lauraceae',
187
+ 9: 'Meliaceae',
188
+ 10: 'Myrtaceae',
189
+ 11: 'Pinaceae',
190
+ 12: 'Rhamnaceae',
191
+ 13: 'Rosaceae',
192
+ 14: 'Salicaceae',
193
+ 15: 'Sapindaceae',
194
+ 16: 'Ulmaceae',
195
+ 17: 'Viburnaceae',
196
+ 18: 'Vitaceae',
197
+ 19: 'Araceae',
198
+ 20: 'Grossulariaceae',
199
+ 21: 'Hydrangeaceae',
200
+ 22: 'Taxaceae',
201
+ 23: 'Achariaceae',
202
+ 24: 'Actinidiaceae',
203
+ 25: 'Altingiaceae',
204
+ 26: 'Amaranthaceae',
205
+ 27: 'Annonaceae',
206
+ 28: 'Apiaceae',
207
+ 29: 'Apocynaceae',
208
+ 30: 'Aquifoliaceae',
209
+ 31: 'Araliaceae',
210
+ 32: 'Aristolochiaceae',
211
+ 33: 'Asteraceae',
212
+ 34: 'Bignoniaceae',
213
+ 35: 'Boraginaceae',
214
+ 36: 'Burseraceae',
215
+ 37: 'Buxaceae',
216
+ 38: 'Calophyllaceae',
217
+ 39: 'Calycanthaceae',
218
+ 40: 'Campanulaceae',
219
+ 41: 'Canellaceae',
220
+ 42: 'Cannabaceae',
221
+ 43: 'Capparaceae',
222
+ 44: 'Caprifoliaceae',
223
+ 45: 'Cardiopteridaceae',
224
+ 46: 'Celastraceae',
225
+ 47: 'Chloranthaceae',
226
+ 48: 'Chrysobalanaceae',
227
+ 49: 'Clusiaceae',
228
+ 50: 'Combretaceae',
229
+ 51: 'Connaraceae',
230
+ 52: 'Coriariaceae',
231
+ 53: 'Cornaceae',
232
+ 54: 'Crassulaceae',
233
+ 55: 'Cucurbitaceae',
234
+ 56: 'Cunoniaceae',
235
+ 57: 'Dilleniaceae',
236
+ 58: 'Dipterocarpaceae',
237
+ 59: 'Ebenaceae',
238
+ 60: 'Elaeagnaceae',
239
+ 61: 'Elaeocarpaceae',
240
+ 62: 'Ericaceae',
241
+ 63: 'Escalloniaceae',
242
+ 64: 'Euphorbiaceae',
243
+ 65: 'Garryaceae',
244
+ 66: 'Geraniaceae',
245
+ 67: 'Gesneriaceae',
246
+ 68: 'Gnetaceae',
247
+ 69: 'Hamamelidaceae',
248
+ 70: 'Humiriaceae',
249
+ 71: 'Hypericaceae',
250
+ 72: 'Icacinaceae',
251
+ 73: 'Iteaceae',
252
+ 74: 'Ixonanthaceae',
253
+ 75: 'Lamiaceae',
254
+ 76: 'Lardizabalaceae',
255
+ 77: 'Lecythidaceae',
256
+ 78: 'Linaceae',
257
+ 79: 'Loganiaceae',
258
+ 80: 'Loranthaceae',
259
+ 81: 'Lythraceae',
260
+ 82: 'Magnoliaceae',
261
+ 83: 'Malpighiaceae',
262
+ 84: 'Malvaceae',
263
+ 85: 'Marantaceae',
264
+ 86: 'Melastomataceae',
265
+ 87: 'Menispermaceae',
266
+ 88: 'Monimiaceae',
267
+ 89: 'Moraceae',
268
+ 90: 'Myricaceae',
269
+ 91: 'Myristicaceae',
270
+ 92: 'Nothofagaceae',
271
+ 93: 'Nyctaginaceae',
272
+ 94: 'Nyssaceae',
273
+ 95: 'Ochnaceae',
274
+ 96: 'Olacaceae',
275
+ 97: 'Oleaceae',
276
+ 98: 'Onagraceae',
277
+ 99: 'Opiliaceae',
278
+ 100: 'Oxalidaceae',
279
+ 101: 'Paracryphiaceae',
280
+ 102: 'Passifloraceae',
281
+ 103: 'Penaeaceae',
282
+ 104: 'Pentaphylacaceae',
283
+ 105: 'Phyllanthaceae',
284
+ 106: 'Phytolaccaceae',
285
+ 107: 'Piperaceae',
286
+ 108: 'Pittosporaceae',
287
+ 109: 'Platanaceae',
288
+ 110: 'Polemoniaceae',
289
+ 111: 'Polygalaceae',
290
+ 112: 'Polygonaceae',
291
+ 113: 'Primulaceae',
292
+ 114: 'Proteaceae',
293
+ 115: 'Ranunculaceae',
294
+ 116: 'Rhizophoraceae',
295
+ 117: 'Rubiaceae',
296
+ 118: 'Rutaceae',
297
+ 119: 'Sabiaceae',
298
+ 120: 'Santalaceae',
299
+ 121: 'Sapotaceae',
300
+ 122: 'Sarcolaenaceae',
301
+ 123: 'Saxifragaceae',
302
+ 124: 'Schisandraceae',
303
+ 125: 'Scrophulariaceae',
304
+ 126: 'Simaroubaceae',
305
+ 127: 'Smilacaceae',
306
+ 128: 'Solanaceae',
307
+ 129: 'Staphyleaceae',
308
+ 130: 'Stemonuraceae',
309
+ 131: 'Styracaceae',
310
+ 132: 'Symplocaceae',
311
+ 133: 'Theaceae',
312
+ 134: 'Thymelaeaceae',
313
+ 135: 'Urticaceae',
314
+ 136: 'Verbenaceae',
315
+ 137: 'Violaceae',
316
+ 138: 'Vochysiaceae',
317
+ 139: 'Winteraceae',
318
+ 140: 'Zygophyllaceae',
319
+ 141:'Uncertain'}
update_csv.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # import pandas as pd
2
+
3
+ # # Load the CSV file into a DataFrame
4
+ # fossils_pd = pd.read_csv('fossils_paths.csv')
5
+
6
+ # # Replace '. ' with '' (effectively removing it) in the 'file_name' column
7
+ # fossils_pd['file_name'] = fossils_pd['file_name'].str.replace('. ', '', regex=False)
8
+
9
+ # # Optional: Save the updated DataFrame back to a CSV file if needed
10
+ # fossils_pd.to_csv('fossils_paths.csv', index=False)