fossil_app / inference_resnet.py
andy-wyx's picture
config access token
397c08f
raw
history blame
6.38 kB
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')
# gpu_devices = tf.config.experimental.list_physical_devices('GPU')
# if gpu_devices:
# tf.config.experimental.set_memory_growth(gpu_devices[0], True)
# else:
# print(f"TensorFlow device: {gpu_devices}")
from keras.applications import resnet
import tensorflow.keras.layers as L
import os
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
import matplotlib.pyplot as plt
from typing import Tuple
from huggingface_hub import snapshot_download
from labels import lookup_170
import numpy as np
if not os.path.exists('model_classification'):
REPO_ID='Serrelab/fossil_classification_models'
token = os.environ.get('READ_TOKEN')
if token is None:
print("warning! A read token in env variables is needed for authentication.")
snapshot_download(repo_id=REPO_ID,token=token,repo_type='model',local_dir='model_classification')
def get_model(base_arch='Nasnet',weights='imagenet',input_shape=(600,600,3),classes=64500):
if base_arch == 'Nasnet':
base_model = tf.keras.applications.NASNetLarge(
input_shape=input_shape,
include_top=False,
weights=weights,
input_tensor=None,
pooling=None,
)
elif base_arch == 'Resnet50v2':
base_model = tf.keras.applications.ResNet50V2(weights=weights,
include_top=False,
pooling='avg',
input_shape=input_shape)
elif base_arch == 'Resnet50v2_finer':
base_model = tf.keras.applications.ResNet50V2(weights=weights,
include_top=False,
pooling='avg',
input_shape=input_shape)
base_model = resnet.stack2(base_model.output, 512, 2, name="conv6")
base_model = resnet.stack2(base_model, 512, 2, name="conv7")
base_model = tf.keras.Model(base_model.input,base_model)
model = tf.keras.Sequential([
base_model,
L.Dense(classes,activation='softmax')
])
model.compile(optimizer='adam',
loss='categorical_crossentropy',
)
return model
def get_triplet_model(input_shape = (600, 600, 3),
embedding_units = 256,
embedding_depth = 2,
backbone_class=tf.keras.applications.ResNet50V2,
nb_classes = 19,load_weights=False,finer_model=False,backbone_name ='Resnet50v2'):
backbone = backbone_class(input_shape=input_shape, include_top=False)
if load_weights:
model = get_model(backbone_name,input_shape=input_shape)
#TODO
model.load_weights('/users/irodri15/data/irodri15/Fossils/Models/pretrained-herbarium/Resnet50v2_NO_imagenet_None_best_1600.h5')
trw = model.layers[0].get_weights()
backbone.set_weights(trw)
if finer_model:
base_model = resnet.stack2(backbone.output, 512, 2, name="conv6")
base_model = resnet.stack2(base_model, 512, 2, name="conv7")
backbone = tf.keras.Model(backbone.input,base_model)
features = GlobalAveragePooling2D()(backbone.output)
embedding_head = features
for embed_i in range(embedding_depth):
embedding_head = Dense(embedding_units, activation="relu" if embed_i < embedding_depth-1 else "linear")(embedding_head)
embedding_head = tf.nn.l2_normalize(embedding_head, -1, epsilon=1e-5)
logits_head = Dense(nb_classes)(features)
model = tf.keras.Model(backbone.input, [embedding_head, logits_head])
model.compile(loss='cce',metrics=['accuracy'])
#model.summary()
return model
load_size = 600
crop_size = 600
def _clever_crop(img: tf.Tensor,
target_size: Tuple[int]=(128,128),
grayscale: bool=False
) -> tf.Tensor:
"""[summary]
Args:
img (tf.Tensor): [description]
target_size (Tuple[int], optional): [description]. Defaults to (128,128).
grayscale (bool, optional): [description]. Defaults to False.
Returns:
tf.Tensor: [description]
"""
maxside = tf.math.maximum(tf.shape(img)[0],tf.shape(img)[1])
minside = tf.math.minimum(tf.shape(img)[0],tf.shape(img)[1])
new_img = img
if tf.math.divide(maxside,minside) > 1.2:
repeating = tf.math.floor(tf.math.divide(maxside,minside))
new_img = img
if tf.math.equal(tf.shape(img)[1],minside):
for _ in range(int(repeating)):
new_img = tf.concat((new_img, img), axis=1)
if tf.math.equal(tf.shape(img)[0],minside):
for _ in range(int(repeating)):
new_img = tf.concat((new_img, img), axis=0)
new_img = tf.image.rot90(new_img)
else:
new_img = img
repeating = 0
img = tf.image.resize(new_img, target_size)
if grayscale:
img = tf.image.rgb_to_grayscale(img)
img = tf.image.grayscale_to_rgb(img)
return img,repeating
def preprocess(img,size=600):
img = np.array(img, np.float32) / 255.0
img = tf.image.resize(img, (size, size))
return np.array(img, np.float32)
def select_top_n(preds,n=10):
top_n = np.argsort(preds)[-n:][::-1]
return top_n
def parse_results(top_n,logits):
results = {}
for n in top_n:
label = lookup_170[n]
results[label] = float(logits[n])
return results
def inference_resnet_embedding(x,model,size=576,n_classes=170,n_top=10):
cropped = _clever_crop(x,(size,size))[0]
prep = preprocess(cropped,size=size)
embedding = model.predict(np.array([prep]))[0][0]
return embedding
def inference_resnet_finer(x,model,size=576,n_classes=170,n_top=10):
cropped = _clever_crop(x,(size,size))[0]
prep = preprocess(cropped,size=size)
logits = tf.nn.softmax(model.predict(np.array([prep]))[1][0]).cpu().numpy()
top_n = select_top_n(logits,n=n_top)
return parse_results(top_n,logits)