sergiopaniego commited on
Commit
b42e4a2
·
1 Parent(s): 1d3ac1a

Updated Space

Browse files
Files changed (2) hide show
  1. app.py +107 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+
5
+ import numpy as np
6
+ from PIL import Image
7
+ from transformers import pipeline
8
+ import matplotlib.pyplot as plt
9
+ import io
10
+
11
+ model_pipeline = pipeline("image-segmentation", model="sergiopaniego/segformer-b0-segments-sidewalk-finetuned")
12
+
13
+ id2label = {0: 'unlabeled', 1: 'flat-road', 2: 'flat-sidewalk', 3: 'flat-crosswalk', 4: 'flat-cyclinglane', 5: 'flat-parkingdriveway', 6: 'flat-railtrack', 7: 'flat-curb', 8: 'human-person', 9: 'human-rider', 10: 'vehicle-car', 11: 'vehicle-truck', 12: 'vehicle-bus', 13: 'vehicle-tramtrain', 14: 'vehicle-motorcycle', 15: 'vehicle-bicycle', 16: 'vehicle-caravan', 17: 'vehicle-cartrailer', 18: 'construction-building', 19: 'construction-door', 20: 'construction-wall', 21: 'construction-fenceguardrail', 22: 'construction-bridge', 23: 'construction-tunnel', 24: 'construction-stairs', 25: 'object-pole', 26: 'object-trafficsign', 27: 'object-trafficlight', 28: 'nature-vegetation', 29: 'nature-terrain', 30: 'sky', 31: 'void-ground', 32: 'void-dynamic', 33: 'void-static', 34: 'void-unclear'}
14
+ sidewalk_palette = [
15
+ [0, 0, 0], # unlabeled
16
+ [216, 82, 24], # flat-road
17
+ [255, 255, 0], # flat-sidewalk
18
+ [125, 46, 141], # flat-crosswalk
19
+ [118, 171, 47], # flat-cyclinglane
20
+ [161, 19, 46], # flat-parkingdriveway
21
+ [255, 0, 0], # flat-railtrack
22
+ [0, 128, 128], # flat-curb
23
+ [190, 190, 0], # human-person
24
+ [0, 255, 0], # human-rider
25
+ [0, 0, 255], # vehicle-car
26
+ [170, 0, 255], # vehicle-truck
27
+ [84, 84, 0], # vehicle-bus
28
+ [84, 170, 0], # vehicle-tramtrain
29
+ [84, 255, 0], # vehicle-motorcycle
30
+ [170, 84, 0], # vehicle-bicycle
31
+ [170, 170, 0], # vehicle-caravan
32
+ [170, 255, 0], # vehicle-cartrailer
33
+ [255, 84, 0], # construction-building
34
+ [255, 170, 0], # construction-door
35
+ [255, 255, 0], # construction-wall
36
+ [33, 138, 200], # construction-fenceguardrail
37
+ [0, 170, 127], # construction-bridge
38
+ [0, 255, 127], # construction-tunnel
39
+ [84, 0, 127], # construction-stairs
40
+ [84, 84, 127], # object-pole
41
+ [84, 170, 127], # object-trafficsign
42
+ [84, 255, 127], # object-trafficlight
43
+ [170, 0, 127], # nature-vegetation
44
+ [170, 84, 127], # nature-terrain
45
+ [170, 170, 127], # sky
46
+ [170, 255, 127], # void-ground
47
+ [255, 0, 127], # void-dynamic
48
+ [255, 84, 127], # void-static
49
+ [255, 170, 127], # void-unclear
50
+ ]
51
+
52
+ def get_output_figure(pil_img, results):
53
+ plt.figure(figsize=(16, 10))
54
+ plt.imshow(pil_img)
55
+ image_array = np.array(pil_img)
56
+
57
+ segmentation_map = np.zeros_like(image_array)
58
+
59
+ for result in results:
60
+ mask = np.array(result['mask'])
61
+ label = result['label']
62
+
63
+ label_index = list(id2label.values()).index(label)
64
+
65
+ color = sidewalk_palette[label_index]
66
+
67
+ for c in range(3):
68
+ segmentation_map[:, :, c] = np.where(mask, color[c], segmentation_map[:, :, c])
69
+
70
+ plt.imshow(segmentation_map, alpha=0.5)
71
+ plt.axis('off')
72
+
73
+ return plt.gcf()
74
+
75
+ @spaces.GPU
76
+ def detect(image):
77
+ results = model_pipeline(image)
78
+ print(results)
79
+
80
+ output_figure = get_output_figure(image, results)
81
+
82
+ buf = io.BytesIO()
83
+ output_figure.savefig(buf, bbox_inches='tight')
84
+ buf.seek(0)
85
+ output_pil_img = Image.open(buf)
86
+
87
+ return output_pil_img
88
+
89
+ with gr.Blocks() as demo:
90
+ gr.Markdown("# Semantic segmentation with SegFormer fine tuned on segments/sidewalk")
91
+ gr.Markdown(
92
+ """
93
+ This application uses a fine tuned SegFormer for sematic segmenation over an input image.
94
+ This version was trained using segments/sidewalk dataset.
95
+ You can load an image and see the predicted segmentation.
96
+ """
97
+ )
98
+
99
+ gr.Interface(
100
+ fn=detect,
101
+ inputs=gr.Image(label="Input image", type="pil"),
102
+ outputs=[
103
+ gr.Image(label="Output prediction", type="pil")
104
+ ]
105
+ )
106
+
107
+ demo.launch(show_error=True)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers
2
+ torch