leedoming commited on
Commit
28455d6
·
verified ·
1 Parent(s): f0cd321

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -39
app.py CHANGED
@@ -11,15 +11,12 @@ import cv2
11
  import chromadb
12
  from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
13
  import torch.nn as nn
14
- import warnings
15
-
16
- # Suppress specific warnings
17
- warnings.filterwarnings("ignore", category=UserWarning, module="transformers.utils.deprecation")
18
 
19
  # Load CLIP model and tokenizer
20
  @st.cache_resource
21
  def load_clip_model():
22
- model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
23
  tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionSigLIP')
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
  model.to(device)
@@ -30,12 +27,21 @@ clip_model, preprocess_val, tokenizer, device = load_clip_model()
30
  # Load SegFormer model
31
  @st.cache_resource
32
  def load_segformer_model():
33
- processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
34
- model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
35
  return model, processor
36
 
37
  segformer_model, segformer_processor = load_segformer_model()
38
 
 
 
 
 
 
 
 
 
 
39
  # Helper functions
40
  def load_image_from_url(url, max_retries=3):
41
  for attempt in range(max_retries):
@@ -50,15 +56,6 @@ def load_image_from_url(url, max_retries=3):
50
  else:
51
  return None
52
 
53
- # Load ChromaDB
54
- @st.cache_resource
55
- def load_chromadb():
56
- client = chromadb.PersistentClient(path="./clothesDB")
57
- collection = client.get_collection(name="clothes_items_ver3")
58
- return collection
59
-
60
- collection = load_chromadb()
61
-
62
  def get_image_embedding(image):
63
  image_tensor = preprocess_val(image).unsqueeze(0).to(device)
64
  with torch.no_grad():
@@ -73,35 +70,23 @@ def get_text_embedding(text):
73
  text_features /= text_features.norm(dim=-1, keepdim=True)
74
  return text_features.cpu().numpy()
75
 
76
- def get_all_embeddings_from_collection(collection):
77
- all_embeddings = collection.get(include=['embeddings'])['embeddings']
78
- return np.array(all_embeddings)
79
-
80
- def get_metadata_from_ids(collection, ids):
81
- results = collection.get(ids=ids)
82
- return results['metadatas']
83
-
84
  def find_similar_images(query_embedding, collection, top_k=5):
85
- database_embeddings = get_all_embeddings_from_collection(collection)
86
  similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
87
  top_indices = np.argsort(similarities)[::-1][:top_k]
88
 
89
  all_data = collection.get(include=['metadatas'])['metadatas']
90
 
91
- top_metadatas = [all_data[idx] for idx in top_indices]
92
-
93
- results = []
94
- for idx, metadata in enumerate(top_metadatas):
95
- results.append({
96
- 'info': metadata,
97
- 'similarity': similarities[top_indices[idx]]
98
- })
99
  return results
100
 
101
  def segment_clothing(image):
102
  inputs = segformer_processor(images=image, return_tensors="pt")
103
  outputs = segformer_model(**inputs)
104
- logits = outputs.logits.cpu()
105
 
106
  upsampled_logits = nn.functional.interpolate(
107
  logits,
@@ -110,7 +95,7 @@ def segment_clothing(image):
110
  align_corners=False,
111
  )
112
 
113
- pred_seg = upsampled_logits.argmax(dim=1)[0].numpy()
114
 
115
  categories = segformer_model.config.id2label
116
  segmented_items = []
@@ -127,7 +112,23 @@ def segment_clothing(image):
127
  'mask': mask
128
  })
129
 
130
- return segmented_items, pred_seg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  def crop_image(image, bbox):
133
  return image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
@@ -153,7 +154,7 @@ if st.session_state.step == 'input':
153
  query_image = load_image_from_url(st.session_state.query_image_url)
154
  if query_image is not None:
155
  st.session_state.query_image = query_image
156
- st.session_state.segmentations, st.session_state.pred_seg = segment_clothing(query_image)
157
  if st.session_state.segmentations:
158
  st.session_state.step = 'select_category'
159
  else:
@@ -168,8 +169,10 @@ elif st.session_state.step == 'select_category':
168
  with col1:
169
  st.image(st.session_state.query_image, caption="Original Image", use_column_width=True)
170
  with col2:
171
- st.image(st.session_state.pred_seg, caption="Segmentation Map", use_column_width=True)
172
-
 
 
173
  st.subheader("Segmented Clothing Items:")
174
 
175
  for segmentation in st.session_state.segmentations:
 
11
  import chromadb
12
  from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
13
  import torch.nn as nn
14
+ import matplotlib.pyplot as plt
 
 
 
15
 
16
  # Load CLIP model and tokenizer
17
  @st.cache_resource
18
  def load_clip_model():
19
+ model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
20
  tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionSigLIP')
21
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
  model.to(device)
 
27
  # Load SegFormer model
28
  @st.cache_resource
29
  def load_segformer_model():
30
+ processor = SegformerImageProcessor.from_pretrained("sayeed99/segformer-b3-fashion")
31
+ model = AutoModelForSemanticSegmentation.from_pretrained("sayeed99/segformer-b3-fashion")
32
  return model, processor
33
 
34
  segformer_model, segformer_processor = load_segformer_model()
35
 
36
+ # Load ChromaDB
37
+ @st.cache_resource
38
+ def load_chromadb():
39
+ client = chromadb.PersistentClient(path="./clothesDB")
40
+ collection = client.get_collection(name="clothes_items_ver3")
41
+ return collection
42
+
43
+ collection = load_chromadb()
44
+
45
  # Helper functions
46
  def load_image_from_url(url, max_retries=3):
47
  for attempt in range(max_retries):
 
56
  else:
57
  return None
58
 
 
 
 
 
 
 
 
 
 
59
  def get_image_embedding(image):
60
  image_tensor = preprocess_val(image).unsqueeze(0).to(device)
61
  with torch.no_grad():
 
70
  text_features /= text_features.norm(dim=-1, keepdim=True)
71
  return text_features.cpu().numpy()
72
 
 
 
 
 
 
 
 
 
73
  def find_similar_images(query_embedding, collection, top_k=5):
74
+ database_embeddings = np.array(collection.get(include=['embeddings'])['embeddings'])
75
  similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
76
  top_indices = np.argsort(similarities)[::-1][:top_k]
77
 
78
  all_data = collection.get(include=['metadatas'])['metadatas']
79
 
80
+ results = [
81
+ {'info': all_data[idx], 'similarity': similarities[idx]}
82
+ for idx in top_indices
83
+ ]
 
 
 
 
84
  return results
85
 
86
  def segment_clothing(image):
87
  inputs = segformer_processor(images=image, return_tensors="pt")
88
  outputs = segformer_model(**inputs)
89
+ logits = outputs.logits
90
 
91
  upsampled_logits = nn.functional.interpolate(
92
  logits,
 
95
  align_corners=False,
96
  )
97
 
98
+ pred_seg = upsampled_logits.argmax(dim=1)[0].cpu().numpy()
99
 
100
  categories = segformer_model.config.id2label
101
  segmented_items = []
 
112
  'mask': mask
113
  })
114
 
115
+ return segmented_items, pred_seg, categories
116
+
117
+ def visualize_segmentation(pred_seg, categories):
118
+ plt.figure(figsize=(10, 10))
119
+ plt.imshow(pred_seg, cmap='jet')
120
+ plt.colorbar(label='Category ID')
121
+ plt.title("Segmentation Result")
122
+ plt.axis('off')
123
+
124
+ # Add legend
125
+ unique_classes = np.unique(pred_seg)
126
+ legend_elements = [plt.Rectangle((0,0),1,1, color=plt.cm.jet(category_id/len(categories)),
127
+ label=f"{category_id}: {categories[category_id]}")
128
+ for category_id in unique_classes if category_id in categories]
129
+ plt.legend(handles=legend_elements, loc='center left', bbox_to_anchor=(1, 0.5))
130
+
131
+ return plt
132
 
133
  def crop_image(image, bbox):
134
  return image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
 
154
  query_image = load_image_from_url(st.session_state.query_image_url)
155
  if query_image is not None:
156
  st.session_state.query_image = query_image
157
+ st.session_state.segmentations, st.session_state.pred_seg, st.session_state.categories = segment_clothing(query_image)
158
  if st.session_state.segmentations:
159
  st.session_state.step = 'select_category'
160
  else:
 
169
  with col1:
170
  st.image(st.session_state.query_image, caption="Original Image", use_column_width=True)
171
  with col2:
172
+ seg_fig = visualize_segmentation(st.session_state.pred_seg, st.session_state.categories)
173
+ st.pyplot(seg_fig)
174
+ plt.close(seg_fig) # Prevent memory leaks
175
+
176
  st.subheader("Segmented Clothing Items:")
177
 
178
  for segmentation in st.session_state.segmentations: