Sijuade commited on
Commit
f67b8d5
·
1 Parent(s): fee9635

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +27 -0
  2. requirements.txt +5 -0
  3. utils.py +177 -0
app.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from utils import (
4
+ predict,
5
+ get_html,
6
+ get_examples
7
+ )
8
+
9
+ examples = get_examples()
10
+ placeholder = 'Enter a word/phrase or multiple words/phrases separated by commas...'
11
+
12
+
13
+ with gr.Blocks() as interface:
14
+ gr.HTML(value=get_html, show_label=True)
15
+ with gr.Row():
16
+ inputs = [gr.Image(type="pil"),
17
+ gr.Textbox(label='Text Prompts', placeholder=placeholder, lines=3)]
18
+
19
+ with gr.Row():
20
+ outputs = gr.AnnotatedImage(label="Segmentation Masks")
21
+
22
+ with gr.Row():
23
+ button = gr.Button("Visualize Segments")
24
+ button.click(predict, inputs=inputs, outputs=outputs)
25
+
26
+ with gr.Row():
27
+ gr.Examples(examples=examples, inputs=inputs, outputs=outputs, fn=predict, cache_examples=True)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ pillow
3
+ gradio
4
+ torchvision
5
+ git+https://github.com/openai/CLIP.git
utils.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+ from clipseg import CLIPDensePredT
6
+
7
+
8
+ transform = transforms.Compose([
9
+ transforms.ToTensor(),
10
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
11
+ transforms.Resize((352, 352)),
12
+ ])
13
+
14
+
15
+ model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)
16
+ model.eval()
17
+ model.load_state_dict(torch.load('weights/rd64-uni.pth',
18
+ map_location=torch.device('cpu')), strict=False)
19
+
20
+
21
+ def predict(image, prompts):
22
+ """
23
+ Predict segmentation masks for the given image based on the provided prompts.
24
+
25
+ Parameters:
26
+ - image (PIL.Image): The input image.
27
+ - prompts (str): A comma-separated string of prompts.
28
+ - Model (torch.nn): Segmentation Model.
29
+
30
+ Returns:
31
+ - tuple: A tuple containing the resized input image and a list of segmentation masks.
32
+ """
33
+
34
+ img = transform(image).unsqueeze(0)
35
+
36
+ # Split the prompts string into a list of individual prompts
37
+ prompts = prompts.split(',')
38
+ num_prompts = len(prompts)
39
+
40
+ # Ensure no gradient computation during prediction for performance
41
+ with torch.no_grad():
42
+ # Get model predictions for each prompt
43
+ preds = model(img.repeat(len(prompts), 1, 1, 1), prompts)[0]
44
+
45
+ # Convert model predictions to segmentation masks
46
+ masks = [torch.sigmoid(preds[i][0]) for i in range(num_prompts)]
47
+ masks = [(m.squeeze(0).numpy(), prompts[i]) for i, m in enumerate(masks)]
48
+
49
+ # Return the resized input image and the list of segmentation masks
50
+ return (image.resize((352, 352), Image.LANCZOS), masks)
51
+
52
+ def get_examples():
53
+ examples = [
54
+ ['images/000010.jpg', 'deer, tree, grass'],
55
+ ['images/000002.jpg', 'train, tracks, electric pole, house'],
56
+ ['images/00125.jpg', 'dog, flowers'],
57
+ ['images/000010.jpg', 'horse, man, fence, buildings, hill'],
58
+ ['images/000004.jpg', 'car, truck, building, sky, traffic light, tree, clouds']
59
+ ]
60
+ return(examples)
61
+
62
+
63
+ def get_html():
64
+ html_string = """
65
+ <!DOCTYPE html>
66
+ <html lang="en">
67
+
68
+ <head>
69
+ <meta charset="UTF-8">
70
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
71
+ <title>Multi-Prompt Image Segmentation</title>
72
+ <link href="https://fonts.googleapis.com/css2?family=Roboto+Slab:wght@400;700&display=swap" rel="stylesheet">
73
+
74
+ <style>
75
+ /* General styling */
76
+ body {
77
+ font-family: 'Roboto Slab', serif;
78
+ margin: 0;
79
+ padding: 0;
80
+ background-color: #f4f4f4;
81
+ }
82
+
83
+ .app-header {
84
+ background: linear-gradient(135deg, #4a90e2, #50e3c2);
85
+ color: #fff;
86
+ text-align: center;
87
+ padding: 40px 0;
88
+ border-radius: 20px;
89
+ position: relative;
90
+ overflow: hidden;
91
+ box-shadow: 0px 10px 20px rgba(0, 0, 0, 0.1);
92
+ }
93
+
94
+ /* Ellipse Overlay */
95
+ .app-header::before {
96
+ content: "";
97
+ position: absolute;
98
+ top: -50%;
99
+ left: -50%;
100
+ width: 200%;
101
+ height: 200%;
102
+ background: rgba(255, 255, 255, 0.1);
103
+ transform: rotate(45deg);
104
+ border-radius: 50%;
105
+ }
106
+
107
+ /* Floating Shapes */
108
+ .app-header::after {
109
+ content: "";
110
+ position: absolute;
111
+ top: 20%;
112
+ right: 10%;
113
+ width: 70px;
114
+ height: 70px;
115
+ background: rgba(255, 255, 255, 0.2);
116
+ border-radius: 50%;
117
+ }
118
+
119
+ .floating-shape {
120
+ content: "";
121
+ position: absolute;
122
+ top: 10%;
123
+ left: 5%;
124
+ width: 50px;
125
+ height: 50px;
126
+ background: rgba(255, 255, 255, 0.2);
127
+ border-radius: 50%;
128
+ }
129
+
130
+ /* Text Styling */
131
+ .app-title {
132
+ font-size: 28px;
133
+ margin: 0;
134
+ font-weight: 700;
135
+ text-shadow: 2px 2px 4px rgba(0, 0, 0, 0.2);
136
+ }
137
+
138
+ .app-description {
139
+ font-size: 18px;
140
+ margin-top: 15px;
141
+ opacity: 0.9;
142
+ text-shadow: 1px 1px 3px rgba(0, 0, 0, 0.1);
143
+ }
144
+
145
+ /* Wavy Bottom */
146
+ .wavy-bottom {
147
+ position: absolute;
148
+ bottom: -10px;
149
+ left: 0;
150
+ width: 100%;
151
+ height: 20px;
152
+ background: #f4f4f4;
153
+ border-radius: 100% 100% 0 0;
154
+ }
155
+ </style>
156
+ </head>
157
+
158
+ <body>
159
+
160
+ <!-- App Header -->
161
+ <div class="app-header">
162
+ <h1 class="app-title">Multi-Prompt Image Segmentation</h1>
163
+ <p class="app-description">Upload an image and provide multiple text prompts separated by commas. Get segmented masks for each prompt.</p>
164
+ <div class="floating-shape"></div>
165
+ <div class="wavy-bottom"></div>
166
+ </div>
167
+
168
+ <!-- Rest of the app content will go here -->
169
+
170
+ </body>
171
+
172
+ </html>
173
+
174
+
175
+ """
176
+
177
+ return(html_string)