Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -9,7 +9,8 @@ import json
|
|
9 |
import numpy as np
|
10 |
import cv2
|
11 |
import chromadb
|
12 |
-
from transformers import
|
|
|
13 |
|
14 |
# Load CLIP model and tokenizer
|
15 |
@st.cache_resource
|
@@ -25,8 +26,8 @@ clip_model, preprocess_val, tokenizer, device = load_clip_model()
|
|
25 |
# Load SegFormer model
|
26 |
@st.cache_resource
|
27 |
def load_segformer_model():
|
28 |
-
model = SegformerForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
|
29 |
processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
|
|
|
30 |
return model, processor
|
31 |
|
32 |
segformer_model, segformer_processor = load_segformer_model()
|
@@ -91,15 +92,23 @@ def find_similar_images(query_embedding, collection, top_k=5):
|
|
91 |
def segment_clothing(image):
|
92 |
inputs = segformer_processor(images=image, return_tensors="pt")
|
93 |
outputs = segformer_model(**inputs)
|
94 |
-
logits = outputs.logits.
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
categories = segformer_model.config.id2label
|
98 |
segmented_items = []
|
99 |
|
100 |
for category_id, category_name in categories.items():
|
101 |
-
if category_id in
|
102 |
-
mask = (
|
103 |
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
104 |
if contours:
|
105 |
x, y, w, h = cv2.boundingRect(max(contours, key=cv2.contourArea))
|
@@ -109,7 +118,7 @@ def segment_clothing(image):
|
|
109 |
'mask': mask
|
110 |
})
|
111 |
|
112 |
-
return segmented_items
|
113 |
|
114 |
def crop_image(image, bbox):
|
115 |
return image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
|
@@ -135,7 +144,7 @@ if st.session_state.step == 'input':
|
|
135 |
query_image = load_image_from_url(st.session_state.query_image_url)
|
136 |
if query_image is not None:
|
137 |
st.session_state.query_image = query_image
|
138 |
-
st.session_state.segmentations = segment_clothing(query_image)
|
139 |
if st.session_state.segmentations:
|
140 |
st.session_state.step = 'select_category'
|
141 |
else:
|
@@ -146,7 +155,12 @@ if st.session_state.step == 'input':
|
|
146 |
st.warning("Please enter an image URL.")
|
147 |
|
148 |
elif st.session_state.step == 'select_category':
|
149 |
-
|
|
|
|
|
|
|
|
|
|
|
150 |
st.subheader("Segmented Clothing Items:")
|
151 |
|
152 |
for segmentation in st.session_state.segmentations:
|
|
|
9 |
import numpy as np
|
10 |
import cv2
|
11 |
import chromadb
|
12 |
+
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
|
13 |
+
import torch.nn as nn
|
14 |
|
15 |
# Load CLIP model and tokenizer
|
16 |
@st.cache_resource
|
|
|
26 |
# Load SegFormer model
|
27 |
@st.cache_resource
|
28 |
def load_segformer_model():
|
|
|
29 |
processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
|
30 |
+
model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
|
31 |
return model, processor
|
32 |
|
33 |
segformer_model, segformer_processor = load_segformer_model()
|
|
|
92 |
def segment_clothing(image):
|
93 |
inputs = segformer_processor(images=image, return_tensors="pt")
|
94 |
outputs = segformer_model(**inputs)
|
95 |
+
logits = outputs.logits.cpu()
|
96 |
+
|
97 |
+
upsampled_logits = nn.functional.interpolate(
|
98 |
+
logits,
|
99 |
+
size=image.size[::-1],
|
100 |
+
mode="bilinear",
|
101 |
+
align_corners=False,
|
102 |
+
)
|
103 |
+
|
104 |
+
pred_seg = upsampled_logits.argmax(dim=1)[0].numpy()
|
105 |
|
106 |
categories = segformer_model.config.id2label
|
107 |
segmented_items = []
|
108 |
|
109 |
for category_id, category_name in categories.items():
|
110 |
+
if category_id in pred_seg:
|
111 |
+
mask = (pred_seg == category_id).astype(np.uint8)
|
112 |
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
113 |
if contours:
|
114 |
x, y, w, h = cv2.boundingRect(max(contours, key=cv2.contourArea))
|
|
|
118 |
'mask': mask
|
119 |
})
|
120 |
|
121 |
+
return segmented_items, pred_seg
|
122 |
|
123 |
def crop_image(image, bbox):
|
124 |
return image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
|
|
|
144 |
query_image = load_image_from_url(st.session_state.query_image_url)
|
145 |
if query_image is not None:
|
146 |
st.session_state.query_image = query_image
|
147 |
+
st.session_state.segmentations, st.session_state.pred_seg = segment_clothing(query_image)
|
148 |
if st.session_state.segmentations:
|
149 |
st.session_state.step = 'select_category'
|
150 |
else:
|
|
|
155 |
st.warning("Please enter an image URL.")
|
156 |
|
157 |
elif st.session_state.step == 'select_category':
|
158 |
+
col1, col2 = st.columns(2)
|
159 |
+
with col1:
|
160 |
+
st.image(st.session_state.query_image, caption="Original Image", use_column_width=True)
|
161 |
+
with col2:
|
162 |
+
st.image(st.session_state.pred_seg, caption="Segmentation Map", use_column_width=True)
|
163 |
+
|
164 |
st.subheader("Segmented Clothing Items:")
|
165 |
|
166 |
for segmentation in st.session_state.segmentations:
|