File size: 1,905 Bytes
832aa0d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import argparse
import gradio as gr
import torch
from PIL import Image
from torchvision.transforms import Compose, ConvertImageDtype, Normalize, PILToTensor, Resize
from torchvision.transforms.functional import InterpolationMode
from holocron import models
def main(args):
model = models.rexnet1_3x(pretrained=True).eval()
preprocessor = Compose([
Resize(model.default_cfg['input_shape'][1:], interpolation=InterpolationMode.BILINEAR),
PILToTensor(),
ConvertImageDtype(torch.float32),
Normalize(model.default_cfg['mean'], model.default_cfg['std'])
])
def predict(input):
input = Image.fromarray(input.astype('uint8'), 'RGB')
input = preprocessor(input)
with torch.inference_mode():
prediction = torch.nn.functional.softmax(model(input.unsqueeze(0))[0], dim=0)
return {class_name: float(conf) for class_name, conf in zip(model.default_cfg['classes'], prediction)}
image = gr.inputs.Image()
outputs = gr.outputs.Label(num_top_classes=3)
interface = gr.Interface(
fn=predict,
inputs=[image],
outputs=outputs,
title="Holocron: image classification demo",
article=("<p style='text-align: center'><a href='https://github.com/frgfm/Holocron'>" "Github Repo</a> | "
"<a href='https://frgfm.github.io/Holocron/'>Documentation</a></p>"),
live=True,
theme="huggingface",
layout="horizontal",
)
interface.launch(server_port=args.port, show_error=True)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Holocron image classification demo',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--port", type=int, default=8001, help="Port on which the webserver will be run")
args = parser.parse_args()
main(args)
|