from sklearn.decomposition import PCA import pickle as pk import numpy as np import pandas as pd import os from huggingface_hub import snapshot_download import requests pca_fossils = pk.load(open('pca_fossils_170_finer.pkl','rb')) pca_leaves = pk.load(open('pca_leaves_170_finer.pkl','rb')) if not os.path.exists('dataset'): REPO_ID='Serrelab/Fossils' token = os.environ.get('READ_TOKEN') print(f"Read token:{token}") if token is None: print("warning! A read token in env variables is needed for authentication.") snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset') embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy') #embedding_leaves = np.load('embedding_leaves.npy') fossils_pd= pd.read_csv('fossils_paths.csv') def pca_distance(pca,sample,embedding): """ Args: pca:fitted PCA model sample:sample for which to find the closest embeddings embedding:embeddings of the dataset Returns: The indices of the five closest embeddings to the sample """ s = pca.transform(sample.reshape(1,-1)) all = pca.transform(embedding[:,-1]) distances = np.linalg.norm(all - s, axis=1) #print(distances) return np.argsort(distances)[:5] def return_paths(argsorted,files): paths= [] for i in argsorted: paths.append(files[i]) return paths def download_public_image(url, destination_path): response = requests.get(url) if response.status_code == 200: with open(destination_path, 'wb') as f: f.write(response.content) print(f"Downloaded image to {destination_path}") else: print(f"Failed to download image from bucket. Status code: {response.status_code}") def get_images(embedding): #pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1]) pca_d =pca_distance(pca_fossils,embedding,embedding_fossils) fossils_paths = fossils_pd['file_name'].values paths = return_paths(pca_d,fossils_paths) print(paths) folder_florissant = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/Florissant_Fossil_v2.0/' folder_general = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/General_Fossil_v2.0/' local_paths = [] classes = [] for i, path in enumerate(paths): local_file_path = f'image_{i}.jpg' if 'Florissant_Fossil/512/full/jpg/' in path: public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', folder_florissant) elif 'General_Fossil/512/full/jpg/' in path: public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/General_Fossil/512/full/jpg/', folder_general) else: print("no match found") print(public_path) download_public_image(public_path, local_file_path) names = [] parts = [part for part in public_path.split('/') if part] part = parts[-2] classes.append(part) local_paths.append(local_file_path) #paths= [path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', # '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths] return classes, local_paths