Spaces:
Sleeping
Sleeping
File size: 7,488 Bytes
1d7c63d 679611d af9c1e6 86104a0 1d7c63d 679611d 1d7c63d 1662a5d 1d7c63d 1662a5d af9c1e6 1d7c63d 2a1e3a6 1330097 b65b9a8 60841b3 1d7c63d af9c1e6 1d7c63d 5a566ad 1662a5d 5a566ad 1d7c63d 86104a0 1d7c63d af9c1e6 c5343e6 af9c1e6 c5343e6 0c61c42 c5343e6 af9c1e6 c5343e6 86104a0 5a566ad 1662a5d 5a566ad 86104a0 503ec98 86104a0 7a2259a 86104a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
from sklearn.decomposition import PCA
import pickle as pk
import numpy as np
import pandas as pd
import os
from huggingface_hub import snapshot_download
import requests
import matplotlib.pyplot as plt
from collections import Counter
if not os.path.exists('dataset'):
REPO_ID='Serrelab/Fossils'
token = os.environ.get('READ_TOKEN')
print(f"Read token:{token}")
if token is None:
print("warning! A read token in env variables is needed for authentication.")
snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset')
fossils_pd= pd.read_csv('all_fossils_filtered_100.csv')
def pca_distance(pca,sample,embedding,top_k):
"""
Args:
pca:fitted PCA model
sample:sample for which to find the closest embeddings
embedding:embeddings of the dataset
Returns:
The indices of the five closest embeddings to the sample
"""
s = pca.transform(sample.reshape(1,-1))
all = pca.transform(embedding[:,-1])
distances = np.linalg.norm(all - s, axis=1)
sorted_indices = np.argsort(distances)
filtered_indices = sorted_indices[sorted_indices<=2852] # exclude general fossils, keep florissant only.
top_indices = filtered_indices[:top_k+1] #np.concatenate([filtered_indices[:2], filtered_indices[3:top_k+1]])
return top_indices
def return_paths(argsorted,files):
paths= []
for i in argsorted:
paths.append(files[i])
return paths
def download_public_image(url, destination_path):
response = requests.get(url)
if response.status_code == 200:
with open(destination_path, 'wb') as f:
f.write(response.content)
print(f"Downloaded image to {destination_path}")
else:
print(f"Failed to download image from bucket. Status code: {response.status_code}")
def get_images(embedding,model_name):
if model_name in ['Rock 170','Mummified 170']:
pca_fossils = pk.load(open('pca_fossils_170_finer.pkl','rb'))
pca_leaves = pk.load(open('pca_leaves_170_finer.pkl','rb'))
embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy')
#embedding_leaves = np.load('embedding_leaves.npy')
elif model_name in ['Fossils 142']:
pca_fossils = pk.load(open('pca_fossils_142_resnet.pkl','rb'))
pca_leaves = pk.load(open('pca_leaves_142_resnet.pkl','rb'))
embedding_fossils = np.load('dataset/embedding_fossils_142_finer.npy')
#embedding_leaves = np.load('embedding_leaves.npy')
else:
print(f'{model_name} not recognized')
#pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
pca_d =pca_distance(pca_fossils,embedding,embedding_fossils,top_k=5)
fossils_paths = fossils_pd['file_name'].values
paths = return_paths(pca_d,fossils_paths)
print(paths)
folder_florissant = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/Florissant_Fossil_v2.0/'
folder_general = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/General_Fossil_v2.0/'
local_paths = []
classes = []
for i, path in enumerate(paths):
local_file_path = f'image_{i}.jpg'
if 'Florissant_Fossil/512/full/jpg/' in path:
public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', folder_florissant)
elif 'General_Fossil/512/full/jpg/' in path:
public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/General_Fossil/512/full/jpg/', folder_general)
else:
print("no match found")
print(public_path)
download_public_image(public_path, local_file_path)
names = []
parts = [part for part in public_path.split('/') if part]
part = parts[-2]
classes.append(part)
local_paths.append(local_file_path)
#paths= [path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/',
# '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths]
return classes, local_paths
def get_diagram(embedding,top_k,model_name):
if model_name in ['Rock 170','Mummified 170']:
pca_fossils = pk.load(open('pca_fossils_170_finer.pkl','rb'))
pca_leaves = pk.load(open('pca_leaves_170_finer.pkl','rb'))
embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy')
#embedding_leaves = np.load('embedding_leaves.npy')
elif model_name in ['Fossils 142']:
pca_fossils = pk.load(open('pca_fossils_142_resnet.pkl','rb'))
pca_leaves = pk.load(open('pca_leaves_142_resnet.pkl','rb'))
embedding_fossils = np.load('dataset/embedding_fossils_142_finer.npy')
#embedding_leaves = np.load('embedding_leaves.npy')
else:
print(f'{model_name} not recognized')
#pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
pca_d =pca_distance(pca_fossils,embedding,embedding_fossils,top_k=top_k)
fossils_paths = fossils_pd['file_name'].values
paths = return_paths(pca_d,fossils_paths)
#print(paths)
folder_florissant = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/Florissant_Fossil_v2.0/'
folder_general = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/General_Fossil_v2.0/'
classes = []
for i, path in enumerate(paths):
local_file_path = f'image_{i}.jpg'
if 'Florissant_Fossil/512/full/jpg/' in path:
public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', folder_florissant)
elif 'General_Fossil/512/full/jpg/' in path:
public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/General_Fossil/512/full/jpg/', folder_general)
else:
print("no match found")
print(public_path)
#download_public_image(public_path, local_file_path)
parts = [part for part in public_path.split('/') if part]
part = parts[-2]
classes.append(part)
#local_paths.append(local_file_path)
#paths= [path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/',
# '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths]
class_counts = Counter(classes)
sorted_class_counts = sorted(class_counts.items(), key=lambda item: item[1], reverse=True)
sorted_classes, sorted_frequencies = zip(*sorted_class_counts)
colors = plt.cm.viridis(np.linspace(0, 1, len(sorted_classes)))
fig, ax = plt.subplots()
ax.bar(sorted_classes, sorted_frequencies,color=colors)
ax.set_xlabel('Plant Family')
ax.set_ylabel('Frequency')
ax.set_title('Distribution of Plant Family of '+str(top_k) +' Closest Samples')
ax.set_xticklabels(class_counts.keys(), rotation=45, ha='right')
# Save the diagram to a file
diagram_path = 'class_distribution_chart.png'
plt.tight_layout() # Adjust layout to make room for rotated x-axis labels
plt.savefig(diagram_path)
plt.close() # Close the figure to free up memory
return diagram_path |