|
from typing import List |
|
from glob import glob |
|
import numpy as np |
|
from PIL import Image |
|
from mmseg.models.segmentors.encoder_decoder import EncoderDecoder |
|
import gradio as gr |
|
import torch |
|
import os |
|
from models.cdnetv1 import CDnetV1 |
|
from models.cdnetv2 import CDnetV2 |
|
from models.dbnet import DBNet |
|
from models.hrcloudnet import HRCloudNet |
|
from models.kappamask import KappaMask |
|
from models.mcdnet import MCDNet |
|
from models.scnn import SCNN |
|
from models.unetmobv2 import UNetMobV2 |
|
|
|
|
|
class CloudAdapterGradio: |
|
def __init__(self, device="cpu", example_inputs=None, num_classes=2, palette=None, other_model_weight_path=None): |
|
self.device = device |
|
self.example_inputs = example_inputs |
|
self.img_size = 256 if num_classes == 2 else 512 |
|
self.palette = palette |
|
self.legend = self.html_legend(num_classes=num_classes) |
|
|
|
self.other_models = { |
|
"cdnetv1": CDnetV1(num_classes=num_classes).to(self.device), |
|
"cdnetv2": CDnetV2(num_classes=num_classes).to(self.device), |
|
"hrcloudnet": HRCloudNet(num_classes=num_classes).to(self.device), |
|
"mcdnet": MCDNet(in_channels=3, num_classes=num_classes).to(self.device), |
|
"scnn": SCNN(num_classes=num_classes).to(self.device), |
|
"dbnet": DBNet(img_size=self.img_size, in_channels=3, num_classes=num_classes).to( |
|
self.device |
|
), |
|
"unetmobv2": UNetMobV2(num_classes=num_classes).to(self.device), |
|
"kappamask": KappaMask(num_classes=num_classes, in_channels=3).to(self.device) |
|
} |
|
self.name_mapping = { |
|
"KappaMask": "kappamask", |
|
"CDNetv1": "cdnetv1", |
|
"CDNetv2": "cdnetv2", |
|
"HRCloudNet": "hrcloudnet", |
|
"MCDNet": "mcdnet", |
|
"SCNN": "scnn", |
|
"DBNet": "dbnet", |
|
"UNetMobv2": "unetmobv2", |
|
"Cloud-Adapter": "cloud-adapter", |
|
} |
|
|
|
self.load_weights(other_model_weight_path) |
|
|
|
self.create_ui() |
|
|
|
def load_weights(self, checkpoint_path: str): |
|
for model_name, model in self.other_models.items(): |
|
weight_path = os.path.join(checkpoint_path, model_name+".bin") |
|
weight_path = glob(weight_path)[0] |
|
weight = torch.load(weight_path, map_location=self.device) |
|
model.load_state_dict(weight) |
|
model.eval() |
|
print(f"Loaded {model_name} weights from {weight_path}") |
|
|
|
def html_legend(self, num_classes=2): |
|
if num_classes == 2: |
|
return """ |
|
<div style="margin-top: 10px; text-align: left; display: flex; align-items: center; gap: 20px;justify-content: center;"> |
|
<div style="display: flex; align-items: center;"> |
|
<div style="width: 20px; height: 20px; background-color: rgb(79, 253, 199); margin-right: 10px; "></div> |
|
<span>Clear</span> |
|
</div> |
|
<div style="display: flex; align-items: center;"> |
|
<div style="width: 20px; height: 20px; background-color: rgb(77, 2, 115); margin-right: 10px; "></div> |
|
<span>Cloud</span> |
|
</div> |
|
</div> |
|
""" |
|
return """ |
|
<div style="margin-top: 10px; text-align: left; display: flex; align-items: center; gap: 20px;justify-content: center;"> |
|
<div style="display: flex; align-items: center;"> |
|
<div style="width: 20px; height: 20px; background-color: rgb(79, 253, 199); margin-right: 10px; "></div> |
|
<span>Clear Sky</span> |
|
</div> |
|
<div style="display: flex; align-items: center;"> |
|
<div style="width: 20px; height: 20px; background-color: rgb(77, 2, 115); margin-right: 10px; "></div> |
|
<span>Thick Cloud</span> |
|
</div> |
|
<div style="display: flex; align-items: center;"> |
|
<div style="width: 20px; height: 20px; background-color: rgb(251, 255, 41); margin-right: 10px; "></div> |
|
<span>Thin Cloud</span> |
|
</div> |
|
<div style="display: flex; align-items: center;"> |
|
<div style="width: 20px; height: 20px; background-color: rgb(221, 53, 223); margin-right: 10px; "></div> |
|
<span>Cloud Shadow</span> |
|
</div> |
|
</div> |
|
""" |
|
|
|
def create_ui(self): |
|
with gr.Row(): |
|
|
|
with gr.Column(scale=1): |
|
in_image = gr.Image( |
|
label='Input Image', |
|
sources='upload', |
|
elem_classes='input_image', |
|
interactive=True, |
|
type="pil", |
|
) |
|
with gr.Row(): |
|
|
|
model_choice = gr.Dropdown( |
|
choices=[ |
|
"DBNet", |
|
"HRCloudNet", |
|
"CDNetv2", |
|
"UNetMobv2", |
|
"CDNetv1", |
|
"MCDNet", |
|
"KappaMask", |
|
"SCNN", |
|
], |
|
value="DBNet", |
|
label="Model", |
|
elem_classes='model_type', |
|
) |
|
run_button = gr.Button( |
|
'Run', |
|
variant="primary", |
|
) |
|
|
|
gr.Examples( |
|
examples=self.example_inputs, |
|
inputs=in_image, |
|
label="Example Inputs" |
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
with gr.Column(): |
|
|
|
out_image = gr.Image( |
|
label='Output Image', |
|
elem_classes='output_image', |
|
interactive=False |
|
) |
|
|
|
legend = gr.HTML( |
|
value=self.legend, |
|
elem_classes="output_legend", |
|
) |
|
|
|
|
|
run_button.click( |
|
self.inference, |
|
inputs=[in_image, model_choice], |
|
outputs=out_image, |
|
) |
|
|
|
@torch.no_grad() |
|
def inference(self, image: Image.Image, model_choice: str) -> Image.Image: |
|
return self.simple_model_forward(image, self.name_mapping[model_choice]) |
|
|
|
@torch.no_grad() |
|
def simple_model_forward(self, image: Image.Image, model_choice: str) -> Image.Image: |
|
""" |
|
Simple Model Inference |
|
""" |
|
ori_size = image.size |
|
image = image.resize((self.img_size, self.img_size), |
|
resample=Image.Resampling.BILINEAR) |
|
image = np.array(image) |
|
image = (image - np.min(image)) / (np.max(image)-np.min(image)) |
|
|
|
image = torch.from_numpy(image).unsqueeze(0).to(self.device) |
|
image = image.permute(0, 3, 1, 2).float() |
|
|
|
logits: torch.Tensor = self.other_models[model_choice].forward(image) |
|
pred_mask = torch.argmax(logits, dim=1).squeeze( |
|
0).cpu().numpy().astype(np.uint8) |
|
|
|
del image |
|
del logits |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
im = Image.fromarray(pred_mask).convert("P") |
|
im.putpalette(self.palette) |
|
return im.resize(ori_size, resample=Image.Resampling.BILINEAR) |
|
|
|
|
|
def get_palette(dataset_name: str) -> List[int]: |
|
if dataset_name in ["cloudsen12_high_l1c", "cloudsen12_high_l2a"]: |
|
return [79, 253, 199, 77, 2, 115, 251, 255, 41, 221, 53, 223] |
|
if dataset_name == "l8_biome": |
|
return [79, 253, 199, 221, 53, 223, 251, 255, 41, 77, 2, 115] |
|
if dataset_name in ["gf12ms_whu_gf1", "gf12ms_whu_gf2", "hrc_whu"]: |
|
return [79, 253, 199, 77, 2, 115] |
|
raise Exception("dataset_name not supported") |
|
|
|
|
|
if __name__ == '__main__': |
|
title = 'Cloud Segmentation for Remote Sensing Images' |
|
custom_css = """ |
|
h1 { |
|
text-align: center; |
|
font-size: 24px; |
|
font-weight: bold; |
|
margin-bottom: 20px; |
|
} |
|
""" |
|
hrc_whu_examples = glob("example_inputs/hrc_whu/*") |
|
gf1_examples = glob("example_inputs/gf1/*") |
|
gf2_examples = glob("example_inputs/gf2/*") |
|
l1c_examples = glob("example_inputs/l1c/*") |
|
l2a_examples = glob("example_inputs/l2a/*") |
|
l8_examples = glob("example_inputs/l8/*") |
|
|
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
with gr.Blocks(analytics_enabled=False, title=title,css=custom_css) as demo: |
|
gr.Markdown(f'# {title}') |
|
with gr.Tabs(): |
|
with gr.TabItem('Google Earth'): |
|
CloudAdapterGradio( |
|
device=device, |
|
example_inputs=hrc_whu_examples, |
|
num_classes=2, |
|
palette=get_palette("hrc_whu"), |
|
other_model_weight_path="checkpoints/hrc_whu" |
|
) |
|
with gr.TabItem('Gaofen-1'): |
|
CloudAdapterGradio( |
|
device=device, |
|
example_inputs=gf1_examples, |
|
num_classes=2, |
|
palette=get_palette("gf12ms_whu_gf1"), |
|
other_model_weight_path="checkpoints/gf12ms_whu_gf1" |
|
) |
|
with gr.TabItem('Gaofen-2'): |
|
CloudAdapterGradio( |
|
device=device, |
|
example_inputs=gf2_examples, |
|
num_classes=2, |
|
palette=get_palette("gf12ms_whu_gf2"), |
|
other_model_weight_path="checkpoints/gf12ms_whu_gf2" |
|
) |
|
|
|
with gr.TabItem('Sentinel-2 (L1C)'): |
|
CloudAdapterGradio( |
|
device=device, |
|
example_inputs=l1c_examples, |
|
num_classes=4, |
|
palette=get_palette("cloudsen12_high_l1c"), |
|
other_model_weight_path="checkpoints/cloudsen12_high_l1c" |
|
) |
|
with gr.TabItem('Sentinel-2 (L2A)'): |
|
CloudAdapterGradio( |
|
device=device, |
|
example_inputs=l2a_examples, |
|
num_classes=4, |
|
palette=get_palette("cloudsen12_high_l2a"), |
|
other_model_weight_path="checkpoints/cloudsen12_high_l2a" |
|
) |
|
with gr.TabItem('Landsat-8'): |
|
CloudAdapterGradio( |
|
device=device, |
|
example_inputs=l8_examples, |
|
num_classes=4, |
|
palette=get_palette("l8_biome"), |
|
other_model_weight_path="checkpoints/l8_biome" |
|
) |
|
|
|
demo.launch(share=True, debug=True) |