File size: 2,368 Bytes
cedeb82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c160b1e
cedeb82
 
 
f71ea6c
cedeb82
 
 
 
 
c160b1e
cedeb82
 
5852c4b
cedeb82
 
 
 
e97c0ff
cedeb82
e97c0ff
cedeb82
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import tensorflow as tf

tf.config.set_visible_devices([], 'GPU')
# gpu_devices = tf.config.experimental.list_physical_devices('GPU')
# if gpu_devices:
#     tf.config.experimental.set_memory_growth(gpu_devices[0], True)
# else:
#     print(f"TensorFlow device: {gpu_devices}")

from keras.applications import resnet
import tensorflow as tf
import keras
import os

import matplotlib.pyplot as plt
from typing import Tuple
from huggingface_hub import snapshot_download
from labels import lookup_140
import numpy as np 

if not os.path.exists('model_classification'):

    REPO_ID='Serrelab/fossil_classification_models'
    token = os.getenv('READ_TOKEN')
    print('read token:',token)
    if token is None:
        print("warning! A read token in env variables is needed for authentication.")
    snapshot_download(repo_id=REPO_ID,token=token,repo_type='model',local_dir='model_classification')

def  get_resnet_model(model_path):
    cce = tf.keras.losses.categorical_crossentropy
    model = keras.models.load_model(model_path, custom_objects = {"cce":cce})
    g = keras.Model(model.input, model.layers[2].output)
    # out = tf.keras.layers.Activation('relu')(g_.output)
    # g = tf.keras.Model(model.input, out)
    h = keras.Model(model.layers[3].input, model.layers[-1].output)
    return model, g, h


def select_top_n(preds,n=10):
    top_n = np.argsort(preds)[-n:][::-1]
    return top_n


def parse_results(top_n,logits):
    results = {}
    for n in top_n:
        label = lookup_140[n]
        results[label] = float(logits[n])
    return results 

def inference_resnet_embedding_v2(x,model,size=384,n_classes=140,n_top=10):
        
        
    x = tf.image.resize(x, (size, size))
    x = tf.reshape(x, (384, 384, 3))/255
    embedding = model.predict(np.array([x]))[0][0]
   
    
    return embedding 

def inference_resnet_finer_v2(x,model,size=384,n_classes=142,n_top=10):
    
    x = tf.image.resize(x, (size, size))
    x = tf.reshape(x, (384, 384, 3))/255
    # _, batch_logits = model.predict(x)
    # predictions = tf.math.top_k(batch_logits, k=10)
    # print(predictions)
    logits = model.predict(np.array([x]))
    print(len(logits[0][0]))
    print(logits)
    logits = tf.nn.softmax(logits[1][0]).cpu().numpy()
    print(logits)
    top_n = select_top_n(logits,n=n_top)
    print(top_n)
    
    return parse_results(top_n,logits)