leedoming commited on
Commit
7218e6b
·
verified ·
1 Parent(s): 9d7f3f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -95
app.py CHANGED
@@ -9,7 +9,7 @@ import json
9
  import numpy as np
10
  import cv2
11
  import chromadb
12
- from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
13
  import torch.nn as nn
14
  import matplotlib.pyplot as plt
15
 
@@ -24,14 +24,12 @@ def load_clip_model():
24
 
25
  clip_model, preprocess_val, tokenizer, device = load_clip_model()
26
 
27
- # Load SegFormer model
28
  @st.cache_resource
29
- def load_segformer_model():
30
- processor = SegformerImageProcessor.from_pretrained("sayeed99/segformer-b3-fashion")
31
- model = AutoModelForSemanticSegmentation.from_pretrained("sayeed99/segformer-b3-fashion")
32
- return model, processor
33
 
34
- segformer_model, segformer_processor = load_segformer_model()
35
 
36
  # Load ChromaDB
37
  @st.cache_resource
@@ -83,55 +81,24 @@ def find_similar_images(query_embedding, collection, top_k=5):
83
  ]
84
  return results
85
 
86
- def segment_clothing(image):
87
- inputs = segformer_processor(images=image, return_tensors="pt")
88
- outputs = segformer_model(**inputs)
89
- logits = outputs.logits
90
-
91
- upsampled_logits = nn.functional.interpolate(
92
- logits,
93
- size=image.size[::-1],
94
- mode="bilinear",
95
- align_corners=False,
96
- )
97
-
98
- pred_seg = upsampled_logits.argmax(dim=1)[0].cpu().numpy()
99
-
100
- categories = segformer_model.config.id2label
101
- segmented_items = []
102
-
103
- for category_id, category_name in categories.items():
104
- if category_id in pred_seg:
105
- mask = (pred_seg == category_id).astype(np.uint8)
106
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
107
- if contours:
108
- x, y, w, h = cv2.boundingRect(max(contours, key=cv2.contourArea))
109
- segmented_items.append({
110
- 'category': category_name,
111
- 'bbox': [x, y, x+w, y+h],
112
- 'mask': mask
113
- })
114
 
115
- return segmented_items, pred_seg, categories
116
-
117
- def visualize_segmentation(pred_seg, categories):
118
- plt.figure(figsize=(10, 10))
119
- plt.imshow(pred_seg, cmap='jet')
120
- plt.colorbar(label='Category ID')
121
- plt.title("Segmentation Result")
122
- plt.axis('off')
 
123
 
124
- # Add legend
125
- unique_classes = np.unique(pred_seg)
126
- legend_elements = [plt.Rectangle((0,0),1,1, color=plt.cm.jet(category_id/len(categories)),
127
- label=f"{category_id}: {categories[category_id]}")
128
- for category_id in unique_classes if category_id in categories]
129
- plt.legend(handles=legend_elements, loc='center left', bbox_to_anchor=(1, 0.5))
130
-
131
- return plt
132
-
133
- def crop_image(image, bbox):
134
- return image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
135
 
136
  # Streamlit app
137
  st.title("Advanced Fashion Search App")
@@ -154,7 +121,8 @@ if st.session_state.step == 'input':
154
  query_image = load_image_from_url(st.session_state.query_image_url)
155
  if query_image is not None:
156
  st.session_state.query_image = query_image
157
- st.session_state.segmentations, st.session_state.pred_seg, st.session_state.categories = segment_clothing(query_image)
 
158
  if st.session_state.segmentations:
159
  st.session_state.step = 'select_category'
160
  else:
@@ -169,21 +137,11 @@ elif st.session_state.step == 'select_category':
169
  with col1:
170
  st.image(st.session_state.query_image, caption="Original Image", use_column_width=True)
171
  with col2:
172
- seg_fig = visualize_segmentation(st.session_state.pred_seg, st.session_state.categories)
173
- st.pyplot(seg_fig)
174
- plt.close(seg_fig) # Prevent memory leaks
175
 
176
  st.subheader("Segmented Clothing Items:")
177
 
178
- for segmentation in st.session_state.segmentations:
179
- col1, col2 = st.columns([1, 3])
180
- with col1:
181
- st.write(f"{segmentation['category']}")
182
- with col2:
183
- cropped_image = crop_image(st.session_state.query_image, segmentation['bbox'])
184
- st.image(cropped_image, caption=segmentation['category'], use_column_width=True)
185
-
186
- options = [s['category'] for s in st.session_state.segmentations]
187
  selected_option = st.selectbox("Select a category to search:", options)
188
 
189
  if st.button("Search Similar Items"):
@@ -192,11 +150,15 @@ elif st.session_state.step == 'select_category':
192
 
193
  elif st.session_state.step == 'show_results':
194
  st.image(st.session_state.query_image, caption="Query Image", use_column_width=True)
195
- selected_segmentation = next(s for s in st.session_state.segmentations
196
- if s['category'] == st.session_state.selected_category)
197
- cropped_image = crop_image(st.session_state.query_image, selected_segmentation['bbox'])
198
- st.image(cropped_image, caption="Cropped Image", use_column_width=True)
199
- query_embedding = get_image_embedding(cropped_image)
 
 
 
 
200
  similar_images = find_similar_images(query_embedding, collection)
201
 
202
  st.subheader("Similar Items:")
@@ -220,28 +182,26 @@ elif st.session_state.step == 'show_results':
220
  st.session_state.segmentations = []
221
  st.session_state.selected_category = None
222
 
223
- else: # Text search
224
- query_text = st.text_input("Enter search text:")
225
- if st.button("Search by Text"):
226
- if query_text:
227
- text_embedding = get_text_embedding(query_text)
228
- similar_images = find_similar_images(text_embedding, collection)
229
- st.subheader("Similar Items:")
230
- for img in similar_images:
231
- col1, col2 = st.columns(2)
232
- with col1:
233
- st.image(img['info']['image_url'], use_column_width=True)
234
- with col2:
235
- st.write(f"Name: {img['info']['name']}")
236
- st.write(f"Brand: {img['info']['brand']}")
237
- category = img['info'].get('category')
238
- if category:
239
- st.write(f"Category: {category}")
240
- st.write(f"Price: {img['info']['price']}")
241
- st.write(f"Discount: {img['info']['discount']}%")
242
- st.write(f"Similarity: {img['similarity']:.2f}")
243
- else:
244
- st.warning("Please enter a search text.")
245
 
246
  # Display ChromaDB vacuum message
247
- st.sidebar.warning("If you've upgraded ChromaDB from a version below 0.6, you may benefit from vacuuming your database. Run 'chromadb utils vacuum --help' for more information.")
 
9
  import numpy as np
10
  import cv2
11
  import chromadb
12
+ from transformers import pipeline
13
  import torch.nn as nn
14
  import matplotlib.pyplot as plt
15
 
 
24
 
25
  clip_model, preprocess_val, tokenizer, device = load_clip_model()
26
 
27
+ # Load Clothing Segmentation model
28
  @st.cache_resource
29
+ def load_segmentation_model():
30
+ return pipeline(model="mattmdjaga/segformer_b2_clothes")
 
 
31
 
32
+ segmenter = load_segmentation_model()
33
 
34
  # Load ChromaDB
35
  @st.cache_resource
 
81
  ]
82
  return results
83
 
84
+ def segment_clothing(img, clothes=["Hat", "Upper-clothes", "Skirt", "Pants", "Dress", "Belt", "Left-shoe", "Right-shoe", "Scarf"]):
85
+ segments = segmenter(img)
86
+ mask_list = []
87
+ for s in segments:
88
+ if s['label'] in clothes:
89
+ mask_list.append(s['mask'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ if mask_list:
92
+ final_mask = np.array(mask_list[0])
93
+ for mask in mask_list[1:]:
94
+ current_mask = np.array(mask)
95
+ final_mask = final_mask + current_mask
96
+
97
+ final_mask = Image.fromarray(final_mask.astype('uint8') * 255)
98
+ img = img.convert("RGBA")
99
+ img.putalpha(final_mask)
100
 
101
+ return img, segments
 
 
 
 
 
 
 
 
 
 
102
 
103
  # Streamlit app
104
  st.title("Advanced Fashion Search App")
 
121
  query_image = load_image_from_url(st.session_state.query_image_url)
122
  if query_image is not None:
123
  st.session_state.query_image = query_image
124
+ segmented_image, st.session_state.segmentations = segment_clothing(query_image)
125
+ st.session_state.segmented_image = segmented_image
126
  if st.session_state.segmentations:
127
  st.session_state.step = 'select_category'
128
  else:
 
137
  with col1:
138
  st.image(st.session_state.query_image, caption="Original Image", use_column_width=True)
139
  with col2:
140
+ st.image(st.session_state.segmented_image, caption="Segmented Image", use_column_width=True)
 
 
141
 
142
  st.subheader("Segmented Clothing Items:")
143
 
144
+ options = list(set(s['label'] for s in st.session_state.segmentations))
 
 
 
 
 
 
 
 
145
  selected_option = st.selectbox("Select a category to search:", options)
146
 
147
  if st.button("Search Similar Items"):
 
150
 
151
  elif st.session_state.step == 'show_results':
152
  st.image(st.session_state.query_image, caption="Query Image", use_column_width=True)
153
+ st.image(st.session_state.segmented_image, caption="Segmented Image", use_column_width=True)
154
+
155
+ selected_segment = next(s for s in st.session_state.segmentations if s['label'] == st.session_state.selected_category)
156
+ mask = np.array(selected_segment['mask'])
157
+ masked_image = Image.fromarray((np.array(st.session_state.query_image) * mask[:,:,None]).astype('uint8'))
158
+
159
+ st.image(masked_image, caption=f"Selected Category: {st.session_state.selected_category}", use_column_width=True)
160
+
161
+ query_embedding = get_image_embedding(masked_image)
162
  similar_images = find_similar_images(query_embedding, collection)
163
 
164
  st.subheader("Similar Items:")
 
182
  st.session_state.segmentations = []
183
  st.session_state.selected_category = None
184
 
185
+ # Text search (optional, you can keep or remove this part)
186
+ st.sidebar.title("Text Search")
187
+ query_text = st.sidebar.text_input("Enter search text:")
188
+ if st.sidebar.button("Search by Text"):
189
+ if query_text:
190
+ text_embedding = get_text_embedding(query_text)
191
+ similar_images = find_similar_images(text_embedding, collection)
192
+ st.sidebar.subheader("Similar Items:")
193
+ for img in similar_images:
194
+ st.sidebar.image(img['info']['image_url'], use_column_width=True)
195
+ st.sidebar.write(f"Name: {img['info']['name']}")
196
+ st.sidebar.write(f"Brand: {img['info']['brand']}")
197
+ category = img['info'].get('category')
198
+ if category:
199
+ st.sidebar.write(f"Category: {category}")
200
+ st.sidebar.write(f"Price: {img['info']['price']}")
201
+ st.sidebar.write(f"Discount: {img['info']['discount']}%")
202
+ st.sidebar.write(f"Similarity: {img['similarity']:.2f}")
203
+ else:
204
+ st.sidebar.warning("Please enter a search text.")
 
 
205
 
206
  # Display ChromaDB vacuum message
207
+ st.sidebar.warning("If you've upgraded ChromaDB from a version below 0.6, you may benefit from vacuuming your database")