File size: 6,147 Bytes
1d7c63d
 
 
 
679611d
 
af9c1e6
86104a0
 
1d7c63d
 
 
 
679611d
 
 
 
 
 
 
 
 
 
1d7c63d
 
 
 
86104a0
af9c1e6
 
 
 
 
 
 
 
1d7c63d
 
 
2a1e3a6
 
 
1d7c63d
 
 
 
 
 
 
af9c1e6
 
 
 
 
 
 
 
1d7c63d
 
 
 
 
86104a0
1d7c63d
 
 
 
 
af9c1e6
 
 
 
c5343e6
 
af9c1e6
 
c5343e6
 
 
 
 
 
0c61c42
c5343e6
 
 
 
 
 
af9c1e6
 
c5343e6
 
86104a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a2259a
86104a0
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from sklearn.decomposition import PCA
import pickle as pk
import numpy as np
import pandas as pd
import os
from huggingface_hub import snapshot_download
import requests
import matplotlib.pyplot as plt
from collections import Counter


pca_fossils = pk.load(open('pca_fossils_170_finer.pkl','rb'))
pca_leaves = pk.load(open('pca_leaves_170_finer.pkl','rb'))

if not os.path.exists('dataset'):
  REPO_ID='Serrelab/Fossils'
  token = os.environ.get('READ_TOKEN')
  print(f"Read token:{token}")
  if token is None:
     print("warning! A read token in env variables is needed for authentication.")
  snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset')

embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy')
#embedding_leaves = np.load('embedding_leaves.npy')

fossils_pd= pd.read_csv('fossils_paths.csv')

def pca_distance(pca,sample,embedding,top_k):
    """
    Args:
        pca:fitted PCA model
        sample:sample for which to find the closest embeddings
        embedding:embeddings of the dataset
    Returns:
        The indices of the five closest embeddings to the sample
    """
    s = pca.transform(sample.reshape(1,-1))
    all = pca.transform(embedding[:,-1])
    distances = np.linalg.norm(all - s, axis=1)
    sorted_indices = np.argsort(distances)
    filtered_indices = sorted_indices[sorted_indices<=2852]
    return filtered_indices[:top_k]

def return_paths(argsorted,files):
    paths= []
    for i in argsorted:
        paths.append(files[i])
    return paths

def download_public_image(url, destination_path):
    response = requests.get(url)
    if response.status_code == 200:
        with open(destination_path, 'wb') as f:
            f.write(response.content)
        print(f"Downloaded image to {destination_path}")
    else:
        print(f"Failed to download image from bucket. Status code: {response.status_code}")

def get_images(embedding):
    
    #pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
    
    pca_d =pca_distance(pca_fossils,embedding,embedding_fossils,top_k=5)
    
    fossils_paths = fossils_pd['file_name'].values
    
    paths = return_paths(pca_d,fossils_paths)
    print(paths)

    folder_florissant = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/Florissant_Fossil_v2.0/'
    folder_general = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/General_Fossil_v2.0/'
    
    local_paths = []
    classes = []
    for i, path in enumerate(paths):
        local_file_path = f'image_{i}.jpg'
        if 'Florissant_Fossil/512/full/jpg/' in path:
            public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', folder_florissant)
        elif 'General_Fossil/512/full/jpg/' in path:
            public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/General_Fossil/512/full/jpg/', folder_general)
        else:
            print("no match found")
        print(public_path)
        download_public_image(public_path, local_file_path)
        names = []
        parts = [part for part in public_path.split('/') if part]
        part = parts[-2]
        classes.append(part)
        local_paths.append(local_file_path)
    #paths= [path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/',
    #                     '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths]

    return classes, local_paths

def get_diagram(embedding,top_k):
    
    #pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
    
    pca_d =pca_distance(pca_fossils,embedding,embedding_fossils,top_k=top_k)
    
    fossils_paths = fossils_pd['file_name'].values
    
    paths = return_paths(pca_d,fossils_paths)
    #print(paths)

    folder_florissant = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/Florissant_Fossil_v2.0/'
    folder_general = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/General_Fossil_v2.0/'
    
    classes = []
    for i, path in enumerate(paths):
        local_file_path = f'image_{i}.jpg'
        if 'Florissant_Fossil/512/full/jpg/' in path:
            public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', folder_florissant)
        elif 'General_Fossil/512/full/jpg/' in path:
            public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/General_Fossil/512/full/jpg/', folder_general)
        else:
            print("no match found")
        print(public_path)
        #download_public_image(public_path, local_file_path)
        parts = [part for part in public_path.split('/') if part]
        part = parts[-2]
        classes.append(part)
        #local_paths.append(local_file_path)
    #paths= [path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/',
    #                     '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths]
    class_counts = Counter(classes)

    sorted_class_counts = sorted(class_counts.items(), key=lambda item: item[1], reverse=True)
    sorted_classes, sorted_frequencies = zip(*sorted_class_counts)
    colors = plt.cm.viridis(np.linspace(0, 1, len(sorted_classes)))
    fig, ax = plt.subplots()
    ax.bar(sorted_classes, sorted_frequencies,color=colors)
    ax.set_xlabel('Class Label')
    ax.set_ylabel('Frequency')
    ax.set_title('Distribution of Plant Family of '+str(top_k) +' Closest Samples')
    ax.set_xticklabels(class_counts.keys(), rotation=45, ha='right')

    # Save the diagram to a file
    diagram_path = 'class_distribution_chart.png'
    plt.tight_layout()  # Adjust layout to make room for rotated x-axis labels
    plt.savefig(diagram_path)
    plt.close()  # Close the figure to free up memory

    return diagram_path