Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -9,7 +9,7 @@ import json
|
|
9 |
import numpy as np
|
10 |
import cv2
|
11 |
import chromadb
|
12 |
-
from transformers import
|
13 |
import torch.nn as nn
|
14 |
import matplotlib.pyplot as plt
|
15 |
|
@@ -24,14 +24,12 @@ def load_clip_model():
|
|
24 |
|
25 |
clip_model, preprocess_val, tokenizer, device = load_clip_model()
|
26 |
|
27 |
-
# Load
|
28 |
@st.cache_resource
|
29 |
-
def
|
30 |
-
|
31 |
-
model = AutoModelForSemanticSegmentation.from_pretrained("sayeed99/segformer-b3-fashion")
|
32 |
-
return model, processor
|
33 |
|
34 |
-
|
35 |
|
36 |
# Load ChromaDB
|
37 |
@st.cache_resource
|
@@ -83,55 +81,24 @@ def find_similar_images(query_embedding, collection, top_k=5):
|
|
83 |
]
|
84 |
return results
|
85 |
|
86 |
-
def segment_clothing(
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
logits,
|
93 |
-
size=image.size[::-1],
|
94 |
-
mode="bilinear",
|
95 |
-
align_corners=False,
|
96 |
-
)
|
97 |
-
|
98 |
-
pred_seg = upsampled_logits.argmax(dim=1)[0].cpu().numpy()
|
99 |
-
|
100 |
-
categories = segformer_model.config.id2label
|
101 |
-
segmented_items = []
|
102 |
-
|
103 |
-
for category_id, category_name in categories.items():
|
104 |
-
if category_id in pred_seg:
|
105 |
-
mask = (pred_seg == category_id).astype(np.uint8)
|
106 |
-
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
107 |
-
if contours:
|
108 |
-
x, y, w, h = cv2.boundingRect(max(contours, key=cv2.contourArea))
|
109 |
-
segmented_items.append({
|
110 |
-
'category': category_name,
|
111 |
-
'bbox': [x, y, x+w, y+h],
|
112 |
-
'mask': mask
|
113 |
-
})
|
114 |
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
123 |
|
124 |
-
|
125 |
-
unique_classes = np.unique(pred_seg)
|
126 |
-
legend_elements = [plt.Rectangle((0,0),1,1, color=plt.cm.jet(category_id/len(categories)),
|
127 |
-
label=f"{category_id}: {categories[category_id]}")
|
128 |
-
for category_id in unique_classes if category_id in categories]
|
129 |
-
plt.legend(handles=legend_elements, loc='center left', bbox_to_anchor=(1, 0.5))
|
130 |
-
|
131 |
-
return plt
|
132 |
-
|
133 |
-
def crop_image(image, bbox):
|
134 |
-
return image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
|
135 |
|
136 |
# Streamlit app
|
137 |
st.title("Advanced Fashion Search App")
|
@@ -154,7 +121,8 @@ if st.session_state.step == 'input':
|
|
154 |
query_image = load_image_from_url(st.session_state.query_image_url)
|
155 |
if query_image is not None:
|
156 |
st.session_state.query_image = query_image
|
157 |
-
|
|
|
158 |
if st.session_state.segmentations:
|
159 |
st.session_state.step = 'select_category'
|
160 |
else:
|
@@ -169,21 +137,11 @@ elif st.session_state.step == 'select_category':
|
|
169 |
with col1:
|
170 |
st.image(st.session_state.query_image, caption="Original Image", use_column_width=True)
|
171 |
with col2:
|
172 |
-
|
173 |
-
st.pyplot(seg_fig)
|
174 |
-
plt.close(seg_fig) # Prevent memory leaks
|
175 |
|
176 |
st.subheader("Segmented Clothing Items:")
|
177 |
|
178 |
-
for
|
179 |
-
col1, col2 = st.columns([1, 3])
|
180 |
-
with col1:
|
181 |
-
st.write(f"{segmentation['category']}")
|
182 |
-
with col2:
|
183 |
-
cropped_image = crop_image(st.session_state.query_image, segmentation['bbox'])
|
184 |
-
st.image(cropped_image, caption=segmentation['category'], use_column_width=True)
|
185 |
-
|
186 |
-
options = [s['category'] for s in st.session_state.segmentations]
|
187 |
selected_option = st.selectbox("Select a category to search:", options)
|
188 |
|
189 |
if st.button("Search Similar Items"):
|
@@ -192,11 +150,15 @@ elif st.session_state.step == 'select_category':
|
|
192 |
|
193 |
elif st.session_state.step == 'show_results':
|
194 |
st.image(st.session_state.query_image, caption="Query Image", use_column_width=True)
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
|
|
|
|
|
|
|
|
200 |
similar_images = find_similar_images(query_embedding, collection)
|
201 |
|
202 |
st.subheader("Similar Items:")
|
@@ -220,28 +182,26 @@ elif st.session_state.step == 'show_results':
|
|
220 |
st.session_state.segmentations = []
|
221 |
st.session_state.selected_category = None
|
222 |
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
else:
|
244 |
-
st.warning("Please enter a search text.")
|
245 |
|
246 |
# Display ChromaDB vacuum message
|
247 |
-
st.sidebar.warning("If you've upgraded ChromaDB from a version below 0.6, you may benefit from vacuuming your database
|
|
|
9 |
import numpy as np
|
10 |
import cv2
|
11 |
import chromadb
|
12 |
+
from transformers import pipeline
|
13 |
import torch.nn as nn
|
14 |
import matplotlib.pyplot as plt
|
15 |
|
|
|
24 |
|
25 |
clip_model, preprocess_val, tokenizer, device = load_clip_model()
|
26 |
|
27 |
+
# Load Clothing Segmentation model
|
28 |
@st.cache_resource
|
29 |
+
def load_segmentation_model():
|
30 |
+
return pipeline(model="mattmdjaga/segformer_b2_clothes")
|
|
|
|
|
31 |
|
32 |
+
segmenter = load_segmentation_model()
|
33 |
|
34 |
# Load ChromaDB
|
35 |
@st.cache_resource
|
|
|
81 |
]
|
82 |
return results
|
83 |
|
84 |
+
def segment_clothing(img, clothes=["Hat", "Upper-clothes", "Skirt", "Pants", "Dress", "Belt", "Left-shoe", "Right-shoe", "Scarf"]):
|
85 |
+
segments = segmenter(img)
|
86 |
+
mask_list = []
|
87 |
+
for s in segments:
|
88 |
+
if s['label'] in clothes:
|
89 |
+
mask_list.append(s['mask'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
+
if mask_list:
|
92 |
+
final_mask = np.array(mask_list[0])
|
93 |
+
for mask in mask_list[1:]:
|
94 |
+
current_mask = np.array(mask)
|
95 |
+
final_mask = final_mask + current_mask
|
96 |
+
|
97 |
+
final_mask = Image.fromarray(final_mask.astype('uint8') * 255)
|
98 |
+
img = img.convert("RGBA")
|
99 |
+
img.putalpha(final_mask)
|
100 |
|
101 |
+
return img, segments
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
# Streamlit app
|
104 |
st.title("Advanced Fashion Search App")
|
|
|
121 |
query_image = load_image_from_url(st.session_state.query_image_url)
|
122 |
if query_image is not None:
|
123 |
st.session_state.query_image = query_image
|
124 |
+
segmented_image, st.session_state.segmentations = segment_clothing(query_image)
|
125 |
+
st.session_state.segmented_image = segmented_image
|
126 |
if st.session_state.segmentations:
|
127 |
st.session_state.step = 'select_category'
|
128 |
else:
|
|
|
137 |
with col1:
|
138 |
st.image(st.session_state.query_image, caption="Original Image", use_column_width=True)
|
139 |
with col2:
|
140 |
+
st.image(st.session_state.segmented_image, caption="Segmented Image", use_column_width=True)
|
|
|
|
|
141 |
|
142 |
st.subheader("Segmented Clothing Items:")
|
143 |
|
144 |
+
options = list(set(s['label'] for s in st.session_state.segmentations))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
selected_option = st.selectbox("Select a category to search:", options)
|
146 |
|
147 |
if st.button("Search Similar Items"):
|
|
|
150 |
|
151 |
elif st.session_state.step == 'show_results':
|
152 |
st.image(st.session_state.query_image, caption="Query Image", use_column_width=True)
|
153 |
+
st.image(st.session_state.segmented_image, caption="Segmented Image", use_column_width=True)
|
154 |
+
|
155 |
+
selected_segment = next(s for s in st.session_state.segmentations if s['label'] == st.session_state.selected_category)
|
156 |
+
mask = np.array(selected_segment['mask'])
|
157 |
+
masked_image = Image.fromarray((np.array(st.session_state.query_image) * mask[:,:,None]).astype('uint8'))
|
158 |
+
|
159 |
+
st.image(masked_image, caption=f"Selected Category: {st.session_state.selected_category}", use_column_width=True)
|
160 |
+
|
161 |
+
query_embedding = get_image_embedding(masked_image)
|
162 |
similar_images = find_similar_images(query_embedding, collection)
|
163 |
|
164 |
st.subheader("Similar Items:")
|
|
|
182 |
st.session_state.segmentations = []
|
183 |
st.session_state.selected_category = None
|
184 |
|
185 |
+
# Text search (optional, you can keep or remove this part)
|
186 |
+
st.sidebar.title("Text Search")
|
187 |
+
query_text = st.sidebar.text_input("Enter search text:")
|
188 |
+
if st.sidebar.button("Search by Text"):
|
189 |
+
if query_text:
|
190 |
+
text_embedding = get_text_embedding(query_text)
|
191 |
+
similar_images = find_similar_images(text_embedding, collection)
|
192 |
+
st.sidebar.subheader("Similar Items:")
|
193 |
+
for img in similar_images:
|
194 |
+
st.sidebar.image(img['info']['image_url'], use_column_width=True)
|
195 |
+
st.sidebar.write(f"Name: {img['info']['name']}")
|
196 |
+
st.sidebar.write(f"Brand: {img['info']['brand']}")
|
197 |
+
category = img['info'].get('category')
|
198 |
+
if category:
|
199 |
+
st.sidebar.write(f"Category: {category}")
|
200 |
+
st.sidebar.write(f"Price: {img['info']['price']}")
|
201 |
+
st.sidebar.write(f"Discount: {img['info']['discount']}%")
|
202 |
+
st.sidebar.write(f"Similarity: {img['similarity']:.2f}")
|
203 |
+
else:
|
204 |
+
st.sidebar.warning("Please enter a search text.")
|
|
|
|
|
205 |
|
206 |
# Display ChromaDB vacuum message
|
207 |
+
st.sidebar.warning("If you've upgraded ChromaDB from a version below 0.6, you may benefit from vacuuming your database")
|