Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
import torchvision.transforms as T | |
from torchgeo.trainers import SemanticSegmentationTask | |
import gradio as gr | |
from PIL import Image | |
import cv2 | |
def load_model(checkpoint_path): | |
model = SemanticSegmentationTask.load_from_checkpoint(checkpoint_path) | |
return model | |
def preprocess_image(inp): | |
compose = T.Compose([T.Resize((2048, 2048)), T.ToTensor()]) | |
inp = compose(inp).unsqueeze(0) | |
return inp | |
def predict_segmentation(model, inp): | |
with torch.no_grad(): | |
y_hat = torch.nn.Softmax2d()(model(inp)) | |
return y_hat.squeeze() | |
def overlay_prediction(input_image, prediction_tensor, alpha=0.5, threshold=0.25): | |
# Convert the prediction tensor to a PIL image and resize it to match the input image size | |
prediction_image = T.ToPILImage()(prediction_tensor[0]) | |
prediction_image = prediction_image.resize(input_image.size, resample=Image.NEAREST) | |
# Apply the cv2.COLORMAP_INFERNO colormap | |
prediction_image = cv2.applyColorMap(np.array(prediction_image), cv2.COLORMAP_INFERNO) | |
prediction_image = Image.fromarray(prediction_image).convert("RGBA") | |
overlay = Image.new("RGBA", prediction_image.size, (0, 0, 0, 0)) | |
for x in range(prediction_image.width): | |
for y in range(prediction_image.height): | |
r, g, b, a = prediction_image.getpixel((x, y)) | |
if a / 255 > threshold: | |
overlay.putpixel((x, y), (r, g, b, int(255 * alpha))) | |
combined_image = Image.alpha_composite(input_image.convert("RGBA"), overlay) | |
return combined_image.convert("RGB") | |
def predict(inp): | |
model = load_model("./unet_resnet50.ckpt") | |
# Check if a GPU is available and move the model to the GPU if possible | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
preprocessed_image = preprocess_image(inp) | |
# Move the input tensor to the GPU if available | |
preprocessed_image = preprocessed_image.to(device) | |
segmentation_result = predict_segmentation(model, preprocessed_image) | |
# Move the output tensor back to the CPU for post-processing | |
segmentation_result = segmentation_result.cpu() | |
output_image = overlay_prediction(inp, segmentation_result) | |
return output_image | |
gr.Interface( | |
fn=predict, | |
inputs=gr.inputs.Image(type="pil"), | |
outputs="image", | |
examples=["./example1.jpg", "./example2.jpg", "./example3.jpg"] | |
).launch() |