Spaces:
Build error
Build error
import cv2 | |
import json | |
import gradio as gr | |
import numpy as np | |
import tensorflow as tf | |
from backbone import create_name_vit | |
from backbone import ClassificationModel | |
vit_l16_384 = { | |
"backbone_name": "vit-l/16", | |
"backbone_params": { | |
"image_size": 384, | |
"representation_size": 0, | |
"attention_dropout_rate": 0., | |
"dropout_rate": 0., | |
"channels": 3 | |
}, | |
"dropout_rate": 0., | |
"pretrained": "./weights/vit_l16_384/model-weights" | |
} | |
# Init backbone | |
backbone = create_name_vit(vit_l16_384["backbone_name"], **vit_l16_384["backbone_params"]) | |
# Init classification model | |
model = ClassificationModel( | |
backbone=backbone, | |
dropout_rate=vit_l16_384["dropout_rate"], | |
num_classes=1000 | |
) | |
# Load weights | |
model.load_weights(vit_l16_384["pretrained"]) | |
model.trainable = False | |
# Load ImageNet idx to label mapping | |
with open("assets/imagenet_1000_idx2labels.json") as f: | |
idx_to_label = json.load(f) | |
def resize_with_normalization(image, size=[384, 384]): | |
image = tf.cast(image, tf.float32) | |
image = tf.image.resize(image, size) | |
image -= tf.constant(127.5, shape=(1, 1, 3), dtype=tf.float32) | |
image /= tf.constant(127.5, shape=(1, 1, 3), dtype=tf.float32) | |
image = tf.expand_dims(image, axis=0) | |
return image | |
def softmax_stable(x): | |
return(np.exp(x - np.max(x)) / np.exp(x - np.max(x)).sum()) | |
def classify_image(img, top_k): | |
img = tf.convert_to_tensor(img) | |
img = resize_with_normalization(img) | |
pred_logits = model.predict(img, batch_size=1, workers=8)[0] | |
pred_probs = softmax_stable(pred_logits) | |
top_k_labels = pred_probs.argsort()[-top_k:][::-1] | |
return {idx_to_label[str(idx)] : round(float(pred_probs[idx]), 4) for idx in top_k_labels} | |
description = """ | |
Gradio demo for <a href="https://huggingface.co/docs/transformers/main/en/model_doc/owlvit">ViT released by Kakao Lab</a>, | |
introduced in <a href="https://arxiv.org/abs/2205.06230">Simple Open-Vocabulary Object Detection | |
with Vision Transformers</a>. | |
\n\nYou can use OWL-ViT to query images with text descriptions of any object. | |
To use it, simply upload an image and enter comma separated text descriptions of objects you want to query the image for. You | |
can also use the score threshold slider to set a threshold to filter out low probability predictions. | |
\n\nOWL-ViT is trained on text templates, | |
hence you can get better predictions by querying the image with text templates used in training the original model: *"photo of a star-spangled banner"*, | |
*"image of a shoe"*. Refer to the <a href="https://arxiv.org/abs/2103.00020">CLIP</a> paper to see the full list of text templates used to augment the training data. | |
""" | |
demo = gr.Interface( | |
classify_image, | |
inputs=[gr.Image(), gr.Slider(0, 1000, value=5)], | |
outputs=gr.outputs.Label(), | |
title="Image Classification with Kakao Brain ViT", | |
#description=description, | |
) | |
demo.launch() |