leedoming commited on
Commit
82ca7db
·
verified ·
1 Parent(s): c585644

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -9
app.py CHANGED
@@ -9,7 +9,8 @@ import json
9
  import numpy as np
10
  import cv2
11
  import chromadb
12
- from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
 
13
 
14
  # Load CLIP model and tokenizer
15
  @st.cache_resource
@@ -25,8 +26,8 @@ clip_model, preprocess_val, tokenizer, device = load_clip_model()
25
  # Load SegFormer model
26
  @st.cache_resource
27
  def load_segformer_model():
28
- model = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
29
  processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
 
30
  return model, processor
31
 
32
  segformer_model, segformer_processor = load_segformer_model()
@@ -91,15 +92,23 @@ def find_similar_images(query_embedding, collection, top_k=5):
91
  def segment_clothing(image):
92
  inputs = segformer_processor(images=image, return_tensors="pt")
93
  outputs = segformer_model(**inputs)
94
- logits = outputs.logits.squeeze()
95
- predicted_mask = logits.argmax(dim=0).numpy()
 
 
 
 
 
 
 
 
96
 
97
  categories = segformer_model.config.id2label
98
  segmented_items = []
99
 
100
  for category_id, category_name in categories.items():
101
- if category_id in predicted_mask:
102
- mask = (predicted_mask == category_id).astype(np.uint8)
103
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
104
  if contours:
105
  x, y, w, h = cv2.boundingRect(max(contours, key=cv2.contourArea))
@@ -109,7 +118,7 @@ def segment_clothing(image):
109
  'mask': mask
110
  })
111
 
112
- return segmented_items
113
 
114
  def crop_image(image, bbox):
115
  return image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
@@ -135,7 +144,7 @@ if st.session_state.step == 'input':
135
  query_image = load_image_from_url(st.session_state.query_image_url)
136
  if query_image is not None:
137
  st.session_state.query_image = query_image
138
- st.session_state.segmentations = segment_clothing(query_image)
139
  if st.session_state.segmentations:
140
  st.session_state.step = 'select_category'
141
  else:
@@ -146,7 +155,12 @@ if st.session_state.step == 'input':
146
  st.warning("Please enter an image URL.")
147
 
148
  elif st.session_state.step == 'select_category':
149
- st.image(st.session_state.query_image, caption="Query Image", use_column_width=True)
 
 
 
 
 
150
  st.subheader("Segmented Clothing Items:")
151
 
152
  for segmentation in st.session_state.segmentations:
 
9
  import numpy as np
10
  import cv2
11
  import chromadb
12
+ from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
13
+ import torch.nn as nn
14
 
15
  # Load CLIP model and tokenizer
16
  @st.cache_resource
 
26
  # Load SegFormer model
27
  @st.cache_resource
28
  def load_segformer_model():
 
29
  processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
30
+ model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
31
  return model, processor
32
 
33
  segformer_model, segformer_processor = load_segformer_model()
 
92
  def segment_clothing(image):
93
  inputs = segformer_processor(images=image, return_tensors="pt")
94
  outputs = segformer_model(**inputs)
95
+ logits = outputs.logits.cpu()
96
+
97
+ upsampled_logits = nn.functional.interpolate(
98
+ logits,
99
+ size=image.size[::-1],
100
+ mode="bilinear",
101
+ align_corners=False,
102
+ )
103
+
104
+ pred_seg = upsampled_logits.argmax(dim=1)[0].numpy()
105
 
106
  categories = segformer_model.config.id2label
107
  segmented_items = []
108
 
109
  for category_id, category_name in categories.items():
110
+ if category_id in pred_seg:
111
+ mask = (pred_seg == category_id).astype(np.uint8)
112
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
113
  if contours:
114
  x, y, w, h = cv2.boundingRect(max(contours, key=cv2.contourArea))
 
118
  'mask': mask
119
  })
120
 
121
+ return segmented_items, pred_seg
122
 
123
  def crop_image(image, bbox):
124
  return image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
 
144
  query_image = load_image_from_url(st.session_state.query_image_url)
145
  if query_image is not None:
146
  st.session_state.query_image = query_image
147
+ st.session_state.segmentations, st.session_state.pred_seg = segment_clothing(query_image)
148
  if st.session_state.segmentations:
149
  st.session_state.step = 'select_category'
150
  else:
 
155
  st.warning("Please enter an image URL.")
156
 
157
  elif st.session_state.step == 'select_category':
158
+ col1, col2 = st.columns(2)
159
+ with col1:
160
+ st.image(st.session_state.query_image, caption="Original Image", use_column_width=True)
161
+ with col2:
162
+ st.image(st.session_state.pred_seg, caption="Segmentation Map", use_column_width=True)
163
+
164
  st.subheader("Segmented Clothing Items:")
165
 
166
  for segmentation in st.session_state.segmentations: