Spaces:
Sleeping
Sleeping
Yuxiang Wang
commited on
Commit
·
0c61c42
1
Parent(s):
8c11be1
feat:add beit,rise xai;display closest imgs with gallery
Browse files- app.py +45 -29
- closest_sample.py +1 -0
- env.py +1 -0
- explanations.py +29 -18
- fossils_paths.csv +0 -0
- inference_beit.py +100 -186
- labels.py +144 -0
- update_csv.py +10 -0
app.py
CHANGED
@@ -18,6 +18,7 @@ import glob
|
|
18 |
from inference_sam import segmentation_sam
|
19 |
from explanations import explain
|
20 |
from inference_resnet import get_triplet_model
|
|
|
21 |
import pathlib
|
22 |
import tensorflow as tf
|
23 |
from closest_sample import get_images
|
@@ -26,6 +27,14 @@ if not os.path.exists('images'):
|
|
26 |
REPO_ID='Serrelab/image_examples_gradio'
|
27 |
snapshot_download(repo_id=REPO_ID, token=os.environ.get('READ_TOKEN'),repo_type='dataset',local_dir='images')
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
def get_model(model_name):
|
30 |
|
31 |
|
@@ -45,25 +54,24 @@ def get_model(model_name):
|
|
45 |
backbone_class=tf.keras.applications.ResNet50V2,
|
46 |
nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2')
|
47 |
model.load_weights('model_classification/rock-170.h5')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
else:
|
49 |
raise ValueError(f"Model name '{model_name}' is not recognized")
|
50 |
return model,n_classes
|
51 |
|
52 |
-
'''
|
53 |
-
elif model_name == 'Fossils 19':
|
54 |
-
n_classes = 19 or 23?
|
55 |
-
model = get_beit_model(input_shape=(600, 600, 3),
|
56 |
-
num_labels=n_classes,
|
57 |
-
load_weights=False,
|
58 |
-
)
|
59 |
-
model.load_weights('model_classification/beit-fossils-19.h5')
|
60 |
-
'''
|
61 |
|
62 |
def segment_image(input_image):
|
63 |
img = segmentation_sam(input_image)
|
64 |
return img
|
65 |
|
66 |
def classify_image(input_image, model_name):
|
|
|
67 |
if 'Rock 170' ==model_name:
|
68 |
from inference_resnet import inference_resnet_finer
|
69 |
model,n_classes= get_model(model_name)
|
@@ -74,10 +82,10 @@ def classify_image(input_image, model_name):
|
|
74 |
model, n_classes= get_model(model_name)
|
75 |
result = inference_resnet_finer(input_image,model,size=600,n_classes=n_classes)
|
76 |
return result
|
77 |
-
if 'Fossils
|
78 |
-
from inference_beit import
|
79 |
model,n_classes = get_model(model_name)
|
80 |
-
result =
|
81 |
return result
|
82 |
return None
|
83 |
|
@@ -92,12 +100,10 @@ def get_embeddings(input_image,model_name):
|
|
92 |
model, n_classes= get_model(model_name)
|
93 |
result = inference_resnet_embedding(input_image,model,size=600,n_classes=n_classes)
|
94 |
return result
|
95 |
-
if 'Fossils
|
96 |
-
from inference_beit import
|
97 |
model,n_classes = get_model(model_name)
|
98 |
-
result =
|
99 |
-
#TODO
|
100 |
-
#result = inference_beit_embedding
|
101 |
return result
|
102 |
return None
|
103 |
|
@@ -110,11 +116,16 @@ def find_closest(input_image,model_name):
|
|
110 |
|
111 |
def explain_image(input_image,model_name):
|
112 |
model,n_classes= get_model(model_name)
|
|
|
|
|
|
|
|
|
113 |
#saliency, integrated, smoothgrad,
|
114 |
-
rise = explain(model,input_image,n_classes=n_classes)
|
115 |
#original = saliency + integrated + smoothgrad
|
116 |
print('done')
|
117 |
-
|
|
|
118 |
|
119 |
#minimalist theme
|
120 |
with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
|
@@ -126,17 +137,17 @@ with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
|
|
126 |
input_image = gr.Image(label="Input")
|
127 |
classify_image_button = gr.Button("Classify Image")
|
128 |
|
129 |
-
with gr.Column():
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
|
135 |
with gr.Column():
|
136 |
model_name = gr.Dropdown(
|
137 |
-
["Mummified 170", "Rock 170","Fossils
|
138 |
multiselect=False,
|
139 |
-
value="
|
140 |
label="Model",
|
141 |
interactive=True,
|
142 |
)
|
@@ -168,7 +179,12 @@ with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
|
|
168 |
#gradcam = gr.Image(label='integraged gradients')
|
169 |
#guided_gradcam = gr.Image(label='gradcam')
|
170 |
#guided_backprop = gr.Image(label='guided backprop')
|
171 |
-
|
|
|
|
|
|
|
|
|
|
|
172 |
generate_explanations = gr.Button("Generate Explanations")
|
173 |
|
174 |
# with gr.Accordion('Closest Images'):
|
@@ -199,9 +215,9 @@ with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
|
|
199 |
|
200 |
find_closest_btn = gr.Button("Find Closest Images")
|
201 |
|
202 |
-
segment_button.click(segment_image, inputs=input_image, outputs=segmented_image)
|
203 |
classify_image_button.click(classify_image, inputs=[input_image,model_name], outputs=class_predicted)
|
204 |
-
generate_explanations.click(explain_image, inputs=[input_image,model_name], outputs=[
|
205 |
#find_closest_btn.click(find_closest, inputs=[input_image,model_name], outputs=[label_closest_image_0,label_closest_image_1,label_closest_image_2,label_closest_image_3,label_closest_image_4,closest_image_0,closest_image_1,closest_image_2,closest_image_3,closest_image_4])
|
206 |
def update_outputs(input_image,model_name):
|
207 |
labels, images = find_closest(input_image,model_name)
|
|
|
18 |
from inference_sam import segmentation_sam
|
19 |
from explanations import explain
|
20 |
from inference_resnet import get_triplet_model
|
21 |
+
from inference_beit import get_triplet_model_beit
|
22 |
import pathlib
|
23 |
import tensorflow as tf
|
24 |
from closest_sample import get_images
|
|
|
27 |
REPO_ID='Serrelab/image_examples_gradio'
|
28 |
snapshot_download(repo_id=REPO_ID, token=os.environ.get('READ_TOKEN'),repo_type='dataset',local_dir='images')
|
29 |
|
30 |
+
if not os.path.exists('dataset'):
|
31 |
+
REPO_ID='Serrelab/Fossils'
|
32 |
+
token = os.environ.get('READ_TOKEN')
|
33 |
+
print(f"Read token:{token}")
|
34 |
+
if token is None:
|
35 |
+
print("warning! A read token in env variables is needed for authentication.")
|
36 |
+
snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset')
|
37 |
+
|
38 |
def get_model(model_name):
|
39 |
|
40 |
|
|
|
54 |
backbone_class=tf.keras.applications.ResNet50V2,
|
55 |
nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2')
|
56 |
model.load_weights('model_classification/rock-170.h5')
|
57 |
+
elif model_name == 'Fossils 142':
|
58 |
+
n_classes = 142
|
59 |
+
model = get_triplet_model_beit(input_shape = (384, 384, 3),
|
60 |
+
embedding_units = 256,
|
61 |
+
embedding_depth = 2,
|
62 |
+
n_classes = n_classes)
|
63 |
+
model.load_weights('model_classification/fossil-142.h5')
|
64 |
else:
|
65 |
raise ValueError(f"Model name '{model_name}' is not recognized")
|
66 |
return model,n_classes
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
def segment_image(input_image):
|
70 |
img = segmentation_sam(input_image)
|
71 |
return img
|
72 |
|
73 |
def classify_image(input_image, model_name):
|
74 |
+
#segmented_image = segment_image(input_image)
|
75 |
if 'Rock 170' ==model_name:
|
76 |
from inference_resnet import inference_resnet_finer
|
77 |
model,n_classes= get_model(model_name)
|
|
|
82 |
model, n_classes= get_model(model_name)
|
83 |
result = inference_resnet_finer(input_image,model,size=600,n_classes=n_classes)
|
84 |
return result
|
85 |
+
if 'Fossils 142' ==model_name:
|
86 |
+
from inference_beit import inference_resnet_finer_beit
|
87 |
model,n_classes = get_model(model_name)
|
88 |
+
result = inference_resnet_finer_beit(input_image,model,size=384,n_classes=n_classes)
|
89 |
return result
|
90 |
return None
|
91 |
|
|
|
100 |
model, n_classes= get_model(model_name)
|
101 |
result = inference_resnet_embedding(input_image,model,size=600,n_classes=n_classes)
|
102 |
return result
|
103 |
+
if 'Fossils 142' ==model_name:
|
104 |
+
from inference_beit import inference_resnet_embedding_beit
|
105 |
model,n_classes = get_model(model_name)
|
106 |
+
result = inference_resnet_embedding_beit(input_image,model,size=384,n_classes=n_classes)
|
|
|
|
|
107 |
return result
|
108 |
return None
|
109 |
|
|
|
116 |
|
117 |
def explain_image(input_image,model_name):
|
118 |
model,n_classes= get_model(model_name)
|
119 |
+
if model_name=='Fossils 142':
|
120 |
+
size = 384
|
121 |
+
else:
|
122 |
+
size = 600
|
123 |
#saliency, integrated, smoothgrad,
|
124 |
+
rise,avg = explain(model,input_image,size = size, n_classes=n_classes)
|
125 |
#original = saliency + integrated + smoothgrad
|
126 |
print('done')
|
127 |
+
rise1,rise2,rise3,rise4,rise5,avg = rise[0],rise[1],rise[2],rise[3],rise[4],avg[0]
|
128 |
+
return rise1,rise2,rise3,rise4,rise5,avg
|
129 |
|
130 |
#minimalist theme
|
131 |
with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
|
|
|
137 |
input_image = gr.Image(label="Input")
|
138 |
classify_image_button = gr.Button("Classify Image")
|
139 |
|
140 |
+
# with gr.Column():
|
141 |
+
# #segmented_image = gr.outputs.Image(label="SAM output",type='numpy')
|
142 |
+
# segmented_image=gr.Image(label="Segmented Image", type='numpy')
|
143 |
+
# segment_button = gr.Button("Segment Image")
|
144 |
+
# #classify_segmented_button = gr.Button("Classify Segmented Image")
|
145 |
|
146 |
with gr.Column():
|
147 |
model_name = gr.Dropdown(
|
148 |
+
["Mummified 170", "Rock 170","Fossils 142"],
|
149 |
multiselect=False,
|
150 |
+
value="Fossils 142", # default option
|
151 |
label="Model",
|
152 |
interactive=True,
|
153 |
)
|
|
|
179 |
#gradcam = gr.Image(label='integraged gradients')
|
180 |
#guided_gradcam = gr.Image(label='gradcam')
|
181 |
#guided_backprop = gr.Image(label='guided backprop')
|
182 |
+
rise1 = gr.Image(label = 'Rise1')
|
183 |
+
rise2 = gr.Image(label = 'Rise2')
|
184 |
+
rise3 = gr.Image(label = 'Rise3')
|
185 |
+
rise4 = gr.Image(label = 'Rise4')
|
186 |
+
rise5 = gr.Image(label = 'Rise5')
|
187 |
+
avg = gr.Image(label = 'Avg')
|
188 |
generate_explanations = gr.Button("Generate Explanations")
|
189 |
|
190 |
# with gr.Accordion('Closest Images'):
|
|
|
215 |
|
216 |
find_closest_btn = gr.Button("Find Closest Images")
|
217 |
|
218 |
+
#segment_button.click(segment_image, inputs=input_image, outputs=segmented_image)
|
219 |
classify_image_button.click(classify_image, inputs=[input_image,model_name], outputs=class_predicted)
|
220 |
+
generate_explanations.click(explain_image, inputs=[input_image,model_name], outputs=[rise1,rise2,rise3,rise4,rise5,avg]) #
|
221 |
#find_closest_btn.click(find_closest, inputs=[input_image,model_name], outputs=[label_closest_image_0,label_closest_image_1,label_closest_image_2,label_closest_image_3,label_closest_image_4,closest_image_0,closest_image_1,closest_image_2,closest_image_3,closest_image_4])
|
222 |
def update_outputs(input_image,model_name):
|
223 |
labels, images = find_closest(input_image,model_name)
|
closest_sample.py
CHANGED
@@ -77,6 +77,7 @@ def get_images(embedding):
|
|
77 |
public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/General_Fossil/512/full/jpg/', folder_general)
|
78 |
else:
|
79 |
print("no match found")
|
|
|
80 |
download_public_image(public_path, local_file_path)
|
81 |
names = []
|
82 |
parts = [part for part in public_path.split('/') if part]
|
|
|
77 |
public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/General_Fossil/512/full/jpg/', folder_general)
|
78 |
else:
|
79 |
print("no match found")
|
80 |
+
print(public_path)
|
81 |
download_public_image(public_path, local_file_path)
|
82 |
names = []
|
83 |
parts = [part for part in public_path.split('/') if part]
|
env.py
CHANGED
@@ -17,6 +17,7 @@ def config_env():
|
|
17 |
('xplique', None),
|
18 |
('segment_anything', None),
|
19 |
('panopticapi', None),
|
|
|
20 |
]
|
21 |
|
22 |
name_to_command = {'segment_anything':'git+https://github.com/facebookresearch/segment-anything.git',
|
|
|
17 |
('xplique', None),
|
18 |
('segment_anything', None),
|
19 |
('panopticapi', None),
|
20 |
+
('keras_cv_attention_models',None)
|
21 |
]
|
22 |
|
23 |
name_to_command = {'segment_anything':'git+https://github.com/facebookresearch/segment-anything.git',
|
explanations.py
CHANGED
@@ -54,35 +54,46 @@ def explain(model, input_image,size=600, n_classes=171) :
|
|
54 |
#IntegratedGradients(class_model, steps=50, batch_size=BATCH_SIZE),
|
55 |
#SmoothGrad(class_model, nb_samples=50, batch_size=BATCH_SIZE),
|
56 |
#GradCAM(class_model),
|
57 |
-
Rise(class_model,nb_samples = 50, batch_size = BATCH_SIZE,grid_size=
|
58 |
preservation_probability=0.5)
|
59 |
#
|
60 |
-
|
|
|
|
|
|
|
61 |
cropped,repetitions = _clever_crop(input_image,(size,size))
|
62 |
size_repetitions = int(size//(repetitions.numpy()+1))
|
63 |
X = preprocess(cropped,size=size)
|
64 |
-
|
|
|
|
|
|
|
65 |
X = np.expand_dims(X, 0)
|
66 |
explanations = []
|
67 |
-
for
|
68 |
-
print(f'{e}/{len(explainers)}')
|
69 |
-
print('1')
|
70 |
Y = tf.one_hot([Y], n_classes)
|
|
|
71 |
phi = np.abs(explainer(X, Y))[0]
|
72 |
-
print('1')
|
73 |
if len(phi.shape) == 3:
|
74 |
phi = np.mean(phi, -1)
|
75 |
-
print('1')
|
76 |
show(X[0][:,size_repetitions:2*size_repetitions,:])
|
77 |
show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
print(
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
-
return explanations
|
|
|
54 |
#IntegratedGradients(class_model, steps=50, batch_size=BATCH_SIZE),
|
55 |
#SmoothGrad(class_model, nb_samples=50, batch_size=BATCH_SIZE),
|
56 |
#GradCAM(class_model),
|
57 |
+
Rise(class_model,nb_samples = 50, batch_size = BATCH_SIZE,grid_size=15,
|
58 |
preservation_probability=0.5)
|
59 |
#
|
60 |
+
]
|
61 |
+
explainer = Rise(class_model,nb_samples = 50, batch_size = BATCH_SIZE,grid_size=15,
|
62 |
+
preservation_probability=0.5)
|
63 |
+
|
64 |
cropped,repetitions = _clever_crop(input_image,(size,size))
|
65 |
size_repetitions = int(size//(repetitions.numpy()+1))
|
66 |
X = preprocess(cropped,size=size)
|
67 |
+
predictions = class_model.predict(np.array([X]))
|
68 |
+
#Y = np.argmax(predictions)
|
69 |
+
top_5_indices = np.argsort(predictions[0])[-5:][::-1]
|
70 |
+
#print(top_5_indices)
|
71 |
X = np.expand_dims(X, 0)
|
72 |
explanations = []
|
73 |
+
for i,Y in enumerate(top_5_indices):
|
|
|
|
|
74 |
Y = tf.one_hot([Y], n_classes)
|
75 |
+
print(f'{i}/{len(top_5_indices)}')
|
76 |
phi = np.abs(explainer(X, Y))[0]
|
|
|
77 |
if len(phi.shape) == 3:
|
78 |
phi = np.mean(phi, -1)
|
|
|
79 |
show(X[0][:,size_repetitions:2*size_repetitions,:])
|
80 |
show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
|
81 |
+
plt.savefig(f'phi_{i}.png')
|
82 |
+
explanations.append(f'phi_{i}.png')
|
83 |
+
avg=[]
|
84 |
+
for i,Y in enumerate(top_5_indices):
|
85 |
+
Y = tf.one_hot([Y], n_classes)
|
86 |
+
print(f'{i}/{len(top_5_indices)}')
|
87 |
+
phi = np.abs(explainer(X, Y))[0]
|
88 |
+
if len(phi.shape) == 3:
|
89 |
+
phi = np.mean(phi, -1)
|
90 |
+
show(X[0][:,size_repetitions:2*size_repetitions,:])
|
91 |
+
show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
|
92 |
+
plt.savefig(f'phi_6.png')
|
93 |
+
avg.append(f'phi_6.png')
|
94 |
+
|
95 |
+
print('Done')
|
96 |
+
if len(explanations)==1:
|
97 |
+
explanations = explanations[0]
|
98 |
|
99 |
+
return explanations,avg
|
fossils_paths.csv
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
inference_beit.py
CHANGED
@@ -9,195 +9,109 @@ import os
|
|
9 |
import numpy as np
|
10 |
import keras
|
11 |
from PIL import Image
|
12 |
-
import keras_cv
|
13 |
from keras_cv_attention_models import beit
|
14 |
import matplotlib.pyplot as plt
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
def pi(img, mask):
|
84 |
-
img = tf.cast(img, tf.float32)
|
85 |
-
|
86 |
-
shape = tf.shape(img)
|
87 |
-
w, h = tf.cast(shape[0], tf.int64), tf.cast(shape[1], tf.int64)
|
88 |
-
|
89 |
-
mask = smooth_mask(mask)
|
90 |
-
mask = tf.reduce_mean(mask, -1)
|
91 |
-
|
92 |
-
img = img * tf.cast(mask > 0.1, tf.float32)[:, :, None]
|
93 |
-
|
94 |
-
img_resize = tf.image.resize(img, (SIZE, SIZE), method="bicubic", antialias=True)
|
95 |
-
img_pad = tf.image.resize_with_pad(img, SIZE, SIZE, method="bicubic", antialias=True)
|
96 |
-
|
97 |
-
# building 2 anchors
|
98 |
-
anchors = tf.where(mask > 0.15)
|
99 |
-
anchor_xmin = tf.math.reduce_min(anchors[:, 0])
|
100 |
-
anchor_xmax = tf.math.reduce_max(anchors[:, 0])
|
101 |
-
anchor_ymin = tf.math.reduce_min(anchors[:, 1])
|
102 |
-
anchor_ymax = tf.math.reduce_max(anchors[:, 1])
|
103 |
-
|
104 |
-
if anchor_xmax - anchor_xmin > 50 and anchor_ymax - anchor_ymin > 50:
|
105 |
-
|
106 |
-
img_anchor_1 = resize(img[anchor_xmin:anchor_xmax, anchor_ymin:anchor_ymax])
|
107 |
-
|
108 |
-
delta_x = (anchor_xmax - anchor_xmin) // 4
|
109 |
-
delta_y = (anchor_ymax - anchor_ymin) // 4
|
110 |
-
img_anchor_2 = img[anchor_xmin+delta_x:anchor_xmax-delta_x,
|
111 |
-
anchor_ymin+delta_y:anchor_ymax-delta_y]
|
112 |
-
img_anchor_2 = resize(img_anchor_2)
|
113 |
-
else:
|
114 |
-
img_anchor_1 = img_resize
|
115 |
-
img_anchor_2 = img_pad
|
116 |
-
|
117 |
-
# building the anchors max
|
118 |
-
anchor_max = tf.where(mask == tf.math.reduce_max(mask))[0]
|
119 |
-
anchor_max_x, anchor_max_y = anchor_max[0], anchor_max[1]
|
120 |
-
|
121 |
-
img_max_zoom1 = img[tf.math.maximum(anchor_max_x-SIZE, 0): tf.math.minimum(anchor_max_x+SIZE, w),
|
122 |
-
tf.math.maximum(anchor_max_y-SIZE, 0): tf.math.minimum(anchor_max_y+SIZE, h)]
|
123 |
-
|
124 |
-
img_max_zoom1 = resize(img_max_zoom1)
|
125 |
-
img_max_zoom2 = img[anchor_max_x-SIZE//2:anchor_max_x+SIZE//2,
|
126 |
-
anchor_max_y-SIZE//2:anchor_max_y+SIZE//2]
|
127 |
-
img_max_zoom2 = img[tf.math.maximum(anchor_max_x-SIZE//2, 0): tf.math.minimum(anchor_max_x+SIZE//2, w),
|
128 |
-
tf.math.maximum(anchor_max_y-SIZE//2, 0): tf.math.minimum(anchor_max_y+SIZE//2, h)]
|
129 |
-
#tf.print(img_max_zoom2.shape)
|
130 |
-
#img_max_zoom2 = resize(img_max_zoom2)
|
131 |
-
|
132 |
-
return tf.cast(img_resize, tf.float32)
|
133 |
-
|
134 |
-
def parse_img(element, split, randaugment,maskaugment=True):
|
135 |
-
#global debug
|
136 |
-
path, class_id = element[0], element[1]
|
137 |
-
|
138 |
-
data = tf.io.read_file(path)
|
139 |
-
img = tf.io.decode_jpeg(data)
|
140 |
-
img = tf.cast(img, tf.uint8)
|
141 |
-
img = normalize(img)
|
142 |
-
shape = tf.shape(img)
|
143 |
-
|
144 |
-
# data_mask = tf.io.read_file(path_mask)
|
145 |
-
# mask = tf.io.decode_jpeg(data_mask)
|
146 |
-
|
147 |
-
class_id = tf.strings.to_number(class_id)
|
148 |
-
class_id = tf.cast(class_id, tf.int32)
|
149 |
-
|
150 |
-
label = tf.one_hot(class_id, num_classes)
|
151 |
-
|
152 |
-
# img = pi(img, mask)
|
153 |
-
img = tf.image.resize_with_pad(img, SIZE, SIZE, method="bicubic", antialias=True)
|
154 |
-
|
155 |
-
return tf.cast(img, tf.float32), tf.cast(label, tf.int32)
|
156 |
-
|
157 |
-
SIZE = 384
|
158 |
-
wsize=hsize=SIZE
|
159 |
-
def resize_images(batch_x, width=224, height=224):
|
160 |
-
return tf.image.resize(batch_x, (width, height))
|
161 |
-
|
162 |
-
def load_img(image_path,gray=False):
|
163 |
-
img = tf.io.read_file(image_path)
|
164 |
-
img = tf.image.decode_jpeg(img, channels=3)
|
165 |
-
img = tf.image.convert_image_dtype(img, tf.float32)
|
166 |
-
if gray:
|
167 |
img = tf.image.rgb_to_grayscale(img)
|
168 |
img = tf.image.grayscale_to_rgb(img)
|
169 |
-
img = tf.image.resize(img,(wsize,hsize))
|
170 |
-
return img
|
171 |
-
|
172 |
-
LR = 1e-3
|
173 |
-
|
174 |
-
optimizer = tf.keras.optimizers.Adam(LR)
|
175 |
-
cce = tf.keras.losses.categorical_crossentropy
|
176 |
-
|
177 |
-
model_path = '/content/drive/MyDrive/Gg_Fossils_data_shared_copy/Fossils/models/model-13.h5'
|
178 |
-
model = keras.models.load_model(model_path, custom_objects = {'cce': cce})
|
179 |
-
|
180 |
-
outputs = model.predict(images)
|
181 |
-
|
182 |
-
predictions = tf.math.top_k(outputs[1], k = 5)
|
183 |
-
cid = 1
|
184 |
-
dataset = np.array(dataset)
|
185 |
-
final_predictions = []
|
186 |
-
for ele in predictions[1]:
|
187 |
-
if cid in ele:
|
188 |
-
final_predictions.append(cid)
|
189 |
-
else:
|
190 |
-
final_predictions.append(cid+10)
|
191 |
-
final_predictions = np.array(final_predictions)
|
192 |
-
images2 = images[final_predictions == cid]
|
193 |
-
image2_paths = dataset[final_predictions == cid][:,0]
|
194 |
-
print(images2.shape)
|
195 |
-
|
196 |
-
def get_beit_model(input_shape, num_labels, load_weights=False, ...):
|
197 |
-
pass
|
198 |
-
|
199 |
-
def inference_dino(input_image, model_name):
|
200 |
-
pass
|
201 |
|
202 |
-
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
import numpy as np
|
10 |
import keras
|
11 |
from PIL import Image
|
|
|
12 |
from keras_cv_attention_models import beit
|
13 |
import matplotlib.pyplot as plt
|
14 |
|
15 |
+
import tensorflow as tf
|
16 |
+
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
|
17 |
+
from typing import Tuple
|
18 |
+
#from huggingface_hub import snapshot_download
|
19 |
+
from labels import lookup_140
|
20 |
+
|
21 |
+
|
22 |
+
def get_triplet_model_beit(input_shape = (600, 600, 3),
|
23 |
+
embedding_units = 256,
|
24 |
+
embedding_depth = 2,
|
25 |
+
n_classes = 19,backbone_name ='Beit'):
|
26 |
+
|
27 |
+
backbone_class = beit.BeitBasePatch16(input_shape=input_shape, pretrained="imagenet21k-ft1k")
|
28 |
+
|
29 |
+
backbone_class = tf.keras.Model(backbone_class.input, backbone_class.layers[-2].output)
|
30 |
+
#features = GlobalAveragePooling2D()(backbone_class.output)
|
31 |
+
embedding_head = backbone_class.output
|
32 |
+
|
33 |
+
for embed_i in range(embedding_depth):
|
34 |
+
embedding_head = Dense(embedding_units, activation="relu" if embed_i < embedding_depth-1 else "linear")(embedding_head)
|
35 |
+
embedding_head = tf.nn.l2_normalize(embedding_head, -1, epsilon=1e-5)
|
36 |
+
|
37 |
+
logits_head = Dense(n_classes)(backbone_class.output)
|
38 |
+
|
39 |
+
model = tf.keras.Model(backbone_class.input, [embedding_head, logits_head])
|
40 |
+
model.compile(loss='cce',metrics=['accuracy'])
|
41 |
+
#model.summary()
|
42 |
+
|
43 |
+
return model
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
load_size = 600
|
49 |
+
crop_size = 600
|
50 |
+
def _clever_crop(img: tf.Tensor,
|
51 |
+
target_size: Tuple[int]=(128,128),
|
52 |
+
grayscale: bool=False
|
53 |
+
) -> tf.Tensor:
|
54 |
+
"""[summary]
|
55 |
+
Args:
|
56 |
+
img (tf.Tensor): [description]
|
57 |
+
target_size (Tuple[int], optional): [description]. Defaults to (128,128).
|
58 |
+
grayscale (bool, optional): [description]. Defaults to False.
|
59 |
+
Returns:
|
60 |
+
tf.Tensor: [description]
|
61 |
+
"""
|
62 |
+
maxside = tf.math.maximum(tf.shape(img)[0],tf.shape(img)[1])
|
63 |
+
minside = tf.math.minimum(tf.shape(img)[0],tf.shape(img)[1])
|
64 |
+
new_img = img
|
65 |
+
|
66 |
+
if tf.math.divide(maxside,minside) > 1.2:
|
67 |
+
repeating = tf.math.floor(tf.math.divide(maxside,minside))
|
68 |
+
new_img = img
|
69 |
+
if tf.math.equal(tf.shape(img)[1],minside):
|
70 |
+
for _ in range(int(repeating)):
|
71 |
+
new_img = tf.concat((new_img, img), axis=1)
|
72 |
+
|
73 |
+
if tf.math.equal(tf.shape(img)[0],minside):
|
74 |
+
for _ in range(int(repeating)):
|
75 |
+
new_img = tf.concat((new_img, img), axis=0)
|
76 |
+
new_img = tf.image.rot90(new_img)
|
77 |
+
else:
|
78 |
+
new_img = img
|
79 |
+
repeating = 0
|
80 |
+
img = tf.image.resize(new_img, target_size)
|
81 |
+
if grayscale:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
img = tf.image.rgb_to_grayscale(img)
|
83 |
img = tf.image.grayscale_to_rgb(img)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
+
return img,repeating
|
86 |
+
|
87 |
+
def preprocess(img,size=384):
|
88 |
+
img = np.array(img, np.float32) / 255.0
|
89 |
+
img = tf.image.resize(img, (size, size))
|
90 |
+
return np.array(img, np.float32)
|
91 |
+
|
92 |
+
def select_top_n(preds,n=10):
|
93 |
+
top_n = np.argsort(preds)[-n:][::-1]
|
94 |
+
return top_n
|
95 |
+
|
96 |
+
def parse_results(top_n,logits):
|
97 |
+
results = {}
|
98 |
+
for n in top_n:
|
99 |
+
label = lookup_140[n]
|
100 |
+
results[label] = float(logits[n])
|
101 |
+
return results
|
102 |
+
|
103 |
+
def inference_resnet_embedding_beit(x,model,size=576,n_classes=142,n_top=10):
|
104 |
+
cropped = _clever_crop(x,(size,size))[0]
|
105 |
+
prep = preprocess(cropped,size=size)
|
106 |
+
embedding = model.predict(np.array([prep]))[0][0]
|
107 |
+
|
108 |
+
|
109 |
+
return embedding
|
110 |
+
|
111 |
+
def inference_resnet_finer_beit(x,model,size=576,n_classes=142,n_top=10):
|
112 |
+
cropped = _clever_crop(x,(size,size))[0]
|
113 |
+
prep = preprocess(cropped,size=size)
|
114 |
+
logits = tf.nn.softmax(model.predict(np.array([prep]))[1][0]).cpu().numpy()
|
115 |
+
top_n = select_top_n(logits,n=n_top)
|
116 |
+
|
117 |
+
return parse_results(top_n,logits)
|
labels.py
CHANGED
@@ -173,3 +173,147 @@ lookup_170 = {0: 'Anacardiaceae',
|
|
173 |
dict_lu ={}
|
174 |
for i in range(171):
|
175 |
dict_lu[i] = lookup_170[i]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
dict_lu ={}
|
174 |
for i in range(171):
|
175 |
dict_lu[i] = lookup_170[i]
|
176 |
+
|
177 |
+
|
178 |
+
lookup_140 = {0: 'Anacardiaceae',
|
179 |
+
1: 'Berberidaceae',
|
180 |
+
2: 'Betulaceae',
|
181 |
+
3: 'Cupressaceae',
|
182 |
+
4: 'Dryopteridaceae',
|
183 |
+
5: 'Fabaceae',
|
184 |
+
6: 'Fagaceae',
|
185 |
+
7: 'Juglandaceae',
|
186 |
+
8: 'Lauraceae',
|
187 |
+
9: 'Meliaceae',
|
188 |
+
10: 'Myrtaceae',
|
189 |
+
11: 'Pinaceae',
|
190 |
+
12: 'Rhamnaceae',
|
191 |
+
13: 'Rosaceae',
|
192 |
+
14: 'Salicaceae',
|
193 |
+
15: 'Sapindaceae',
|
194 |
+
16: 'Ulmaceae',
|
195 |
+
17: 'Viburnaceae',
|
196 |
+
18: 'Vitaceae',
|
197 |
+
19: 'Araceae',
|
198 |
+
20: 'Grossulariaceae',
|
199 |
+
21: 'Hydrangeaceae',
|
200 |
+
22: 'Taxaceae',
|
201 |
+
23: 'Achariaceae',
|
202 |
+
24: 'Actinidiaceae',
|
203 |
+
25: 'Altingiaceae',
|
204 |
+
26: 'Amaranthaceae',
|
205 |
+
27: 'Annonaceae',
|
206 |
+
28: 'Apiaceae',
|
207 |
+
29: 'Apocynaceae',
|
208 |
+
30: 'Aquifoliaceae',
|
209 |
+
31: 'Araliaceae',
|
210 |
+
32: 'Aristolochiaceae',
|
211 |
+
33: 'Asteraceae',
|
212 |
+
34: 'Bignoniaceae',
|
213 |
+
35: 'Boraginaceae',
|
214 |
+
36: 'Burseraceae',
|
215 |
+
37: 'Buxaceae',
|
216 |
+
38: 'Calophyllaceae',
|
217 |
+
39: 'Calycanthaceae',
|
218 |
+
40: 'Campanulaceae',
|
219 |
+
41: 'Canellaceae',
|
220 |
+
42: 'Cannabaceae',
|
221 |
+
43: 'Capparaceae',
|
222 |
+
44: 'Caprifoliaceae',
|
223 |
+
45: 'Cardiopteridaceae',
|
224 |
+
46: 'Celastraceae',
|
225 |
+
47: 'Chloranthaceae',
|
226 |
+
48: 'Chrysobalanaceae',
|
227 |
+
49: 'Clusiaceae',
|
228 |
+
50: 'Combretaceae',
|
229 |
+
51: 'Connaraceae',
|
230 |
+
52: 'Coriariaceae',
|
231 |
+
53: 'Cornaceae',
|
232 |
+
54: 'Crassulaceae',
|
233 |
+
55: 'Cucurbitaceae',
|
234 |
+
56: 'Cunoniaceae',
|
235 |
+
57: 'Dilleniaceae',
|
236 |
+
58: 'Dipterocarpaceae',
|
237 |
+
59: 'Ebenaceae',
|
238 |
+
60: 'Elaeagnaceae',
|
239 |
+
61: 'Elaeocarpaceae',
|
240 |
+
62: 'Ericaceae',
|
241 |
+
63: 'Escalloniaceae',
|
242 |
+
64: 'Euphorbiaceae',
|
243 |
+
65: 'Garryaceae',
|
244 |
+
66: 'Geraniaceae',
|
245 |
+
67: 'Gesneriaceae',
|
246 |
+
68: 'Gnetaceae',
|
247 |
+
69: 'Hamamelidaceae',
|
248 |
+
70: 'Humiriaceae',
|
249 |
+
71: 'Hypericaceae',
|
250 |
+
72: 'Icacinaceae',
|
251 |
+
73: 'Iteaceae',
|
252 |
+
74: 'Ixonanthaceae',
|
253 |
+
75: 'Lamiaceae',
|
254 |
+
76: 'Lardizabalaceae',
|
255 |
+
77: 'Lecythidaceae',
|
256 |
+
78: 'Linaceae',
|
257 |
+
79: 'Loganiaceae',
|
258 |
+
80: 'Loranthaceae',
|
259 |
+
81: 'Lythraceae',
|
260 |
+
82: 'Magnoliaceae',
|
261 |
+
83: 'Malpighiaceae',
|
262 |
+
84: 'Malvaceae',
|
263 |
+
85: 'Marantaceae',
|
264 |
+
86: 'Melastomataceae',
|
265 |
+
87: 'Menispermaceae',
|
266 |
+
88: 'Monimiaceae',
|
267 |
+
89: 'Moraceae',
|
268 |
+
90: 'Myricaceae',
|
269 |
+
91: 'Myristicaceae',
|
270 |
+
92: 'Nothofagaceae',
|
271 |
+
93: 'Nyctaginaceae',
|
272 |
+
94: 'Nyssaceae',
|
273 |
+
95: 'Ochnaceae',
|
274 |
+
96: 'Olacaceae',
|
275 |
+
97: 'Oleaceae',
|
276 |
+
98: 'Onagraceae',
|
277 |
+
99: 'Opiliaceae',
|
278 |
+
100: 'Oxalidaceae',
|
279 |
+
101: 'Paracryphiaceae',
|
280 |
+
102: 'Passifloraceae',
|
281 |
+
103: 'Penaeaceae',
|
282 |
+
104: 'Pentaphylacaceae',
|
283 |
+
105: 'Phyllanthaceae',
|
284 |
+
106: 'Phytolaccaceae',
|
285 |
+
107: 'Piperaceae',
|
286 |
+
108: 'Pittosporaceae',
|
287 |
+
109: 'Platanaceae',
|
288 |
+
110: 'Polemoniaceae',
|
289 |
+
111: 'Polygalaceae',
|
290 |
+
112: 'Polygonaceae',
|
291 |
+
113: 'Primulaceae',
|
292 |
+
114: 'Proteaceae',
|
293 |
+
115: 'Ranunculaceae',
|
294 |
+
116: 'Rhizophoraceae',
|
295 |
+
117: 'Rubiaceae',
|
296 |
+
118: 'Rutaceae',
|
297 |
+
119: 'Sabiaceae',
|
298 |
+
120: 'Santalaceae',
|
299 |
+
121: 'Sapotaceae',
|
300 |
+
122: 'Sarcolaenaceae',
|
301 |
+
123: 'Saxifragaceae',
|
302 |
+
124: 'Schisandraceae',
|
303 |
+
125: 'Scrophulariaceae',
|
304 |
+
126: 'Simaroubaceae',
|
305 |
+
127: 'Smilacaceae',
|
306 |
+
128: 'Solanaceae',
|
307 |
+
129: 'Staphyleaceae',
|
308 |
+
130: 'Stemonuraceae',
|
309 |
+
131: 'Styracaceae',
|
310 |
+
132: 'Symplocaceae',
|
311 |
+
133: 'Theaceae',
|
312 |
+
134: 'Thymelaeaceae',
|
313 |
+
135: 'Urticaceae',
|
314 |
+
136: 'Verbenaceae',
|
315 |
+
137: 'Violaceae',
|
316 |
+
138: 'Vochysiaceae',
|
317 |
+
139: 'Winteraceae',
|
318 |
+
140: 'Zygophyllaceae',
|
319 |
+
141:'Uncertain'}
|
update_csv.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import pandas as pd
|
2 |
+
|
3 |
+
# # Load the CSV file into a DataFrame
|
4 |
+
# fossils_pd = pd.read_csv('fossils_paths.csv')
|
5 |
+
|
6 |
+
# # Replace '. ' with '' (effectively removing it) in the 'file_name' column
|
7 |
+
# fossils_pd['file_name'] = fossils_pd['file_name'].str.replace('. ', '', regex=False)
|
8 |
+
|
9 |
+
# # Optional: Save the updated DataFrame back to a CSV file if needed
|
10 |
+
# fossils_pd.to_csv('fossils_paths.csv', index=False)
|