import cv2
import json
import gradio as gr
import numpy as np
import tensorflow as tf
from backbone import create_name_vit
from backbone import ClassificationModel
vit_l16_384 = {
"backbone_name": "vit-l/16",
"backbone_params": {
"image_size": 384,
"representation_size": 0,
"attention_dropout_rate": 0.,
"dropout_rate": 0.,
"channels": 3
},
"dropout_rate": 0.,
"pretrained": "./weights/vit_l16_384/model-weights"
}
# Init backbone
backbone = create_name_vit(vit_l16_384["backbone_name"], **vit_l16_384["backbone_params"])
# Init classification model
model = ClassificationModel(
backbone=backbone,
dropout_rate=vit_l16_384["dropout_rate"],
num_classes=1000
)
# Load weights
model.load_weights(vit_l16_384["pretrained"])
model.trainable = False
# Load ImageNet idx to label mapping
with open("assets/imagenet_1000_idx2labels.json") as f:
idx_to_label = json.load(f)
def resize_with_normalization(image, size=[384, 384]):
image = tf.cast(image, tf.float32)
image = tf.image.resize(image, size)
image -= tf.constant(127.5, shape=(1, 1, 3), dtype=tf.float32)
image /= tf.constant(127.5, shape=(1, 1, 3), dtype=tf.float32)
image = tf.expand_dims(image, axis=0)
return image
def softmax_stable(x):
return(np.exp(x - np.max(x)) / np.exp(x - np.max(x)).sum())
def classify_image(img, top_k):
img = tf.convert_to_tensor(img)
img = resize_with_normalization(img)
pred_logits = model.predict(img, batch_size=1, workers=8)[0]
pred_probs = softmax_stable(pred_logits)
top_k_labels = pred_probs.argsort()[-top_k:][::-1]
return {idx_to_label[str(idx)] : round(float(pred_probs[idx]), 4) for idx in top_k_labels}
description = """
Gradio demo for ViT released by Kakao Lab,
introduced in Simple Open-Vocabulary Object Detection
with Vision Transformers.
\n\nYou can use OWL-ViT to query images with text descriptions of any object.
To use it, simply upload an image and enter comma separated text descriptions of objects you want to query the image for. You
can also use the score threshold slider to set a threshold to filter out low probability predictions.
\n\nOWL-ViT is trained on text templates,
hence you can get better predictions by querying the image with text templates used in training the original model: *"photo of a star-spangled banner"*,
*"image of a shoe"*. Refer to the CLIP paper to see the full list of text templates used to augment the training data.
"""
demo = gr.Interface(
classify_image,
inputs=[gr.Image(), gr.Slider(0, 1000, value=5)],
outputs=gr.outputs.Label(),
title="Image Classification with Kakao Brain ViT",
#description=description,
)
demo.launch()