File size: 7,488 Bytes
1d7c63d
 
 
 
679611d
 
af9c1e6
86104a0
 
1d7c63d
 
679611d
 
 
 
 
 
 
 
1d7c63d
1662a5d
1d7c63d
1662a5d
af9c1e6
 
 
 
 
 
 
 
1d7c63d
 
 
2a1e3a6
1330097
b65b9a8
60841b3
1d7c63d
 
 
 
 
 
 
af9c1e6
 
 
 
 
 
 
 
1d7c63d
5a566ad
 
 
 
 
 
 
 
 
 
1662a5d
5a566ad
 
 
1d7c63d
 
86104a0
1d7c63d
 
 
 
 
af9c1e6
 
 
 
c5343e6
 
af9c1e6
 
c5343e6
 
 
 
 
 
0c61c42
c5343e6
 
 
 
 
 
af9c1e6
 
c5343e6
 
86104a0
5a566ad
 
 
 
 
 
 
 
 
 
1662a5d
5a566ad
 
 
86104a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503ec98
86104a0
7a2259a
86104a0
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from sklearn.decomposition import PCA
import pickle as pk
import numpy as np
import pandas as pd
import os
from huggingface_hub import snapshot_download
import requests
import matplotlib.pyplot as plt
from collections import Counter


if not os.path.exists('dataset'):
  REPO_ID='Serrelab/Fossils'
  token = os.environ.get('READ_TOKEN')
  print(f"Read token:{token}")
  if token is None:
     print("warning! A read token in env variables is needed for authentication.")
  snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset')


fossils_pd= pd.read_csv('all_fossils_filtered_100.csv')

def pca_distance(pca,sample,embedding,top_k):
    """
    Args:
        pca:fitted PCA model
        sample:sample for which to find the closest embeddings
        embedding:embeddings of the dataset
    Returns:
        The indices of the five closest embeddings to the sample
    """
    s = pca.transform(sample.reshape(1,-1))
    all = pca.transform(embedding[:,-1])
    distances = np.linalg.norm(all - s, axis=1)
    sorted_indices = np.argsort(distances)
    filtered_indices = sorted_indices[sorted_indices<=2852] # exclude general fossils, keep florissant only.
    top_indices = filtered_indices[:top_k+1] #np.concatenate([filtered_indices[:2], filtered_indices[3:top_k+1]])
    return top_indices

def return_paths(argsorted,files):
    paths= []
    for i in argsorted:
        paths.append(files[i])
    return paths

def download_public_image(url, destination_path):
    response = requests.get(url)
    if response.status_code == 200:
        with open(destination_path, 'wb') as f:
            f.write(response.content)
        print(f"Downloaded image to {destination_path}")
    else:
        print(f"Failed to download image from bucket. Status code: {response.status_code}")

def get_images(embedding,model_name):

    if model_name in ['Rock 170','Mummified 170']:
        pca_fossils = pk.load(open('pca_fossils_170_finer.pkl','rb'))
        pca_leaves = pk.load(open('pca_leaves_170_finer.pkl','rb'))
        embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy')
        #embedding_leaves = np.load('embedding_leaves.npy')
    elif model_name in ['Fossils 142']:
        pca_fossils = pk.load(open('pca_fossils_142_resnet.pkl','rb'))
        pca_leaves = pk.load(open('pca_leaves_142_resnet.pkl','rb'))
        embedding_fossils = np.load('dataset/embedding_fossils_142_finer.npy')
        #embedding_leaves = np.load('embedding_leaves.npy')
    else:
        print(f'{model_name} not recognized')
    #pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
    
    pca_d =pca_distance(pca_fossils,embedding,embedding_fossils,top_k=5)
    
    fossils_paths = fossils_pd['file_name'].values
    
    paths = return_paths(pca_d,fossils_paths)
    print(paths)

    folder_florissant = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/Florissant_Fossil_v2.0/'
    folder_general = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/General_Fossil_v2.0/'
    
    local_paths = []
    classes = []
    for i, path in enumerate(paths):
        local_file_path = f'image_{i}.jpg'
        if 'Florissant_Fossil/512/full/jpg/' in path:
            public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', folder_florissant)
        elif 'General_Fossil/512/full/jpg/' in path:
            public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/General_Fossil/512/full/jpg/', folder_general)
        else:
            print("no match found")
        print(public_path)
        download_public_image(public_path, local_file_path)
        names = []
        parts = [part for part in public_path.split('/') if part]
        part = parts[-2]
        classes.append(part)
        local_paths.append(local_file_path)
    #paths= [path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/',
    #                     '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths]

    return classes, local_paths

def get_diagram(embedding,top_k,model_name):

    if model_name in ['Rock 170','Mummified 170']:
        pca_fossils = pk.load(open('pca_fossils_170_finer.pkl','rb'))
        pca_leaves = pk.load(open('pca_leaves_170_finer.pkl','rb'))
        embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy')
        #embedding_leaves = np.load('embedding_leaves.npy')
    elif model_name in ['Fossils 142']:
        pca_fossils = pk.load(open('pca_fossils_142_resnet.pkl','rb'))
        pca_leaves = pk.load(open('pca_leaves_142_resnet.pkl','rb'))
        embedding_fossils = np.load('dataset/embedding_fossils_142_finer.npy')
        #embedding_leaves = np.load('embedding_leaves.npy')
    else:
        print(f'{model_name} not recognized')
    
    #pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
    
    pca_d =pca_distance(pca_fossils,embedding,embedding_fossils,top_k=top_k)
    
    fossils_paths = fossils_pd['file_name'].values
    
    paths = return_paths(pca_d,fossils_paths)
    #print(paths)

    folder_florissant = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/Florissant_Fossil_v2.0/'
    folder_general = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/General_Fossil_v2.0/'
    
    classes = []
    for i, path in enumerate(paths):
        local_file_path = f'image_{i}.jpg'
        if 'Florissant_Fossil/512/full/jpg/' in path:
            public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', folder_florissant)
        elif 'General_Fossil/512/full/jpg/' in path:
            public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/General_Fossil/512/full/jpg/', folder_general)
        else:
            print("no match found")
        print(public_path)
        #download_public_image(public_path, local_file_path)
        parts = [part for part in public_path.split('/') if part]
        part = parts[-2]
        classes.append(part)
        #local_paths.append(local_file_path)
    #paths= [path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/',
    #                     '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths]
    class_counts = Counter(classes)

    sorted_class_counts = sorted(class_counts.items(), key=lambda item: item[1], reverse=True)
    sorted_classes, sorted_frequencies = zip(*sorted_class_counts)
    colors = plt.cm.viridis(np.linspace(0, 1, len(sorted_classes)))
    fig, ax = plt.subplots()
    ax.bar(sorted_classes, sorted_frequencies,color=colors)
    ax.set_xlabel('Plant Family')
    ax.set_ylabel('Frequency')
    ax.set_title('Distribution of Plant Family of '+str(top_k) +' Closest Samples')
    ax.set_xticklabels(class_counts.keys(), rotation=45, ha='right')

    # Save the diagram to a file
    diagram_path = 'class_distribution_chart.png'
    plt.tight_layout()  # Adjust layout to make room for rotated x-axis labels
    plt.savefig(diagram_path)
    plt.close()  # Close the figure to free up memory

    return diagram_path