from typing import List from glob import glob import numpy as np from PIL import Image from mmseg.models.segmentors.encoder_decoder import EncoderDecoder import gradio as gr import torch import os from models.cdnetv1 import CDnetV1 from models.cdnetv2 import CDnetV2 from models.dbnet import DBNet from models.hrcloudnet import HRCloudNet from models.kappamask import KappaMask from models.mcdnet import MCDNet from models.scnn import SCNN from models.unetmobv2 import UNetMobV2 class CloudAdapterGradio: def __init__(self, device="cpu", example_inputs=None, num_classes=2, palette=None, other_model_weight_path=None): self.device = device self.example_inputs = example_inputs self.img_size = 256 if num_classes == 2 else 512 self.palette = palette self.legend = self.html_legend(num_classes=num_classes) self.other_models = { "cdnetv1": CDnetV1(num_classes=num_classes).to(self.device), "cdnetv2": CDnetV2(num_classes=num_classes).to(self.device), "hrcloudnet": HRCloudNet(num_classes=num_classes).to(self.device), "mcdnet": MCDNet(in_channels=3, num_classes=num_classes).to(self.device), "scnn": SCNN(num_classes=num_classes).to(self.device), "dbnet": DBNet(img_size=self.img_size, in_channels=3, num_classes=num_classes).to( self.device ), "unetmobv2": UNetMobV2(num_classes=num_classes).to(self.device), "kappamask": KappaMask(num_classes=num_classes, in_channels=3).to(self.device) } self.name_mapping = { "KappaMask": "kappamask", "CDNetv1": "cdnetv1", "CDNetv2": "cdnetv2", "HRCloudNet": "hrcloudnet", "MCDNet": "mcdnet", "SCNN": "scnn", "DBNet": "dbnet", "UNetMobv2": "unetmobv2", "Cloud-Adapter": "cloud-adapter", } self.load_weights(other_model_weight_path) self.create_ui() def load_weights(self, checkpoint_path: str): for model_name, model in self.other_models.items(): weight_path = os.path.join(checkpoint_path, model_name+".bin") weight_path = glob(weight_path)[0] weight = torch.load(weight_path, map_location=self.device) model.load_state_dict(weight) model.eval() print(f"Loaded {model_name} weights from {weight_path}") def html_legend(self, num_classes=2): if num_classes == 2: return """
Clear
Cloud
""" return """
Clear Sky
Thick Cloud
Thin Cloud
Cloud Shadow
""" def create_ui(self): with gr.Row(): # 左侧:输入图片和按钮 with gr.Column(scale=1): # 左侧列 in_image = gr.Image( label='Input Image', sources='upload', elem_classes='input_image', interactive=True, type="pil", ) with gr.Row(): # 增加一个下拉菜单 model_choice = gr.Dropdown( choices=[ "DBNet", "HRCloudNet", "CDNetv2", "UNetMobv2", "CDNetv1", "MCDNet", "KappaMask", "SCNN", ], value="DBNet", label="Model", elem_classes='model_type', ) run_button = gr.Button( 'Run', variant="primary", ) # 示例输入列表 gr.Examples( examples=self.example_inputs, inputs=in_image, label="Example Inputs" ) # 右侧:输出图片 with gr.Column(scale=1): # 右侧列 with gr.Column(): # 输出图片 out_image = gr.Image( label='Output Image', elem_classes='output_image', interactive=False ) # 图例 legend = gr.HTML( value=self.legend, elem_classes="output_legend", ) # 按钮点击逻辑:触发图像转换 run_button.click( self.inference, inputs=[in_image, model_choice], outputs=out_image, ) @torch.no_grad() def inference(self, image: Image.Image, model_choice: str) -> Image.Image: return self.simple_model_forward(image, self.name_mapping[model_choice]) @torch.no_grad() def simple_model_forward(self, image: Image.Image, model_choice: str) -> Image.Image: """ Simple Model Inference """ ori_size = image.size image = image.resize((self.img_size, self.img_size), resample=Image.Resampling.BILINEAR) image = np.array(image) image = (image - np.min(image)) / (np.max(image)-np.min(image)) image = torch.from_numpy(image).unsqueeze(0).to(self.device) image = image.permute(0, 3, 1, 2).float() logits: torch.Tensor = self.other_models[model_choice].forward(image) pred_mask = torch.argmax(logits, dim=1).squeeze( 0).cpu().numpy().astype(np.uint8) del image del logits if torch.cuda.is_available(): torch.cuda.empty_cache() im = Image.fromarray(pred_mask).convert("P") im.putpalette(self.palette) return im.resize(ori_size, resample=Image.Resampling.BILINEAR) def get_palette(dataset_name: str) -> List[int]: if dataset_name in ["cloudsen12_high_l1c", "cloudsen12_high_l2a"]: return [79, 253, 199, 77, 2, 115, 251, 255, 41, 221, 53, 223] if dataset_name == "l8_biome": return [79, 253, 199, 221, 53, 223, 251, 255, 41, 77, 2, 115] if dataset_name in ["gf12ms_whu_gf1", "gf12ms_whu_gf2", "hrc_whu"]: return [79, 253, 199, 77, 2, 115] raise Exception("dataset_name not supported") if __name__ == '__main__': title = 'Cloud Segmentation for Remote Sensing Images' custom_css = """ h1 { text-align: center; font-size: 24px; font-weight: bold; margin-bottom: 20px; } """ hrc_whu_examples = glob("example_inputs/hrc_whu/*") gf1_examples = glob("example_inputs/gf1/*") gf2_examples = glob("example_inputs/gf2/*") l1c_examples = glob("example_inputs/l1c/*") l2a_examples = glob("example_inputs/l2a/*") l8_examples = glob("example_inputs/l8/*") device = "cuda:0" if torch.cuda.is_available() else "cpu" with gr.Blocks(analytics_enabled=False, title=title,css=custom_css) as demo: gr.Markdown(f'# {title}') with gr.Tabs(): with gr.TabItem('Google Earth'): CloudAdapterGradio( device=device, example_inputs=hrc_whu_examples, num_classes=2, palette=get_palette("hrc_whu"), other_model_weight_path="checkpoints/hrc_whu" ) with gr.TabItem('Gaofen-1'): CloudAdapterGradio( device=device, example_inputs=gf1_examples, num_classes=2, palette=get_palette("gf12ms_whu_gf1"), other_model_weight_path="checkpoints/gf12ms_whu_gf1" ) with gr.TabItem('Gaofen-2'): CloudAdapterGradio( device=device, example_inputs=gf2_examples, num_classes=2, palette=get_palette("gf12ms_whu_gf2"), other_model_weight_path="checkpoints/gf12ms_whu_gf2" ) with gr.TabItem('Sentinel-2 (L1C)'): CloudAdapterGradio( device=device, example_inputs=l1c_examples, num_classes=4, palette=get_palette("cloudsen12_high_l1c"), other_model_weight_path="checkpoints/cloudsen12_high_l1c" ) with gr.TabItem('Sentinel-2 (L2A)'): CloudAdapterGradio( device=device, example_inputs=l2a_examples, num_classes=4, palette=get_palette("cloudsen12_high_l2a"), other_model_weight_path="checkpoints/cloudsen12_high_l2a" ) with gr.TabItem('Landsat-8'): CloudAdapterGradio( device=device, example_inputs=l8_examples, num_classes=4, palette=get_palette("l8_biome"), other_model_weight_path="checkpoints/l8_biome" ) demo.launch(share=True, debug=True)