English
cloudseg-models / app.py
XavierJiezou's picture
Add files using upload-large-folder tool
63a9590 verified
from typing import List
from glob import glob
import numpy as np
from PIL import Image
from mmseg.models.segmentors.encoder_decoder import EncoderDecoder
import gradio as gr
import torch
import os
from models.cdnetv1 import CDnetV1
from models.cdnetv2 import CDnetV2
from models.dbnet import DBNet
from models.hrcloudnet import HRCloudNet
from models.kappamask import KappaMask
from models.mcdnet import MCDNet
from models.scnn import SCNN
from models.unetmobv2 import UNetMobV2
class CloudAdapterGradio:
def __init__(self, device="cpu", example_inputs=None, num_classes=2, palette=None, other_model_weight_path=None):
self.device = device
self.example_inputs = example_inputs
self.img_size = 256 if num_classes == 2 else 512
self.palette = palette
self.legend = self.html_legend(num_classes=num_classes)
self.other_models = {
"cdnetv1": CDnetV1(num_classes=num_classes).to(self.device),
"cdnetv2": CDnetV2(num_classes=num_classes).to(self.device),
"hrcloudnet": HRCloudNet(num_classes=num_classes).to(self.device),
"mcdnet": MCDNet(in_channels=3, num_classes=num_classes).to(self.device),
"scnn": SCNN(num_classes=num_classes).to(self.device),
"dbnet": DBNet(img_size=self.img_size, in_channels=3, num_classes=num_classes).to(
self.device
),
"unetmobv2": UNetMobV2(num_classes=num_classes).to(self.device),
"kappamask": KappaMask(num_classes=num_classes, in_channels=3).to(self.device)
}
self.name_mapping = {
"KappaMask": "kappamask",
"CDNetv1": "cdnetv1",
"CDNetv2": "cdnetv2",
"HRCloudNet": "hrcloudnet",
"MCDNet": "mcdnet",
"SCNN": "scnn",
"DBNet": "dbnet",
"UNetMobv2": "unetmobv2",
"Cloud-Adapter": "cloud-adapter",
}
self.load_weights(other_model_weight_path)
self.create_ui()
def load_weights(self, checkpoint_path: str):
for model_name, model in self.other_models.items():
weight_path = os.path.join(checkpoint_path, model_name+".bin")
weight_path = glob(weight_path)[0]
weight = torch.load(weight_path, map_location=self.device)
model.load_state_dict(weight)
model.eval()
print(f"Loaded {model_name} weights from {weight_path}")
def html_legend(self, num_classes=2):
if num_classes == 2:
return """
<div style="margin-top: 10px; text-align: left; display: flex; align-items: center; gap: 20px;justify-content: center;">
<div style="display: flex; align-items: center;">
<div style="width: 20px; height: 20px; background-color: rgb(79, 253, 199); margin-right: 10px; "></div>
<span>Clear</span>
</div>
<div style="display: flex; align-items: center;">
<div style="width: 20px; height: 20px; background-color: rgb(77, 2, 115); margin-right: 10px; "></div>
<span>Cloud</span>
</div>
</div>
"""
return """
<div style="margin-top: 10px; text-align: left; display: flex; align-items: center; gap: 20px;justify-content: center;">
<div style="display: flex; align-items: center;">
<div style="width: 20px; height: 20px; background-color: rgb(79, 253, 199); margin-right: 10px; "></div>
<span>Clear Sky</span>
</div>
<div style="display: flex; align-items: center;">
<div style="width: 20px; height: 20px; background-color: rgb(77, 2, 115); margin-right: 10px; "></div>
<span>Thick Cloud</span>
</div>
<div style="display: flex; align-items: center;">
<div style="width: 20px; height: 20px; background-color: rgb(251, 255, 41); margin-right: 10px; "></div>
<span>Thin Cloud</span>
</div>
<div style="display: flex; align-items: center;">
<div style="width: 20px; height: 20px; background-color: rgb(221, 53, 223); margin-right: 10px; "></div>
<span>Cloud Shadow</span>
</div>
</div>
"""
def create_ui(self):
with gr.Row():
# 左侧:输入图片和按钮
with gr.Column(scale=1): # 左侧列
in_image = gr.Image(
label='Input Image',
sources='upload',
elem_classes='input_image',
interactive=True,
type="pil",
)
with gr.Row():
# 增加一个下拉菜单
model_choice = gr.Dropdown(
choices=[
"DBNet",
"HRCloudNet",
"CDNetv2",
"UNetMobv2",
"CDNetv1",
"MCDNet",
"KappaMask",
"SCNN",
],
value="DBNet",
label="Model",
elem_classes='model_type',
)
run_button = gr.Button(
'Run',
variant="primary",
)
# 示例输入列表
gr.Examples(
examples=self.example_inputs,
inputs=in_image,
label="Example Inputs"
)
# 右侧:输出图片
with gr.Column(scale=1): # 右侧列
with gr.Column():
# 输出图片
out_image = gr.Image(
label='Output Image',
elem_classes='output_image',
interactive=False
)
# 图例
legend = gr.HTML(
value=self.legend,
elem_classes="output_legend",
)
# 按钮点击逻辑:触发图像转换
run_button.click(
self.inference,
inputs=[in_image, model_choice],
outputs=out_image,
)
@torch.no_grad()
def inference(self, image: Image.Image, model_choice: str) -> Image.Image:
return self.simple_model_forward(image, self.name_mapping[model_choice])
@torch.no_grad()
def simple_model_forward(self, image: Image.Image, model_choice: str) -> Image.Image:
"""
Simple Model Inference
"""
ori_size = image.size
image = image.resize((self.img_size, self.img_size),
resample=Image.Resampling.BILINEAR)
image = np.array(image)
image = (image - np.min(image)) / (np.max(image)-np.min(image))
image = torch.from_numpy(image).unsqueeze(0).to(self.device)
image = image.permute(0, 3, 1, 2).float()
logits: torch.Tensor = self.other_models[model_choice].forward(image)
pred_mask = torch.argmax(logits, dim=1).squeeze(
0).cpu().numpy().astype(np.uint8)
del image
del logits
if torch.cuda.is_available():
torch.cuda.empty_cache()
im = Image.fromarray(pred_mask).convert("P")
im.putpalette(self.palette)
return im.resize(ori_size, resample=Image.Resampling.BILINEAR)
def get_palette(dataset_name: str) -> List[int]:
if dataset_name in ["cloudsen12_high_l1c", "cloudsen12_high_l2a"]:
return [79, 253, 199, 77, 2, 115, 251, 255, 41, 221, 53, 223]
if dataset_name == "l8_biome":
return [79, 253, 199, 221, 53, 223, 251, 255, 41, 77, 2, 115]
if dataset_name in ["gf12ms_whu_gf1", "gf12ms_whu_gf2", "hrc_whu"]:
return [79, 253, 199, 77, 2, 115]
raise Exception("dataset_name not supported")
if __name__ == '__main__':
title = 'Cloud Segmentation for Remote Sensing Images'
custom_css = """
h1 {
text-align: center;
font-size: 24px;
font-weight: bold;
margin-bottom: 20px;
}
"""
hrc_whu_examples = glob("example_inputs/hrc_whu/*")
gf1_examples = glob("example_inputs/gf1/*")
gf2_examples = glob("example_inputs/gf2/*")
l1c_examples = glob("example_inputs/l1c/*")
l2a_examples = glob("example_inputs/l2a/*")
l8_examples = glob("example_inputs/l8/*")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
with gr.Blocks(analytics_enabled=False, title=title,css=custom_css) as demo:
gr.Markdown(f'# {title}')
with gr.Tabs():
with gr.TabItem('Google Earth'):
CloudAdapterGradio(
device=device,
example_inputs=hrc_whu_examples,
num_classes=2,
palette=get_palette("hrc_whu"),
other_model_weight_path="checkpoints/hrc_whu"
)
with gr.TabItem('Gaofen-1'):
CloudAdapterGradio(
device=device,
example_inputs=gf1_examples,
num_classes=2,
palette=get_palette("gf12ms_whu_gf1"),
other_model_weight_path="checkpoints/gf12ms_whu_gf1"
)
with gr.TabItem('Gaofen-2'):
CloudAdapterGradio(
device=device,
example_inputs=gf2_examples,
num_classes=2,
palette=get_palette("gf12ms_whu_gf2"),
other_model_weight_path="checkpoints/gf12ms_whu_gf2"
)
with gr.TabItem('Sentinel-2 (L1C)'):
CloudAdapterGradio(
device=device,
example_inputs=l1c_examples,
num_classes=4,
palette=get_palette("cloudsen12_high_l1c"),
other_model_weight_path="checkpoints/cloudsen12_high_l1c"
)
with gr.TabItem('Sentinel-2 (L2A)'):
CloudAdapterGradio(
device=device,
example_inputs=l2a_examples,
num_classes=4,
palette=get_palette("cloudsen12_high_l2a"),
other_model_weight_path="checkpoints/cloudsen12_high_l2a"
)
with gr.TabItem('Landsat-8'):
CloudAdapterGradio(
device=device,
example_inputs=l8_examples,
num_classes=4,
palette=get_palette("l8_biome"),
other_model_weight_path="checkpoints/l8_biome"
)
demo.launch(share=True, debug=True)