Spaces:
Runtime error
Runtime error
import os | |
import numpy as np | |
import codecs | |
import torch | |
import torchvision.transforms as transforms | |
import gradio as gr | |
from PIL import Image | |
from unetplusplus import NestedUNet | |
torch.manual_seed(0) | |
if torch.cuda.is_available(): | |
torch.backends.cudnn.deterministic = True | |
# Device | |
DEVICE = "cpu" | |
print(DEVICE) | |
# Load color map | |
cmap = np.load("cmap.npy") | |
# Make directories | |
os.system("mkdir ./models") | |
# Get model weights | |
if not os.path.exists("./models/masksupnyu39.31d.pth"): | |
os.system( | |
"wget -O ./models/masksupnyu39.31d.pth https://github.com/hasibzunair/masksup-segmentation/releases/download/v0.1/masksupnyu39.31iou.pth" | |
) | |
# Load model | |
model = NestedUNet(num_classes=40) | |
checkpoint = torch.load( | |
"./models/masksupnyu39.31d.pth", map_location=torch.device("cpu") | |
) | |
model.load_state_dict(checkpoint) | |
model = model.to(DEVICE) | |
model.eval() | |
# Main inference function | |
def inference(img_path): | |
image = Image.open(img_path).convert("RGB") | |
transforms_image = transforms.Compose( | |
[ | |
transforms.Resize((224, 224)), | |
transforms.CenterCrop((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
] | |
) | |
image = transforms_image(image) | |
image = image[None, :] # batch dimension | |
# Predict | |
with torch.no_grad(): | |
output = torch.sigmoid(model(image.to(DEVICE).float())) | |
output = ( | |
torch.softmax(output, dim=1) | |
.argmax(dim=1)[0] | |
.float() | |
.cpu() | |
.numpy() | |
.astype(np.uint8) | |
) | |
pred = cmap[output] | |
return pred | |
# App | |
title = "Masked Supervised Learning for Semantic Segmentation" | |
description = codecs.open("description.html", "r", "utf-8").read() | |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2210.00923' target='_blank'>Masked Supervised Learning for Semantic Segmentation</a> | <a href='https://github.com/hasibzunair/masksup-segmentation' target='_blank'>Github</a></p>" | |
gr.Interface( | |
inference, | |
gr.inputs.Image(type="filepath", label="Input Image"), | |
gr.outputs.Image(type="numpy", label="Predicted Output"), | |
examples=[ | |
"./sample_images/a.png", | |
"./sample_images/b.png", | |
"./sample_images/c.png", | |
"./sample_images/d.png", | |
], | |
title=title, | |
description=description, | |
article=article, | |
allow_flagging=False, | |
analytics_enabled=False, | |
).launch(debug=True, enable_queue=True) | |