Spaces:
Runtime error
Runtime error
Upload 2 files
Browse files- app.py +181 -0
- requirements.txt +19 -0
app.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
from glob import glob
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import matplotlib
|
8 |
+
import tensorflow as tf
|
9 |
+
from tensorflow import keras
|
10 |
+
from tensorflow.keras import backend as K
|
11 |
+
import pandas as pd
|
12 |
+
import gc
|
13 |
+
import random
|
14 |
+
import math
|
15 |
+
import glob
|
16 |
+
import torch
|
17 |
+
import gradio as gr
|
18 |
+
from PIL import Image
|
19 |
+
import cv2
|
20 |
+
|
21 |
+
|
22 |
+
classes = ['None','building','pervious surface','impervious surface','bare soil','water','coniferous','deciduous','brushwood','vineyard','herbaceous vegetation','agricultural land','plowed land']
|
23 |
+
id2label = pd.DataFrame(classes)[0].to_dict()
|
24 |
+
print(id2label)
|
25 |
+
label2id = {v: k for k, v in id2label.items()}
|
26 |
+
num_labels = len(id2label)
|
27 |
+
|
28 |
+
from transformers import SegformerForSemanticSegmentation, SegformerFeatureExtractor
|
29 |
+
|
30 |
+
segformer_b0_rgb_model = SegformerForSemanticSegmentation.from_pretrained("alanoix/segformer_b0_flair_one",
|
31 |
+
num_labels=len(id2label),
|
32 |
+
id2label=id2label,
|
33 |
+
label2id=label2id)
|
34 |
+
|
35 |
+
segformer_rgb_feature_extractor = SegformerFeatureExtractor(ignore_index=0, reduce_labels=False, do_resize=False, do_rescale=False, do_normalize=False)
|
36 |
+
segformer_b0_rgb_model= torch.quantization.quantize_dynamic(segformer_b0_rgb_model, {torch.nn.Linear}, dtype=torch.qint8)
|
37 |
+
|
38 |
+
|
39 |
+
import albumentations as aug
|
40 |
+
MEAN = np.array([0.44050665, 0.45704361, 0.42254708])
|
41 |
+
STD = np.array([0.20264351, 0.1782405 , 0.17575739])
|
42 |
+
|
43 |
+
test_transform = aug.Compose([
|
44 |
+
aug.Normalize(mean=MEAN, std=STD),
|
45 |
+
])
|
46 |
+
|
47 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
48 |
+
segformer_b0_rgb_model = segformer_b0_rgb_model.to(device)
|
49 |
+
|
50 |
+
class_colors = [(random.randint(0, 255), random.randint(
|
51 |
+
0, 255), random.randint(0, 255)) for _ in range(5000)]
|
52 |
+
|
53 |
+
|
54 |
+
# Default IMAGE_ORDERING = channels_last
|
55 |
+
IMAGE_ORDERING = "channels_last"
|
56 |
+
|
57 |
+
|
58 |
+
def get_colored_segmentation_image(seg_arr, n_classes, colors=class_colors):
|
59 |
+
output_height = seg_arr.shape[0]
|
60 |
+
output_width = seg_arr.shape[1]
|
61 |
+
|
62 |
+
seg_img = np.zeros((output_height, output_width, 3))
|
63 |
+
|
64 |
+
for c in range(n_classes):
|
65 |
+
seg_arr_c = seg_arr[:, :] == c
|
66 |
+
seg_img[:, :, 0] += ((seg_arr_c)*(colors[c][0])).astype('uint8')
|
67 |
+
seg_img[:, :, 1] += ((seg_arr_c)*(colors[c][1])).astype('uint8')
|
68 |
+
seg_img[:, :, 2] += ((seg_arr_c)*(colors[c][2])).astype('uint8')
|
69 |
+
|
70 |
+
return seg_img
|
71 |
+
|
72 |
+
|
73 |
+
def get_legends(class_names, colors=class_colors):
|
74 |
+
|
75 |
+
n_classes = len(class_names)
|
76 |
+
legend = np.zeros(((len(class_names) * 25) + 25, 125, 3),
|
77 |
+
dtype="uint8") + 255
|
78 |
+
|
79 |
+
class_names_colors = enumerate(zip(class_names[:n_classes],
|
80 |
+
colors[:n_classes]))
|
81 |
+
|
82 |
+
for (i, (class_name, color)) in class_names_colors:
|
83 |
+
color = [int(c) for c in color]
|
84 |
+
cv2.putText(legend, class_name, (5, (i * 25) + 17),
|
85 |
+
cv2.FONT_HERSHEY_COMPLEX, 0.5, (0, 0, 0), 1)
|
86 |
+
cv2.rectangle(legend, (100, (i * 25)), (125, (i * 25) + 25),
|
87 |
+
tuple(color), -1)
|
88 |
+
|
89 |
+
return legend
|
90 |
+
|
91 |
+
|
92 |
+
def overlay_seg_image(inp_img, seg_img):
|
93 |
+
orininal_h = inp_img.shape[0]
|
94 |
+
orininal_w = inp_img.shape[1]
|
95 |
+
seg_img = cv2.resize(seg_img, (orininal_w, orininal_h), interpolation=cv2.INTER_NEAREST)
|
96 |
+
|
97 |
+
fused_img = (inp_img/2 + seg_img/2).astype('uint8')
|
98 |
+
return fused_img
|
99 |
+
|
100 |
+
|
101 |
+
def concat_lenends(seg_img, legend_img):
|
102 |
+
|
103 |
+
new_h = np.maximum(seg_img.shape[0], legend_img.shape[0])
|
104 |
+
new_w = seg_img.shape[1] + legend_img.shape[1]
|
105 |
+
|
106 |
+
out_img = np.zeros((new_h, new_w, 3)).astype('uint8') + legend_img[0, 0, 0]
|
107 |
+
|
108 |
+
out_img[:legend_img.shape[0], : legend_img.shape[1]] = np.copy(legend_img)
|
109 |
+
out_img[:seg_img.shape[0], legend_img.shape[1]:] = np.copy(seg_img)
|
110 |
+
|
111 |
+
return out_img
|
112 |
+
|
113 |
+
|
114 |
+
def visualize_segmentation(seg_arr, inp_img=None, n_classes=None,
|
115 |
+
colors=class_colors, class_names=None,
|
116 |
+
overlay_img=False, show_legends=False,
|
117 |
+
prediction_width=None, prediction_height=None):
|
118 |
+
|
119 |
+
if n_classes is None:
|
120 |
+
n_classes = np.max(seg_arr)
|
121 |
+
|
122 |
+
seg_img = get_colored_segmentation_image(seg_arr, n_classes, colors=colors)
|
123 |
+
|
124 |
+
if inp_img is not None:
|
125 |
+
original_h = inp_img.shape[0]
|
126 |
+
original_w = inp_img.shape[1]
|
127 |
+
seg_img = cv2.resize(seg_img, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
|
128 |
+
|
129 |
+
if (prediction_height is not None) and (prediction_width is not None):
|
130 |
+
seg_img = cv2.resize(seg_img, (prediction_width, prediction_height), interpolation=cv2.INTER_NEAREST)
|
131 |
+
if inp_img is not None:
|
132 |
+
inp_img = cv2.resize(inp_img,
|
133 |
+
(prediction_width, prediction_height))
|
134 |
+
|
135 |
+
if overlay_img:
|
136 |
+
assert inp_img is not None
|
137 |
+
seg_img = overlay_seg_image(inp_img, seg_img)
|
138 |
+
|
139 |
+
if show_legends:
|
140 |
+
assert class_names is not None
|
141 |
+
legend_img = get_legends(class_names, colors=colors)
|
142 |
+
|
143 |
+
seg_img = concat_lenends(seg_img, legend_img)
|
144 |
+
|
145 |
+
return seg_img
|
146 |
+
|
147 |
+
def query_image(img):
|
148 |
+
image_to_pred = test_transform(image=img)['image']
|
149 |
+
|
150 |
+
pixel_values = segformer_rgb_feature_extractor(image_to_pred, return_tensors="pt").pixel_values.to(device)
|
151 |
+
|
152 |
+
outputs_segformer_b0_rgb = segformer_b0_rgb_model(pixel_values=pixel_values)
|
153 |
+
pred_segformer_b0_rgb = outputs_segformer_b0_rgb.logits.cpu().detach().numpy()
|
154 |
+
|
155 |
+
pred = np.mean(np.array([K.softmax(pred_segformer_b0_rgb, axis = 1)]), axis = 0)
|
156 |
+
pred = tf.image.resize(tf.transpose(pred, perm=[0,2,3,1]), size = [512,512], method="bilinear") # resize to 512*512
|
157 |
+
pred = np.argmax(pred, axis = -1)
|
158 |
+
pred =np.squeeze(pred)
|
159 |
+
result = pred.astype(np.uint8)
|
160 |
+
|
161 |
+
class_names = [ 'None', 'building', 'pervious surface', 'impervious surface', 'bare soil','water','coniferous','deciduous','brushwood','vineyard', 'herbaceous vegetation', 'agricultural land', 'plowed land']
|
162 |
+
seg_img = visualize_segmentation(result, img, n_classes=13,
|
163 |
+
colors=class_colors , overlay_img=True,
|
164 |
+
show_legends=True,
|
165 |
+
class_names=class_names,
|
166 |
+
prediction_width=512,
|
167 |
+
prediction_height=512)
|
168 |
+
|
169 |
+
return seg_img
|
170 |
+
|
171 |
+
demo = gr.Interface(
|
172 |
+
|
173 |
+
query_image,
|
174 |
+
inputs=[gr.Image()],
|
175 |
+
outputs="image",
|
176 |
+
title="Image Segmentation Demo",
|
177 |
+
description = "Please upload an image to see segmentation capabilities of this model",
|
178 |
+
examples=["examples/IMG_011942.jpeg","examples/IMG_005339.jpeg","examples/IMG_004753.jpeg","examples/IMG_011617.jpeg","examples/IMG_003022.jpeg"]
|
179 |
+
)
|
180 |
+
|
181 |
+
demo.launch() #debug=True
|
requirements.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
albumentations==1.2.1
|
2 |
+
evaluate==0.4.0
|
3 |
+
numpy==1.22.4
|
4 |
+
opencv_python==4.7.0.72
|
5 |
+
pandas==1.4.4
|
6 |
+
Pillow==9.4.0
|
7 |
+
rasterio==1.3.6
|
8 |
+
scikit_learn==1.2.2
|
9 |
+
torch==1.13.1+cu116
|
10 |
+
tqdm==4.65.0
|
11 |
+
transformers==4.27.3
|
12 |
+
GDAL==3.3.2
|
13 |
+
matplotlib==3.7.1
|
14 |
+
osgeo==0.0.1
|
15 |
+
scikit_image==0.19.3
|
16 |
+
scipy==1.10.1
|
17 |
+
skimage==0.0
|
18 |
+
tensorflow==2.11.0
|
19 |
+
|