|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import json |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import onnxruntime |
|
from huggingface_hub import hf_hub_download |
|
from PIL import Image |
|
|
|
|
|
REPO = "frgfm/rexnet1_0x" |
|
|
|
|
|
with open(hf_hub_download(REPO, filename="config.json"), "rb") as f: |
|
cfg = json.load(f) |
|
|
|
ort_session = onnxruntime.InferenceSession(hf_hub_download(REPO, filename="model.onnx")) |
|
|
|
def preprocess_image(pil_img: Image.Image) -> np.ndarray: |
|
"""Preprocess an image for inference |
|
|
|
Args: |
|
pil_img: a valid pillow image |
|
|
|
Returns: |
|
the resized and normalized image of shape (1, C, H, W) |
|
""" |
|
|
|
|
|
img = pil_img.resize(cfg["input_shape"][-2:][::-1], Image.BILINEAR) |
|
|
|
img = np.asarray(img).transpose((2, 0, 1)).astype(np.float32) / 255 |
|
|
|
img -= np.array(cfg["mean"])[:, None, None] |
|
img /= np.array(cfg["std"])[:, None, None] |
|
|
|
return img[None, ...] |
|
|
|
def predict(image): |
|
|
|
np_img = preprocess_image(image) |
|
ort_input = {ort_session.get_inputs()[0].name: np_img} |
|
|
|
|
|
ort_out = ort_session.run(None, ort_input) |
|
|
|
out_exp = np.exp(ort_out[0][0]) |
|
probs = out_exp / out_exp.sum() |
|
|
|
return {class_name: float(conf) for class_name, conf in zip(cfg["classes"], probs)} |
|
|
|
img = gr.inputs.Image(type="pil") |
|
outputs = gr.outputs.Label(num_top_classes=3) |
|
|
|
gr.Interface( |
|
fn=predict, |
|
inputs=[img], |
|
outputs=outputs, |
|
title="Holocron: image classification demo", |
|
article=( |
|
"<p style='text-align: center'><a href='https://github.com/frgfm/Holocron'>" |
|
"Github Repo</a> | " |
|
"<a href='https://frgfm.github.io/Holocron/'>Documentation</a></p>" |
|
), |
|
live=True, |
|
).launch() |
|
|