mnist / app.py
gaviego's picture
Intial
0d94b00
raw
history blame
661 Bytes
import gradio as gr
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import model
net = torch.load('mnist.pth')
net.eval()
def predict(img):
arr = np.array(img) / 255 # Assuming img is in the range [0, 255]
arr = np.expand_dims(arr, axis=0) # Add batch dimension
arr = torch.from_numpy(arr).float() # Convert to PyTorch tensor
output = net(arr)
topk_values, topk_indices = torch.topk(output, 2) # Get the top 2 classes
return [str(k) for k in topk_indices[0].tolist()]
sp = gr.Sketchpad(shape=(28, 28))
gr.Interface(fn=predict,
inputs=sp,
outputs=['label','label']).launch()