Spaces:
Build error
Build error
import cv2 | |
import json | |
import gradio as gr | |
import numpy as np | |
import tensorflow as tf | |
from backbone import create_name_vit | |
from backbone import ClassificationModel | |
vit_l16_512 = { | |
"backbone_name": "vit-l/16", | |
"backbone_params": { | |
"image_size": 512, | |
"representation_size": 0, | |
"attention_dropout_rate": 0., | |
"dropout_rate": 0., | |
"channels": 3 | |
}, | |
"dropout_rate": 0., | |
"pretrained": "./weights/vit_l16_512/model-weights" | |
} | |
# Init backbone | |
backbone = create_name_vit(vit_l16_512["backbone_name"], **vit_l16_512["backbone_params"]) | |
# Init classification model | |
model = ClassificationModel( | |
backbone=backbone, | |
dropout_rate=vit_l16_512["dropout_rate"], | |
num_classes=1000 | |
) | |
# Load weights | |
model.load_weights(vit_l16_512["pretrained"]) | |
model.trainable = False | |
# Load ImageNet idx to label mapping | |
with open("assets/imagenet_1000_idx2labels.json") as f: | |
idx_to_label = json.load(f) | |
def resize_with_normalization(image, size=[512, 512]): | |
image = tf.cast(image, tf.float32) | |
image = tf.image.resize(image, size) | |
image -= tf.constant(127.5, shape=(1, 1, 3), dtype=tf.float32) | |
image /= tf.constant(127.5, shape=(1, 1, 3), dtype=tf.float32) | |
image = tf.expand_dims(image, axis=0) | |
return image | |
def softmax_stable(x): | |
return(np.exp(x - np.max(x)) / np.exp(x - np.max(x)).sum()) | |
def classify_image(img, top_k): | |
img = tf.convert_to_tensor(img) | |
img = resize_with_normalization(img) | |
pred_logits = model.predict(img, batch_size=1, workers=8)[0] | |
pred_probs = softmax_stable(pred_logits) | |
top_k_labels = pred_probs.argsort()[-top_k:][::-1] | |
return {idx_to_label[str(idx)] : round(float(pred_probs[idx]), 4) for idx in top_k_labels} | |
demo = gr.Interface( | |
classify_image, | |
inputs=[gr.Image(), gr.Slider(0, 1000, value=5)], | |
outputs=gr.outputs.Label(), | |
title="Image Classification with Kakao Brain ViT", | |
examples=[ | |
["assets/halloween-gaf8ad7ebc_1920.jpeg", 5], | |
["assets/IMG_4484.jpeg", 5], | |
["assets/IMG_4737.jpeg", 5], | |
["assets/IMG_4740.jpeg", 5], | |
], | |
) | |
demo.launch() |