fossil_app / closest_sample.py
andy-wyx's picture
use florissant only for closest images
2a1e3a6
raw
history blame
6.15 kB
from sklearn.decomposition import PCA
import pickle as pk
import numpy as np
import pandas as pd
import os
from huggingface_hub import snapshot_download
import requests
import matplotlib.pyplot as plt
from collections import Counter
pca_fossils = pk.load(open('pca_fossils_170_finer.pkl','rb'))
pca_leaves = pk.load(open('pca_leaves_170_finer.pkl','rb'))
if not os.path.exists('dataset'):
REPO_ID='Serrelab/Fossils'
token = os.environ.get('READ_TOKEN')
print(f"Read token:{token}")
if token is None:
print("warning! A read token in env variables is needed for authentication.")
snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset')
embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy')
#embedding_leaves = np.load('embedding_leaves.npy')
fossils_pd= pd.read_csv('fossils_paths.csv')
def pca_distance(pca,sample,embedding,top_k):
"""
Args:
pca:fitted PCA model
sample:sample for which to find the closest embeddings
embedding:embeddings of the dataset
Returns:
The indices of the five closest embeddings to the sample
"""
s = pca.transform(sample.reshape(1,-1))
all = pca.transform(embedding[:,-1])
distances = np.linalg.norm(all - s, axis=1)
sorted_indices = np.argsort(distances)
filtered_indices = sorted_indices[sorted_indices<=2852]
return filtered_indices[:top_k]
def return_paths(argsorted,files):
paths= []
for i in argsorted:
paths.append(files[i])
return paths
def download_public_image(url, destination_path):
response = requests.get(url)
if response.status_code == 200:
with open(destination_path, 'wb') as f:
f.write(response.content)
print(f"Downloaded image to {destination_path}")
else:
print(f"Failed to download image from bucket. Status code: {response.status_code}")
def get_images(embedding):
#pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
pca_d =pca_distance(pca_fossils,embedding,embedding_fossils,top_k=5)
fossils_paths = fossils_pd['file_name'].values
paths = return_paths(pca_d,fossils_paths)
print(paths)
folder_florissant = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/Florissant_Fossil_v2.0/'
folder_general = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/General_Fossil_v2.0/'
local_paths = []
classes = []
for i, path in enumerate(paths):
local_file_path = f'image_{i}.jpg'
if 'Florissant_Fossil/512/full/jpg/' in path:
public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', folder_florissant)
elif 'General_Fossil/512/full/jpg/' in path:
public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/General_Fossil/512/full/jpg/', folder_general)
else:
print("no match found")
print(public_path)
download_public_image(public_path, local_file_path)
names = []
parts = [part for part in public_path.split('/') if part]
part = parts[-2]
classes.append(part)
local_paths.append(local_file_path)
#paths= [path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/',
# '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths]
return classes, local_paths
def get_diagram(embedding,top_k):
#pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
pca_d =pca_distance(pca_fossils,embedding,embedding_fossils,top_k=top_k)
fossils_paths = fossils_pd['file_name'].values
paths = return_paths(pca_d,fossils_paths)
#print(paths)
folder_florissant = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/Florissant_Fossil_v2.0/'
folder_general = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/General_Fossil_v2.0/'
classes = []
for i, path in enumerate(paths):
local_file_path = f'image_{i}.jpg'
if 'Florissant_Fossil/512/full/jpg/' in path:
public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', folder_florissant)
elif 'General_Fossil/512/full/jpg/' in path:
public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/General_Fossil/512/full/jpg/', folder_general)
else:
print("no match found")
print(public_path)
#download_public_image(public_path, local_file_path)
parts = [part for part in public_path.split('/') if part]
part = parts[-2]
classes.append(part)
#local_paths.append(local_file_path)
#paths= [path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/',
# '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths]
class_counts = Counter(classes)
sorted_class_counts = sorted(class_counts.items(), key=lambda item: item[1], reverse=True)
sorted_classes, sorted_frequencies = zip(*sorted_class_counts)
colors = plt.cm.viridis(np.linspace(0, 1, len(sorted_classes)))
fig, ax = plt.subplots()
ax.bar(sorted_classes, sorted_frequencies,color=colors)
ax.set_xlabel('Class Label')
ax.set_ylabel('Frequency')
ax.set_title('Distribution of Plant Family of '+str(top_k) +' Closest Samples')
ax.set_xticklabels(class_counts.keys(), rotation=45, ha='right')
# Save the diagram to a file
diagram_path = 'class_distribution_chart.png'
plt.tight_layout() # Adjust layout to make room for rotated x-axis labels
plt.savefig(diagram_path)
plt.close() # Close the figure to free up memory
return diagram_path