andy-wyx commited on
Commit
5a566ad
·
1 Parent(s): 0b77991

use corresponding embeddings for each model

Browse files
Files changed (3) hide show
  1. app.py +2 -2
  2. closest_sample.py +28 -8
  3. test.py +1 -1
app.py CHANGED
@@ -196,13 +196,13 @@ def get_embeddings(input_image,model_name):
196
 
197
  def find_closest(input_image,model_name):
198
  embedding = get_embeddings(input_image,model_name)
199
- classes, paths = get_images(embedding)
200
  #outputs = classes+paths
201
  return classes,paths
202
 
203
  def generate_diagram_closest(input_image,model_name,top_k):
204
  embedding = get_embeddings(input_image,model_name)
205
- diagram_path = get_diagram(embedding,top_k)
206
  return diagram_path
207
 
208
  def explain_image(input_image,model_name,explain_method,nb_samples):
 
196
 
197
  def find_closest(input_image,model_name):
198
  embedding = get_embeddings(input_image,model_name)
199
+ classes, paths = get_images(embedding,model_name)
200
  #outputs = classes+paths
201
  return classes,paths
202
 
203
  def generate_diagram_closest(input_image,model_name,top_k):
204
  embedding = get_embeddings(input_image,model_name)
205
+ diagram_path = get_diagram(embedding,top_k,model_name)
206
  return diagram_path
207
 
208
  def explain_image(input_image,model_name,explain_method,nb_samples):
closest_sample.py CHANGED
@@ -9,9 +9,6 @@ import matplotlib.pyplot as plt
9
  from collections import Counter
10
 
11
 
12
- pca_fossils = pk.load(open('pca_fossils_142_resnet.pkl','rb'))
13
- pca_leaves = pk.load(open('pca_leaves_142_resnet.pkl','rb'))
14
-
15
  if not os.path.exists('dataset'):
16
  REPO_ID='Serrelab/Fossils'
17
  token = os.environ.get('READ_TOKEN')
@@ -20,8 +17,6 @@ if not os.path.exists('dataset'):
20
  print("warning! A read token in env variables is needed for authentication.")
21
  snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset')
22
 
23
- embedding_fossils = np.load('dataset/embedding_leaves_142_finer.npy')
24
- #embedding_leaves = np.load('embedding_leaves.npy')
25
 
26
  fossils_pd= pd.read_csv('fossils_paths.csv')
27
 
@@ -57,8 +52,20 @@ def download_public_image(url, destination_path):
57
  else:
58
  print(f"Failed to download image from bucket. Status code: {response.status_code}")
59
 
60
- def get_images(embedding):
61
-
 
 
 
 
 
 
 
 
 
 
 
 
62
  #pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
63
 
64
  pca_d =pca_distance(pca_fossils,embedding,embedding_fossils,top_k=5)
@@ -93,7 +100,20 @@ def get_images(embedding):
93
 
94
  return classes, local_paths
95
 
96
- def get_diagram(embedding,top_k):
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  #pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
99
 
 
9
  from collections import Counter
10
 
11
 
 
 
 
12
  if not os.path.exists('dataset'):
13
  REPO_ID='Serrelab/Fossils'
14
  token = os.environ.get('READ_TOKEN')
 
17
  print("warning! A read token in env variables is needed for authentication.")
18
  snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset')
19
 
 
 
20
 
21
  fossils_pd= pd.read_csv('fossils_paths.csv')
22
 
 
52
  else:
53
  print(f"Failed to download image from bucket. Status code: {response.status_code}")
54
 
55
+ def get_images(embedding,model_name):
56
+
57
+ if model_name in ['Rock 170','Mummified 170']:
58
+ pca_fossils = pk.load(open('pca_fossils_170_finer.pkl','rb'))
59
+ pca_leaves = pk.load(open('pca_leaves_170_finer.pkl','rb'))
60
+ embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy')
61
+ #embedding_leaves = np.load('embedding_leaves.npy')
62
+ elif model_name in ['Fossils 142']:
63
+ pca_fossils = pk.load(open('pca_fossils_142_resnet.pkl','rb'))
64
+ pca_leaves = pk.load(open('pca_leaves_142_resnet.pkl','rb'))
65
+ embedding_fossils = np.load('dataset/embedding_leaves_142_finer.npy')
66
+ #embedding_leaves = np.load('embedding_leaves.npy')
67
+ else:
68
+ print(f'{model_name} not recognized')
69
  #pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
70
 
71
  pca_d =pca_distance(pca_fossils,embedding,embedding_fossils,top_k=5)
 
100
 
101
  return classes, local_paths
102
 
103
+ def get_diagram(embedding,top_k,model_name):
104
+
105
+ if model_name in ['Rock 170','Mummified 170']:
106
+ pca_fossils = pk.load(open('pca_fossils_170_finer.pkl','rb'))
107
+ pca_leaves = pk.load(open('pca_leaves_170_finer.pkl','rb'))
108
+ embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy')
109
+ #embedding_leaves = np.load('embedding_leaves.npy')
110
+ elif model_name in ['Fossils 142']:
111
+ pca_fossils = pk.load(open('pca_fossils_142_resnet.pkl','rb'))
112
+ pca_leaves = pk.load(open('pca_leaves_142_resnet.pkl','rb'))
113
+ embedding_fossils = np.load('dataset/embedding_leaves_142_finer.npy')
114
+ #embedding_leaves = np.load('embedding_leaves.npy')
115
+ else:
116
+ print(f'{model_name} not recognized')
117
 
118
  #pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
119
 
test.py CHANGED
@@ -23,7 +23,7 @@
23
  import numpy as np
24
 
25
  # Load the .npy file
26
- embedding = np.load('embedding.npy')
27
 
28
  # Check the shape of the array
29
  print(embedding.shape)
 
23
  import numpy as np
24
 
25
  # Load the .npy file
26
+ embedding = np.load('dataset/embedding_leaves_142_finer.npy')
27
 
28
  # Check the shape of the array
29
  print(embedding.shape)