import gradio as gr | |
from PIL import Image | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import model | |
net = torch.load('mnist.pth') | |
net.eval() | |
def predict(img): | |
arr = np.array(img) / 255 # Assuming img is in the range [0, 255] | |
arr = np.expand_dims(arr, axis=0) # Add batch dimension | |
arr = torch.from_numpy(arr).float() # Convert to PyTorch tensor | |
output = net(arr) | |
topk_values, topk_indices = torch.topk(output, 2) # Get the top 2 classes | |
return [str(k) for k in topk_indices[0].tolist()] | |
sp = gr.Sketchpad(shape=(28, 28)) | |
gr.Interface(fn=predict, | |
inputs=sp, | |
outputs=['label','label']).launch() | |