gainforest's picture
Duplicate from gainforest/tree-crown-delineation
adaa21e
import torch
import numpy as np
import torchvision.transforms as T
from torchgeo.trainers import SemanticSegmentationTask
import gradio as gr
from PIL import Image
import cv2
def load_model(checkpoint_path):
model = SemanticSegmentationTask.load_from_checkpoint(checkpoint_path)
return model
def preprocess_image(inp):
compose = T.Compose([T.Resize((2048, 2048)), T.ToTensor()])
inp = compose(inp).unsqueeze(0)
return inp
def predict_segmentation(model, inp):
with torch.no_grad():
y_hat = torch.nn.Softmax2d()(model(inp))
return y_hat.squeeze()
def overlay_prediction(input_image, prediction_tensor, alpha=0.5, threshold=0.25):
# Convert the prediction tensor to a PIL image and resize it to match the input image size
prediction_image = T.ToPILImage()(prediction_tensor[0])
prediction_image = prediction_image.resize(input_image.size, resample=Image.NEAREST)
# Apply the cv2.COLORMAP_INFERNO colormap
prediction_image = cv2.applyColorMap(np.array(prediction_image), cv2.COLORMAP_INFERNO)
prediction_image = Image.fromarray(prediction_image).convert("RGBA")
overlay = Image.new("RGBA", prediction_image.size, (0, 0, 0, 0))
for x in range(prediction_image.width):
for y in range(prediction_image.height):
r, g, b, a = prediction_image.getpixel((x, y))
if a / 255 > threshold:
overlay.putpixel((x, y), (r, g, b, int(255 * alpha)))
combined_image = Image.alpha_composite(input_image.convert("RGBA"), overlay)
return combined_image.convert("RGB")
def predict(inp):
model = load_model("./unet_resnet50.ckpt")
# Check if a GPU is available and move the model to the GPU if possible
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
preprocessed_image = preprocess_image(inp)
# Move the input tensor to the GPU if available
preprocessed_image = preprocessed_image.to(device)
segmentation_result = predict_segmentation(model, preprocessed_image)
# Move the output tensor back to the CPU for post-processing
segmentation_result = segmentation_result.cpu()
output_image = overlay_prediction(inp, segmentation_result)
return output_image
gr.Interface(
fn=predict,
inputs=gr.inputs.Image(type="pil"),
outputs="image",
examples=["./example1.jpg", "./example2.jpg", "./example3.jpg"]
).launch()