File size: 4,204 Bytes
8ae7071
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4def87
8ae7071
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242850e
8ae7071
 
242850e
cc6b299
8ae7071
 
 
 
 
 
242850e
8ae7071
 
 
008d1df
 
45a81a5
 
8ae7071
 
45a81a5
8ae7071
 
05b14d0
008d1df
 
45a81a5
008d1df
 
 
05b14d0
 
 
 
 
 
 
 
311bc12
8ae7071
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
tree-segmentation
Proof of concept showing effectiveness of a fine tuned instance segmentation model for detecting trees.
"""
import os
import cv2
os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'")
from transformers import DetrFeatureExtractor, DetrForSegmentation
from PIL import Image
import gradio as gr
import numpy as np
import torch
import torchvision
import detectron2

# import some common detectron2 utilities
import itertools
import seaborn as sns
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.utils.visualizer import ColorMode
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.utils.visualizer import ColorMode
from detectron2.structures import Instances


# Model for trees
tree_cfg = get_cfg()
tree_cfg.merge_from_file("tree_model_weights/tree_cfg.yml")
tree_cfg.MODEL.DEVICE='cpu'
tree_cfg.MODEL.WEIGHTS = "tree_model_weights/treev1_best.pth"
tree_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2
tree_predictor = DefaultPredictor(tree_cfg)

# Model for buildings
building_cfg = get_cfg()
building_cfg.merge_from_file("building_model_weight/buildings_poc_cfg.yml")
building_cfg.MODEL.DEVICE='cpu'
building_cfg.MODEL.WEIGHTS = "building_model_weight/model_final.pth"  
building_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 8
building_predictor = DefaultPredictor(building_cfg)

# A function that runs the buildings model on an given image and confidence threshold
def segment_building(im, confidence_threshold):
    building_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = confidence_threshold
    outputs = building_predictor(im)
    building_instances = outputs["instances"].to("cpu")

    return building_instances

# A function that runs the trees model on an given image and confidence threshold
def segment_tree(im, confidence_threshold):
    tree_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = confidence_threshold
    outputs = tree_predictor(im)
    tree_instances = outputs["instances"].to("cpu")

    return tree_instances

# Function to map strings to color mode
def map_color_mode(color_mode):
    if color_mode == "Black/white":
        return ColorMode.IMAGE_BW
    elif color_mode == "Random":
        return ColorMode.IMAGE
    elif color_mode == "Segmentation" or color_mode == None:
        return ColorMode.SEGMENTATION

def visualize_image(im, mode, tree_threshold:float, building_threshold:float, color_mode):
    im = np.array(im)
    color_mode = map_color_mode(color_mode)

    if mode == "Trees":
        instances = segment_tree(im, tree_threshold)
    elif mode == "Buildings":
        instances = segment_building(im, building_threshold)
    elif mode == "Both" or mode == None:
        tree_instances = segment_tree(im, tree_threshold)
        building_instances = segment_building(im, building_threshold)
        instances = Instances.cat([tree_instances, building_instances])

    metadata = MetadataCatalog.get("urban-trees-fdokv_train")
    print("metadata", type(metadata), metadata)
    print('metadata.get("thing_classes")', type(metadata.get("thing_classes")), metadata.get("thing_classes"))

    visualizer = Visualizer(im[:, :, ::-1],
                            metadata=metadata,
                            scale=0.5,
                            instance_mode=color_mode)

    dataset_names = MetadataCatalog.list()
    print(dataset_names)

    metadata = MetadataCatalog.get("urban-small_train")
    category_names = metadata.get("thing_classes")
    print(category_names)
    # visualizer = Visualizer(im[:, :, ::-1],
    #                         metadata=metadata,
    #                         scale=0.5,
    #                         instance_mode=color_mode)
    # # in the visualizer, add category label names to detected instances
    # for instance in instances:
    #     label = category_names[instance["category_id"]]
    #     visualizer.draw_text(label, instance["bbox"][:2])

    output_image = visualizer.draw_instance_predictions(instances)
    
    return Image.fromarray(output_image.get_image()[:, :, ::-1])