import torch import numpy as np import torchvision.transforms as T from torchgeo.trainers import SemanticSegmentationTask import gradio as gr from PIL import Image import cv2 def load_model(checkpoint_path): model = SemanticSegmentationTask.load_from_checkpoint(checkpoint_path) return model def preprocess_image(inp): compose = T.Compose([T.Resize((2048, 2048)), T.ToTensor()]) inp = compose(inp).unsqueeze(0) return inp def predict_segmentation(model, inp): with torch.no_grad(): y_hat = torch.nn.Softmax2d()(model(inp)) return y_hat.squeeze() def overlay_prediction(input_image, prediction_tensor, alpha=0.5, threshold=0.25): # Convert the prediction tensor to a PIL image and resize it to match the input image size prediction_image = T.ToPILImage()(prediction_tensor[0]) prediction_image = prediction_image.resize(input_image.size, resample=Image.NEAREST) # Apply the cv2.COLORMAP_INFERNO colormap prediction_image = cv2.applyColorMap(np.array(prediction_image), cv2.COLORMAP_INFERNO) prediction_image = Image.fromarray(prediction_image).convert("RGBA") overlay = Image.new("RGBA", prediction_image.size, (0, 0, 0, 0)) for x in range(prediction_image.width): for y in range(prediction_image.height): r, g, b, a = prediction_image.getpixel((x, y)) if a / 255 > threshold: overlay.putpixel((x, y), (r, g, b, int(255 * alpha))) combined_image = Image.alpha_composite(input_image.convert("RGBA"), overlay) return combined_image.convert("RGB") def predict(inp): model = load_model("./unet_resnet50.ckpt") # Check if a GPU is available and move the model to the GPU if possible device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) preprocessed_image = preprocess_image(inp) # Move the input tensor to the GPU if available preprocessed_image = preprocessed_image.to(device) segmentation_result = predict_segmentation(model, preprocessed_image) # Move the output tensor back to the CPU for post-processing segmentation_result = segmentation_result.cpu() output_image = overlay_prediction(inp, segmentation_result) return output_image gr.Interface( fn=predict, inputs=gr.inputs.Image(type="pil"), outputs="image", examples=["./example1.jpg", "./example2.jpg", "./example3.jpg"] ).launch()