Spaces:
Running
Running
Upload 14 files
Browse files- .gitattributes +1 -0
- Dockerfile +26 -0
- README.md +13 -13
- app.py +444 -0
- app_10281200.py +235 -0
- app_accessary.py +216 -0
- app_origin.py +0 -0
- clothesDB_11GmarketMusinsa/b4c365a0-3372-4031-9ff8-d457b4123d0a/data_level0.bin +3 -0
- clothesDB_11GmarketMusinsa/b4c365a0-3372-4031-9ff8-d457b4123d0a/header.bin +3 -0
- clothesDB_11GmarketMusinsa/b4c365a0-3372-4031-9ff8-d457b4123d0a/index_metadata.pickle +3 -0
- clothesDB_11GmarketMusinsa/b4c365a0-3372-4031-9ff8-d457b4123d0a/length.bin +3 -0
- clothesDB_11GmarketMusinsa/b4c365a0-3372-4031-9ff8-d457b4123d0a/link_lists.bin +3 -0
- clothesDB_11GmarketMusinsa/chroma.sqlite3 +3 -0
- db_creation.log +2 -0
- db_segmentation.py +221 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
clothesDB_11GmarketMusinsa/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use Python 3.10
|
2 |
+
FROM python:3.10-slim
|
3 |
+
|
4 |
+
# Set the working directory in the container
|
5 |
+
WORKDIR /app
|
6 |
+
|
7 |
+
# Copy the current directory contents into the container at /app
|
8 |
+
COPY . /app
|
9 |
+
|
10 |
+
# Install system dependencies
|
11 |
+
RUN apt-get update && apt-get install -y \
|
12 |
+
build-essential \
|
13 |
+
curl \
|
14 |
+
software-properties-common \
|
15 |
+
git \
|
16 |
+
&& rm -rf /var/lib/apt/lists/*
|
17 |
+
|
18 |
+
# Upgrade pip and install required python packages
|
19 |
+
RUN pip install --no-cache-dir --upgrade pip
|
20 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
21 |
+
|
22 |
+
# Make port 8501 available to the world outside this container
|
23 |
+
EXPOSE 8501
|
24 |
+
|
25 |
+
# Run app.py when the container launches
|
26 |
+
CMD ["streamlit", "run", "app.py"]
|
README.md
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
-
---
|
2 |
-
title: Itda
|
3 |
-
emoji: 🐨
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
-
sdk: streamlit
|
7 |
-
sdk_version: 1.
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: apache-2.0
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
---
|
2 |
+
title: Itda Nosegmentation
|
3 |
+
emoji: 🐨
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: purple
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: 1.38.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import open_clip
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
from transformers import pipeline
|
7 |
+
import chromadb
|
8 |
+
import logging
|
9 |
+
import io
|
10 |
+
import requests
|
11 |
+
from concurrent.futures import ThreadPoolExecutor
|
12 |
+
|
13 |
+
# 로깅 설정
|
14 |
+
logging.basicConfig(level=logging.INFO)
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
# Initialize session state
|
18 |
+
if 'image' not in st.session_state:
|
19 |
+
st.session_state.image = None
|
20 |
+
if 'detected_items' not in st.session_state:
|
21 |
+
st.session_state.detected_items = None
|
22 |
+
if 'selected_item_index' not in st.session_state:
|
23 |
+
st.session_state.selected_item_index = None
|
24 |
+
if 'upload_state' not in st.session_state:
|
25 |
+
st.session_state.upload_state = 'initial'
|
26 |
+
if 'search_clicked' not in st.session_state:
|
27 |
+
st.session_state.search_clicked = False
|
28 |
+
|
29 |
+
# Load models
|
30 |
+
@st.cache_resource
|
31 |
+
def load_models():
|
32 |
+
try:
|
33 |
+
# CLIP 모델
|
34 |
+
model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
|
35 |
+
|
36 |
+
# 세그멘테이션 모델
|
37 |
+
segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
|
38 |
+
|
39 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
40 |
+
model.to(device)
|
41 |
+
|
42 |
+
return model, preprocess_val, segmenter, device
|
43 |
+
except Exception as e:
|
44 |
+
logger.error(f"Error loading models: {e}")
|
45 |
+
raise
|
46 |
+
|
47 |
+
# 모델 로드
|
48 |
+
clip_model, preprocess_val, segmenter, device = load_models()
|
49 |
+
|
50 |
+
# ChromaDB 설정
|
51 |
+
client = chromadb.PersistentClient(path="./clothesDB_11GmarketMusinsa")
|
52 |
+
collection = client.get_collection(name="clothes")
|
53 |
+
|
54 |
+
def process_segmentation(image):
|
55 |
+
"""Segmentation processing"""
|
56 |
+
try:
|
57 |
+
# pipeline 출력 결과 직접 처리
|
58 |
+
output = segmenter(image)
|
59 |
+
|
60 |
+
if not output:
|
61 |
+
logger.warning("No segments found in image")
|
62 |
+
return None
|
63 |
+
|
64 |
+
# 각 세그먼트의 마스크 크기 계산
|
65 |
+
segment_sizes = [np.sum(seg['mask']) for seg in output]
|
66 |
+
|
67 |
+
if not segment_sizes:
|
68 |
+
return None
|
69 |
+
|
70 |
+
# 가장 큰 세그먼트 선택
|
71 |
+
largest_idx = np.argmax(segment_sizes)
|
72 |
+
mask = output[largest_idx]['mask']
|
73 |
+
|
74 |
+
# 마스크가 numpy array가 아닌 경우 변환
|
75 |
+
if not isinstance(mask, np.ndarray):
|
76 |
+
mask = np.array(mask)
|
77 |
+
|
78 |
+
# 마스크가 2D가 아닌 경우 첫 번째 채널 사용
|
79 |
+
if len(mask.shape) > 2:
|
80 |
+
mask = mask[:, :, 0]
|
81 |
+
|
82 |
+
# bool 마스크를 float로 변환
|
83 |
+
mask = mask.astype(float)
|
84 |
+
|
85 |
+
logger.info(f"Successfully created mask with shape {mask.shape}")
|
86 |
+
return mask
|
87 |
+
|
88 |
+
except Exception as e:
|
89 |
+
logger.error(f"Segmentation error: {str(e)}")
|
90 |
+
import traceback
|
91 |
+
logger.error(traceback.format_exc())
|
92 |
+
return None
|
93 |
+
|
94 |
+
def download_and_process_image(image_url, metadata_id):
|
95 |
+
"""Download image from URL and apply segmentation"""
|
96 |
+
try:
|
97 |
+
response = requests.get(image_url, timeout=10) # timeout 추가
|
98 |
+
if response.status_code != 200:
|
99 |
+
logger.error(f"Failed to download image {metadata_id}: HTTP {response.status_code}")
|
100 |
+
return None
|
101 |
+
|
102 |
+
image = Image.open(io.BytesIO(response.content)).convert('RGB')
|
103 |
+
logger.info(f"Successfully downloaded image {metadata_id}")
|
104 |
+
|
105 |
+
mask = process_segmentation(image)
|
106 |
+
if mask is not None:
|
107 |
+
features = extract_features(image, mask)
|
108 |
+
logger.info(f"Successfully extracted features for image {metadata_id}")
|
109 |
+
return features
|
110 |
+
|
111 |
+
logger.warning(f"No valid mask found for image {metadata_id}")
|
112 |
+
return None
|
113 |
+
|
114 |
+
except Exception as e:
|
115 |
+
logger.error(f"Error processing image {metadata_id}: {str(e)}")
|
116 |
+
import traceback
|
117 |
+
logger.error(traceback.format_exc())
|
118 |
+
return None
|
119 |
+
|
120 |
+
def update_db_with_segmentation():
|
121 |
+
"""DB의 모든 이미지에 대해 segmentation을 적용하고 feature를 업데이트"""
|
122 |
+
try:
|
123 |
+
logger.info("Starting database update with segmentation")
|
124 |
+
|
125 |
+
# 새로운 collection 생성
|
126 |
+
try:
|
127 |
+
client.delete_collection("clothes_segmented")
|
128 |
+
logger.info("Deleted existing segmented collection")
|
129 |
+
except:
|
130 |
+
logger.info("No existing segmented collection to delete")
|
131 |
+
|
132 |
+
new_collection = client.create_collection(
|
133 |
+
name="clothes_segmented",
|
134 |
+
metadata={"description": "Clothes collection with segmentation-based features"}
|
135 |
+
)
|
136 |
+
logger.info("Created new segmented collection")
|
137 |
+
|
138 |
+
# 기존 collection에서 메타데이터만 가져오기
|
139 |
+
try:
|
140 |
+
all_items = collection.get(include=['metadatas'])
|
141 |
+
total_items = len(all_items['metadatas'])
|
142 |
+
logger.info(f"Found {total_items} items in database")
|
143 |
+
except Exception as e:
|
144 |
+
logger.error(f"Error getting items from collection: {str(e)}")
|
145 |
+
# 에러 발생 시 빈 리스트로 초기화
|
146 |
+
all_items = {'metadatas': []}
|
147 |
+
total_items = 0
|
148 |
+
|
149 |
+
# 진행 상황 표시를 위한 progress bar
|
150 |
+
progress_bar = st.progress(0)
|
151 |
+
status_text = st.empty()
|
152 |
+
|
153 |
+
successful_updates = 0
|
154 |
+
failed_updates = 0
|
155 |
+
|
156 |
+
with ThreadPoolExecutor(max_workers=4) as executor:
|
157 |
+
futures = []
|
158 |
+
# 이미지 URL이 있는 항목만 처리
|
159 |
+
valid_items = [m for m in all_items['metadatas'] if 'image_url' in m]
|
160 |
+
|
161 |
+
for metadata in valid_items:
|
162 |
+
future = executor.submit(
|
163 |
+
download_and_process_image,
|
164 |
+
metadata['image_url'],
|
165 |
+
metadata.get('id', 'unknown')
|
166 |
+
)
|
167 |
+
futures.append((metadata, future))
|
168 |
+
|
169 |
+
# 결과 처리 및 새 DB에 저장
|
170 |
+
for idx, (metadata, future) in enumerate(futures):
|
171 |
+
try:
|
172 |
+
new_features = future.result()
|
173 |
+
if new_features is not None:
|
174 |
+
item_id = metadata.get('id', str(hash(metadata['image_url'])))
|
175 |
+
try:
|
176 |
+
new_collection.add(
|
177 |
+
embeddings=[new_features.tolist()],
|
178 |
+
metadatas=[metadata],
|
179 |
+
ids=[item_id]
|
180 |
+
)
|
181 |
+
successful_updates += 1
|
182 |
+
logger.info(f"Successfully added item {item_id}")
|
183 |
+
except Exception as e:
|
184 |
+
logger.error(f"Error adding item to new collection: {str(e)}")
|
185 |
+
failed_updates += 1
|
186 |
+
else:
|
187 |
+
failed_updates += 1
|
188 |
+
|
189 |
+
# 진행 상황 업데이트
|
190 |
+
progress = (idx + 1) / len(futures)
|
191 |
+
progress_bar.progress(progress)
|
192 |
+
status_text.text(f"Processing: {idx + 1}/{len(futures)} items. Success: {successful_updates}, Failed: {failed_updates}")
|
193 |
+
|
194 |
+
except Exception as e:
|
195 |
+
logger.error(f"Error processing item: {str(e)}")
|
196 |
+
failed_updates += 1
|
197 |
+
continue
|
198 |
+
|
199 |
+
# 최종 결과 표시
|
200 |
+
status_text.text(f"Update completed. Successfully processed: {successful_updates}, Failed: {failed_updates}")
|
201 |
+
logger.info(f"Database update completed. Successful: {successful_updates}, Failed: {failed_updates}")
|
202 |
+
|
203 |
+
# 성공적으로 처리된 항목이 있는지 확인
|
204 |
+
if successful_updates > 0:
|
205 |
+
return True
|
206 |
+
else:
|
207 |
+
logger.error("No items were successfully processed")
|
208 |
+
return False
|
209 |
+
|
210 |
+
except Exception as e:
|
211 |
+
logger.error(f"Database update error: {str(e)}")
|
212 |
+
import traceback
|
213 |
+
logger.error(traceback.format_exc())
|
214 |
+
return False
|
215 |
+
|
216 |
+
def extract_features(image, mask=None):
|
217 |
+
"""Extract CLIP features with segmentation mask"""
|
218 |
+
try:
|
219 |
+
if mask is not None:
|
220 |
+
img_array = np.array(image)
|
221 |
+
mask = np.expand_dims(mask, axis=2)
|
222 |
+
masked_img = img_array * mask
|
223 |
+
masked_img[mask[:,:,0] == 0] = 255 # 배경을 흰색으로
|
224 |
+
image = Image.fromarray(masked_img.astype(np.uint8))
|
225 |
+
|
226 |
+
image_tensor = preprocess_val(image).unsqueeze(0).to(device)
|
227 |
+
with torch.no_grad():
|
228 |
+
features = clip_model.encode_image(image_tensor)
|
229 |
+
features /= features.norm(dim=-1, keepdim=True)
|
230 |
+
return features.cpu().numpy().flatten()
|
231 |
+
except Exception as e:
|
232 |
+
logger.error(f"Feature extraction error: {e}")
|
233 |
+
raise
|
234 |
+
|
235 |
+
def search_similar_items(features, top_k=10):
|
236 |
+
"""Search similar items using segmentation-based features"""
|
237 |
+
try:
|
238 |
+
# 세그멘테이션이 적용된 collection이 있는지 확인
|
239 |
+
try:
|
240 |
+
search_collection = client.get_collection("clothes_segmented")
|
241 |
+
logger.info("Using segmented collection for search")
|
242 |
+
except:
|
243 |
+
# 없으면 기존 collection 사용
|
244 |
+
search_collection = collection
|
245 |
+
logger.info("Using original collection for search")
|
246 |
+
|
247 |
+
results = search_collection.query(
|
248 |
+
query_embeddings=[features.tolist()],
|
249 |
+
n_results=top_k,
|
250 |
+
include=['metadatas', 'distances']
|
251 |
+
)
|
252 |
+
|
253 |
+
if not results or not results['metadatas'] or not results['distances']:
|
254 |
+
logger.warning("No results returned from ChromaDB")
|
255 |
+
return []
|
256 |
+
|
257 |
+
similar_items = []
|
258 |
+
for metadata, distance in zip(results['metadatas'][0], results['distances'][0]):
|
259 |
+
try:
|
260 |
+
similarity_score = 1 / (1 + float(distance))
|
261 |
+
item_data = metadata.copy()
|
262 |
+
item_data['similarity_score'] = similarity_score
|
263 |
+
similar_items.append(item_data)
|
264 |
+
except Exception as e:
|
265 |
+
logger.error(f"Error processing search result: {str(e)}")
|
266 |
+
continue
|
267 |
+
|
268 |
+
similar_items.sort(key=lambda x: x['similarity_score'], reverse=True)
|
269 |
+
return similar_items
|
270 |
+
except Exception as e:
|
271 |
+
logger.error(f"Search error: {str(e)}")
|
272 |
+
return []
|
273 |
+
|
274 |
+
def show_similar_items(similar_items):
|
275 |
+
"""Display similar items in a structured format with similarity scores"""
|
276 |
+
if not similar_items:
|
277 |
+
st.warning("No similar items found.")
|
278 |
+
return
|
279 |
+
|
280 |
+
st.subheader("Similar Items:")
|
281 |
+
|
282 |
+
# 결과를 2열로 표시
|
283 |
+
items_per_row = 2
|
284 |
+
for i in range(0, len(similar_items), items_per_row):
|
285 |
+
cols = st.columns(items_per_row)
|
286 |
+
for j, col in enumerate(cols):
|
287 |
+
if i + j < len(similar_items):
|
288 |
+
item = similar_items[i + j]
|
289 |
+
with col:
|
290 |
+
try:
|
291 |
+
if 'image_url' in item:
|
292 |
+
st.image(item['image_url'], use_column_width=True)
|
293 |
+
|
294 |
+
# 유사도 점수를 퍼센트로 표시
|
295 |
+
similarity_percent = item['similarity_score'] * 100
|
296 |
+
st.markdown(f"**Similarity: {similarity_percent:.1f}%**")
|
297 |
+
|
298 |
+
st.write(f"Brand: {item.get('brand', 'Unknown')}")
|
299 |
+
name = item.get('name', 'Unknown')
|
300 |
+
if len(name) > 50: # 긴 이름은 줄임
|
301 |
+
name = name[:47] + "..."
|
302 |
+
st.write(f"Name: {name}")
|
303 |
+
|
304 |
+
# 가격 정보 표시
|
305 |
+
price = item.get('price', 0)
|
306 |
+
if isinstance(price, (int, float)):
|
307 |
+
st.write(f"Price: {price:,}원")
|
308 |
+
else:
|
309 |
+
st.write(f"Price: {price}")
|
310 |
+
|
311 |
+
# 할인 정보가 있는 경우
|
312 |
+
if 'discount' in item and item['discount']:
|
313 |
+
st.write(f"Discount: {item['discount']}%")
|
314 |
+
if 'original_price' in item:
|
315 |
+
st.write(f"Original: {item['original_price']:,}원")
|
316 |
+
|
317 |
+
st.divider() # 구분선 추가
|
318 |
+
|
319 |
+
except Exception as e:
|
320 |
+
logger.error(f"Error displaying item: {e}")
|
321 |
+
st.error("Error displaying this item")
|
322 |
+
|
323 |
+
def process_search(image, mask, num_results):
|
324 |
+
"""유사 아이템 검색 처리"""
|
325 |
+
try:
|
326 |
+
with st.spinner("Extracting features..."):
|
327 |
+
features = extract_features(image, mask)
|
328 |
+
|
329 |
+
with st.spinner("Finding similar items..."):
|
330 |
+
similar_items = search_similar_items(features, top_k=num_results)
|
331 |
+
|
332 |
+
return similar_items
|
333 |
+
except Exception as e:
|
334 |
+
logger.error(f"Search processing error: {e}")
|
335 |
+
return None
|
336 |
+
|
337 |
+
# Callback functions
|
338 |
+
def handle_file_upload():
|
339 |
+
if st.session_state.uploaded_file is not None:
|
340 |
+
image = Image.open(st.session_state.uploaded_file).convert('RGB')
|
341 |
+
st.session_state.image = image
|
342 |
+
st.session_state.upload_state = 'image_uploaded'
|
343 |
+
st.rerun()
|
344 |
+
|
345 |
+
def handle_detection():
|
346 |
+
if st.session_state.image is not None:
|
347 |
+
detected_items = process_segmentation(st.session_state.image)
|
348 |
+
st.session_state.detected_items = detected_items
|
349 |
+
st.session_state.upload_state = 'items_detected'
|
350 |
+
st.rerun()
|
351 |
+
|
352 |
+
def handle_search():
|
353 |
+
st.session_state.search_clicked = True
|
354 |
+
|
355 |
+
def admin_interface():
|
356 |
+
st.title("Admin Interface - DB Update")
|
357 |
+
if st.button("Update DB with Segmentation"):
|
358 |
+
with st.spinner("Updating database with segmentation... This may take a while..."):
|
359 |
+
success = update_db_with_segmentation()
|
360 |
+
if success:
|
361 |
+
st.success("Database successfully updated with segmentation-based features!")
|
362 |
+
else:
|
363 |
+
st.error("Failed to update database. Please check the logs.")
|
364 |
+
|
365 |
+
def main():
|
366 |
+
st.title("Fashion Search App")
|
367 |
+
|
368 |
+
# Admin controls in sidebar
|
369 |
+
st.sidebar.title("Admin Controls")
|
370 |
+
if st.sidebar.checkbox("Show Admin Interface"):
|
371 |
+
admin_interface()
|
372 |
+
st.divider()
|
373 |
+
|
374 |
+
# 파일 업로더 (upload_state가 initial일 때만 표시)
|
375 |
+
if st.session_state.upload_state == 'initial':
|
376 |
+
uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'],
|
377 |
+
key='uploaded_file', on_change=handle_file_upload)
|
378 |
+
|
379 |
+
# 이미지가 업로드된 상태
|
380 |
+
if st.session_state.image is not None:
|
381 |
+
st.image(st.session_state.image, caption="Uploaded Image", use_column_width=True)
|
382 |
+
|
383 |
+
if st.session_state.detected_items is None:
|
384 |
+
if st.button("Detect Items", key='detect_button', on_click=handle_detection):
|
385 |
+
pass
|
386 |
+
|
387 |
+
# 검출된 아이템 표시
|
388 |
+
if st.session_state.detected_items:
|
389 |
+
# 감지된 아이템들을 2열로 표시
|
390 |
+
cols = st.columns(2)
|
391 |
+
for idx, item in enumerate(st.session_state.detected_items):
|
392 |
+
with cols[idx % 2]:
|
393 |
+
mask = item['mask']
|
394 |
+
masked_img = np.array(st.session_state.image) * np.expand_dims(mask, axis=2)
|
395 |
+
st.image(masked_img.astype(np.uint8), caption=f"Detected {item['label']}")
|
396 |
+
st.write(f"Item {idx + 1}: {item['label']}")
|
397 |
+
st.write(f"Confidence: {item['score']*100:.1f}%")
|
398 |
+
|
399 |
+
# 아이템 선택
|
400 |
+
selected_idx = st.selectbox(
|
401 |
+
"Select item to search:",
|
402 |
+
range(len(st.session_state.detected_items)),
|
403 |
+
format_func=lambda i: f"{st.session_state.detected_items[i]['label']}",
|
404 |
+
key='item_selector'
|
405 |
+
)
|
406 |
+
|
407 |
+
# 검색 컨트롤
|
408 |
+
search_col1, search_col2 = st.columns([1, 2])
|
409 |
+
with search_col1:
|
410 |
+
search_clicked = st.button("Search Similar Items",
|
411 |
+
key='search_button',
|
412 |
+
type="primary")
|
413 |
+
with search_col2:
|
414 |
+
num_results = st.slider("Number of results:",
|
415 |
+
min_value=1,
|
416 |
+
max_value=20,
|
417 |
+
value=5,
|
418 |
+
key='num_results')
|
419 |
+
|
420 |
+
# 검색 결과 처리
|
421 |
+
if search_clicked or st.session_state.get('search_clicked', False):
|
422 |
+
st.session_state.search_clicked = True
|
423 |
+
selected_mask = st.session_state.detected_items[selected_idx]['mask']
|
424 |
+
|
425 |
+
# 검색 결과를 세션 상태에 저장
|
426 |
+
if 'search_results' not in st.session_state:
|
427 |
+
similar_items = process_search(st.session_state.image, selected_mask, num_results)
|
428 |
+
st.session_state.search_results = similar_items
|
429 |
+
|
430 |
+
# 저장된 검색 결과 표시
|
431 |
+
if st.session_state.search_results:
|
432 |
+
show_similar_items(st.session_state.search_results)
|
433 |
+
else:
|
434 |
+
st.warning("No similar items found.")
|
435 |
+
|
436 |
+
# 새 검색 버튼
|
437 |
+
if st.button("Start New Search", key='new_search'):
|
438 |
+
# 모든 상태 초기화
|
439 |
+
for key in list(st.session_state.keys()):
|
440 |
+
del st.session_state[key]
|
441 |
+
st.rerun()
|
442 |
+
|
443 |
+
if __name__ == "__main__":
|
444 |
+
main()
|
app_10281200.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import open_clip
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
from transformers import pipeline
|
7 |
+
import chromadb
|
8 |
+
import logging
|
9 |
+
|
10 |
+
# 로깅 설정
|
11 |
+
logging.basicConfig(level=logging.INFO)
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
# Initialize session state
|
15 |
+
if 'image' not in st.session_state:
|
16 |
+
st.session_state.image = None
|
17 |
+
if 'detected_items' not in st.session_state:
|
18 |
+
st.session_state.detected_items = None
|
19 |
+
if 'selected_item_index' not in st.session_state:
|
20 |
+
st.session_state.selected_item_index = None
|
21 |
+
if 'upload_state' not in st.session_state:
|
22 |
+
st.session_state.upload_state = 'initial'
|
23 |
+
|
24 |
+
# Load models 안녕
|
25 |
+
@st.cache_resource
|
26 |
+
def load_models():
|
27 |
+
try:
|
28 |
+
# CLIP 모델
|
29 |
+
model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
|
30 |
+
|
31 |
+
# 세그멘테이션 모델
|
32 |
+
segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
|
33 |
+
|
34 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
35 |
+
model.to(device)
|
36 |
+
|
37 |
+
return model, preprocess_val, segmenter, device
|
38 |
+
except Exception as e:
|
39 |
+
logger.error(f"Error loading models: {e}")
|
40 |
+
raise
|
41 |
+
|
42 |
+
# 모델 로드
|
43 |
+
clip_model, preprocess_val, segmenter, device = load_models()
|
44 |
+
|
45 |
+
# ChromaDB 설정
|
46 |
+
client = chromadb.PersistentClient(path="./clothesDB_11GmarketMusinsa")
|
47 |
+
collection = client.get_collection(name="clothes")
|
48 |
+
|
49 |
+
def process_segmentation(image):
|
50 |
+
"""Segmentation processing 안녕하세요"""
|
51 |
+
try:
|
52 |
+
segments = segmenter(image)
|
53 |
+
valid_items = []
|
54 |
+
|
55 |
+
for s in segments:
|
56 |
+
mask_array = np.array(s['mask'])
|
57 |
+
confidence = np.mean(mask_array)
|
58 |
+
|
59 |
+
valid_items.append({
|
60 |
+
'score': confidence,
|
61 |
+
'label': s['label'],
|
62 |
+
'mask': mask_array
|
63 |
+
})
|
64 |
+
|
65 |
+
return valid_items
|
66 |
+
except Exception as e:
|
67 |
+
logger.error(f"Segmentation error: {e}")
|
68 |
+
return []
|
69 |
+
|
70 |
+
def extract_features(image, mask=None):
|
71 |
+
"""Extract CLIP features"""
|
72 |
+
try:
|
73 |
+
if mask is not None:
|
74 |
+
img_array = np.array(image)
|
75 |
+
mask = np.expand_dims(mask, axis=2)
|
76 |
+
masked_img = img_array * mask
|
77 |
+
masked_img[mask[:,:,0] == 0] = 255
|
78 |
+
image = Image.fromarray(masked_img.astype(np.uint8))
|
79 |
+
|
80 |
+
image_tensor = preprocess_val(image).unsqueeze(0).to(device)
|
81 |
+
with torch.no_grad():
|
82 |
+
features = clip_model.encode_image(image_tensor)
|
83 |
+
features /= features.norm(dim=-1, keepdim=True)
|
84 |
+
return features.cpu().numpy().flatten()
|
85 |
+
except Exception as e:
|
86 |
+
logger.error(f"Feature extraction error: {e}")
|
87 |
+
raise
|
88 |
+
|
89 |
+
def search_similar_items(features, top_k=10):
|
90 |
+
"""Search similar items with distance scores"""
|
91 |
+
try:
|
92 |
+
results = collection.query(
|
93 |
+
query_embeddings=[features.tolist()],
|
94 |
+
n_results=top_k,
|
95 |
+
include=['metadatas', 'distances'] # distances 포함
|
96 |
+
)
|
97 |
+
|
98 |
+
similar_items = []
|
99 |
+
for metadata, distance in zip(results['metadatas'][0], results['distances'][0]):
|
100 |
+
# 거리를 유사도 점수로 변환 (0~1 범위)
|
101 |
+
similarity_score = 1 / (1 + distance)
|
102 |
+
metadata['similarity_score'] = similarity_score # 메타데이터에 점수 추가
|
103 |
+
similar_items.append(metadata)
|
104 |
+
|
105 |
+
return similar_items
|
106 |
+
except Exception as e:
|
107 |
+
logger.error(f"Search error: {e}")
|
108 |
+
return []
|
109 |
+
|
110 |
+
def show_similar_items(similar_items):
|
111 |
+
"""Display similar items in a structured format with similarity scores"""
|
112 |
+
st.subheader("Similar Items:")
|
113 |
+
for item in similar_items:
|
114 |
+
col1, col2 = st.columns([1, 2])
|
115 |
+
with col1:
|
116 |
+
st.image(item['image_url'])
|
117 |
+
with col2:
|
118 |
+
# 유사도 점수를 퍼센트로 표시
|
119 |
+
similarity_percent = item['similarity_score'] * 100
|
120 |
+
st.write(f"Similarity: {similarity_percent:.1f}%")
|
121 |
+
st.write(f"Brand: {item.get('brand', 'Unknown')}")
|
122 |
+
st.write(f"Name: {item.get('name', 'Unknown')}")
|
123 |
+
st.write(f"Price: {item.get('price', 'Unknown'):,}원")
|
124 |
+
if 'discount' in item:
|
125 |
+
st.write(f"Discount: {item['discount']}%")
|
126 |
+
if 'original_price' in item:
|
127 |
+
st.write(f"Original Price: {item['original_price']:,}원")
|
128 |
+
|
129 |
+
# Initialize session state
|
130 |
+
if 'image' not in st.session_state:
|
131 |
+
st.session_state.image = None
|
132 |
+
if 'detected_items' not in st.session_state:
|
133 |
+
st.session_state.detected_items = None
|
134 |
+
if 'selected_item_index' not in st.session_state:
|
135 |
+
st.session_state.selected_item_index = None
|
136 |
+
if 'upload_state' not in st.session_state:
|
137 |
+
st.session_state.upload_state = 'initial'
|
138 |
+
if 'search_clicked' not in st.session_state:
|
139 |
+
st.session_state.search_clicked = False
|
140 |
+
|
141 |
+
def reset_state():
|
142 |
+
"""Reset all session state variables"""
|
143 |
+
for key in list(st.session_state.keys()):
|
144 |
+
del st.session_state[key]
|
145 |
+
|
146 |
+
# Callback functions
|
147 |
+
def handle_file_upload():
|
148 |
+
if st.session_state.uploaded_file is not None:
|
149 |
+
image = Image.open(st.session_state.uploaded_file).convert('RGB')
|
150 |
+
st.session_state.image = image
|
151 |
+
st.session_state.upload_state = 'image_uploaded'
|
152 |
+
st.rerun()
|
153 |
+
|
154 |
+
def handle_detection():
|
155 |
+
if st.session_state.image is not None:
|
156 |
+
detected_items = process_segmentation(st.session_state.image)
|
157 |
+
st.session_state.detected_items = detected_items
|
158 |
+
st.session_state.upload_state = 'items_detected'
|
159 |
+
st.rerun()
|
160 |
+
|
161 |
+
def handle_search():
|
162 |
+
st.session_state.search_clicked = True
|
163 |
+
|
164 |
+
def main():
|
165 |
+
st.title("포어블랙 fashion demo!!!")
|
166 |
+
|
167 |
+
# 파일 업로더 (upload_state가 initial일 때만 표시)
|
168 |
+
if st.session_state.upload_state == 'initial':
|
169 |
+
uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'],
|
170 |
+
key='uploaded_file', on_change=handle_file_upload)
|
171 |
+
|
172 |
+
# 이미지가 업로드된 상태 df
|
173 |
+
if st.session_state.image is not None:
|
174 |
+
st.image(st.session_state.image, caption="Uploaded Image", use_column_width=True)
|
175 |
+
|
176 |
+
if st.session_state.detected_items is None:
|
177 |
+
if st.button("Detect Items", key='detect_button', on_click=handle_detection):
|
178 |
+
pass
|
179 |
+
|
180 |
+
# 검출된 아이템 표시d
|
181 |
+
if st.session_state.detected_items:
|
182 |
+
# 감지된 아이템들d을 2열로 표시
|
183 |
+
cols = st.columns(2)
|
184 |
+
for idx, item in enumerate(st.session_state.detected_items):
|
185 |
+
with cols[idx % 2]:
|
186 |
+
mask = item['mask']
|
187 |
+
masked_img = np.array(st.session_state.image) * np.expand_dims(mask, axis=2)
|
188 |
+
st.image(masked_img.astype(np.uint8), caption=f"Detected {item['label']}")
|
189 |
+
st.write(f"Item {idx + 1}: {item['label']}")
|
190 |
+
st.write(f"Confidence: {item['score']*100:.1f}%")
|
191 |
+
|
192 |
+
# 아이템 선택
|
193 |
+
selected_idx = st.selectbox(
|
194 |
+
"Select item to search:",
|
195 |
+
range(len(st.session_state.detected_items)),
|
196 |
+
format_func=lambda i: f"{st.session_state.detected_items[i]['label']}",
|
197 |
+
key='item_selector'
|
198 |
+
)
|
199 |
+
st.session_state.selected_item_index = selected_idx
|
200 |
+
|
201 |
+
# 유사 아이템 검색
|
202 |
+
col1, col2 = st.columns([1, 2])
|
203 |
+
with col1:
|
204 |
+
search_button = st.button("Search Similar Items",
|
205 |
+
key='search_button',
|
206 |
+
on_click=handle_search,
|
207 |
+
type="primary") # 강조된 버튼
|
208 |
+
with col2:
|
209 |
+
num_results = st.slider("Number of results:",
|
210 |
+
min_value=1,
|
211 |
+
max_value=20,
|
212 |
+
value=5,
|
213 |
+
key='num_results')
|
214 |
+
|
215 |
+
if st.session_state.search_clicked:
|
216 |
+
with st.spinner("Searching similar items..."):
|
217 |
+
try:
|
218 |
+
selected_mask = st.session_state.detected_items[selected_idx]['mask']
|
219 |
+
features = extract_features(st.session_state.image, selected_mask)
|
220 |
+
similar_items = search_similar_items(features, top_k=num_results)
|
221 |
+
|
222 |
+
if similar_items:
|
223 |
+
show_similar_items(similar_items)
|
224 |
+
else:
|
225 |
+
st.warning("No similar items found.")
|
226 |
+
except Exception as e:
|
227 |
+
st.error(f"Error during search: {str(e)}")
|
228 |
+
|
229 |
+
# 새 검색 버튼
|
230 |
+
if st.button("Start New Search ", key='new_search'):
|
231 |
+
reset_state()
|
232 |
+
st.rerun()
|
233 |
+
|
234 |
+
if __name__ == "__main__":
|
235 |
+
main()
|
app_accessary.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import open_clip
|
3 |
+
import torch
|
4 |
+
import requests
|
5 |
+
from PIL import Image
|
6 |
+
from io import BytesIO
|
7 |
+
import time
|
8 |
+
import json
|
9 |
+
import numpy as np
|
10 |
+
import onnxruntime as ort
|
11 |
+
from ultralytics import YOLO
|
12 |
+
import cv2
|
13 |
+
import chromadb
|
14 |
+
|
15 |
+
@st.cache_resource
|
16 |
+
def load_clip_model():
|
17 |
+
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
|
18 |
+
tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionSigLIP')
|
19 |
+
|
20 |
+
# 파인튜닝한 모델의 state_dict 불러오기
|
21 |
+
#state_dict = torch.load('./accessory_clip.pt', map_location=torch.device('cpu'))
|
22 |
+
#model.load_state_dict(state_dict) # 모델에 state_dict 적용
|
23 |
+
|
24 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
25 |
+
model.to(device)
|
26 |
+
|
27 |
+
return model, preprocess_val, tokenizer, device
|
28 |
+
|
29 |
+
clip_model, preprocess_val, tokenizer, device = load_clip_model()
|
30 |
+
|
31 |
+
|
32 |
+
@st.cache_resource
|
33 |
+
def load_yolo_model():
|
34 |
+
return YOLO("./accessaries.pt")
|
35 |
+
|
36 |
+
yolo_model = load_yolo_model()
|
37 |
+
|
38 |
+
# URL에서 이미지 로드
|
39 |
+
def load_image_from_url(url, max_retries=3):
|
40 |
+
for attempt in range(max_retries):
|
41 |
+
try:
|
42 |
+
response = requests.get(url, timeout=10)
|
43 |
+
response.raise_for_status()
|
44 |
+
img = Image.open(BytesIO(response.content)).convert('RGB')
|
45 |
+
return img
|
46 |
+
except (requests.RequestException, Image.UnidentifiedImageError) as e:
|
47 |
+
if attempt < max_retries - 1:
|
48 |
+
time.sleep(1)
|
49 |
+
else:
|
50 |
+
return None
|
51 |
+
|
52 |
+
# ChromaDB 클라이언트 설정
|
53 |
+
client = chromadb.PersistentClient(path="./accessaryDB")
|
54 |
+
collection = client.get_collection(name="accessary_items_ver2")
|
55 |
+
|
56 |
+
def get_image_embedding(image):
|
57 |
+
image_tensor = preprocess_val(image).unsqueeze(0).to(device)
|
58 |
+
with torch.no_grad():
|
59 |
+
image_features = clip_model.encode_image(image_tensor)
|
60 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
61 |
+
return image_features.cpu().numpy()
|
62 |
+
|
63 |
+
def get_text_embedding(text):
|
64 |
+
text_tokens = tokenizer([text]).to(device)
|
65 |
+
with torch.no_grad():
|
66 |
+
text_features = clip_model.encode_text(text_tokens)
|
67 |
+
text_features /= text_features.norm(dim=-1, keepdim=True)
|
68 |
+
return text_features.cpu().numpy()
|
69 |
+
|
70 |
+
def get_all_embeddings_from_collection(collection):
|
71 |
+
all_embeddings = collection.get(include=['embeddings'])['embeddings']
|
72 |
+
return np.array(all_embeddings)
|
73 |
+
|
74 |
+
def get_metadata_from_ids(collection, ids):
|
75 |
+
results = collection.get(ids=ids)
|
76 |
+
return results['metadatas']
|
77 |
+
|
78 |
+
def find_similar_images(query_embedding, collection, top_k=5):
|
79 |
+
database_embeddings = get_all_embeddings_from_collection(collection)
|
80 |
+
similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
|
81 |
+
top_indices = np.argsort(similarities)[::-1][:top_k]
|
82 |
+
|
83 |
+
all_data = collection.get(include=['metadatas'])['metadatas']
|
84 |
+
|
85 |
+
top_metadatas = [all_data[idx] for idx in top_indices]
|
86 |
+
|
87 |
+
results = []
|
88 |
+
for idx, metadata in enumerate(top_metadatas):
|
89 |
+
results.append({
|
90 |
+
'info': metadata,
|
91 |
+
'similarity': similarities[top_indices[idx]]
|
92 |
+
})
|
93 |
+
return results
|
94 |
+
|
95 |
+
def detect_clothing(image):
|
96 |
+
results = yolo_model(image)
|
97 |
+
detections = results[0].boxes.data.cpu().numpy()
|
98 |
+
categories = []
|
99 |
+
for detection in detections:
|
100 |
+
x1, y1, x2, y2, conf, cls = detection
|
101 |
+
category = yolo_model.names[int(cls)]
|
102 |
+
if category in ['Bracelets', 'Broches', 'bag', 'belt', 'earring', 'maangtika', 'necklace', 'nose ring', 'ring', 'tiara']:
|
103 |
+
categories.append({
|
104 |
+
'category': category,
|
105 |
+
'bbox': [int(x1), int(y1), int(x2), int(y2)],
|
106 |
+
'confidence': conf
|
107 |
+
})
|
108 |
+
return categories
|
109 |
+
|
110 |
+
# 이미지 자르기
|
111 |
+
def crop_image(image, bbox):
|
112 |
+
return image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
|
113 |
+
|
114 |
+
# 세션 상태 초기화
|
115 |
+
if 'step' not in st.session_state:
|
116 |
+
st.session_state.step = 'input'
|
117 |
+
if 'query_image_url' not in st.session_state:
|
118 |
+
st.session_state.query_image_url = ''
|
119 |
+
if 'detections' not in st.session_state:
|
120 |
+
st.session_state.detections = []
|
121 |
+
if 'selected_category' not in st.session_state:
|
122 |
+
st.session_state.selected_category = None
|
123 |
+
|
124 |
+
# Streamlit app
|
125 |
+
st.title("Accessary Search App")
|
126 |
+
|
127 |
+
# 단계별 처리
|
128 |
+
if st.session_state.step == 'input':
|
129 |
+
st.session_state.query_image_url = st.text_input("Enter image URL:", st.session_state.query_image_url)
|
130 |
+
if st.button("Detect accesseary"):
|
131 |
+
if st.session_state.query_image_url:
|
132 |
+
query_image = load_image_from_url(st.session_state.query_image_url)
|
133 |
+
if query_image is not None:
|
134 |
+
st.session_state.query_image = query_image
|
135 |
+
st.session_state.detections = detect_clothing(query_image)
|
136 |
+
if st.session_state.detections:
|
137 |
+
st.session_state.step = 'select_category'
|
138 |
+
else:
|
139 |
+
st.warning("No items detected in the image.")
|
140 |
+
else:
|
141 |
+
st.error("Failed to load the image. Please try another URL.")
|
142 |
+
else:
|
143 |
+
st.warning("Please enter an image URL.")
|
144 |
+
|
145 |
+
elif st.session_state.step == 'select_category':
|
146 |
+
st.image(st.session_state.query_image, caption="Query Image", use_column_width=True)
|
147 |
+
st.subheader("Detected Clothing Items:")
|
148 |
+
|
149 |
+
for detection in st.session_state.detections:
|
150 |
+
col1, col2 = st.columns([1, 3])
|
151 |
+
with col1:
|
152 |
+
st.write(f"{detection['category']} (Confidence: {detection['confidence']:.2f})")
|
153 |
+
with col2:
|
154 |
+
cropped_image = crop_image(st.session_state.query_image, detection['bbox'])
|
155 |
+
st.image(cropped_image, caption=detection['category'], use_column_width=True)
|
156 |
+
|
157 |
+
options = [f"{d['category']} (Confidence: {d['confidence']:.2f})" for d in st.session_state.detections]
|
158 |
+
selected_option = st.selectbox("Select a category to search:", options)
|
159 |
+
|
160 |
+
if st.button("Search Similar Items"):
|
161 |
+
st.session_state.selected_category = selected_option
|
162 |
+
st.session_state.step = 'show_results'
|
163 |
+
|
164 |
+
elif st.session_state.step == 'show_results':
|
165 |
+
st.image(st.session_state.query_image, caption="Query Image", use_column_width=True)
|
166 |
+
selected_detection = next(d for d in st.session_state.detections
|
167 |
+
if f"{d['category']} (Confidence: {d['confidence']:.2f})" == st.session_state.selected_category)
|
168 |
+
cropped_image = crop_image(st.session_state.query_image, selected_detection['bbox'])
|
169 |
+
st.image(cropped_image, caption="Cropped Image", use_column_width=True)
|
170 |
+
query_embedding = get_image_embedding(cropped_image)
|
171 |
+
similar_images = find_similar_images(query_embedding, collection)
|
172 |
+
|
173 |
+
st.subheader("Similar Items:")
|
174 |
+
for img in similar_images:
|
175 |
+
col1, col2 = st.columns(2)
|
176 |
+
with col1:
|
177 |
+
st.image(img['info']['image_url'], use_column_width=True)
|
178 |
+
with col2:
|
179 |
+
st.write(f"Name: {img['info']['name']}")
|
180 |
+
st.write(f"Brand: {img['info']['brand']}")
|
181 |
+
category = img['info'].get('category')
|
182 |
+
if category:
|
183 |
+
st.write(f"Category: {category}")
|
184 |
+
st.write(f"Price: {img['info']['price']}")
|
185 |
+
st.write(f"Discount: {img['info']['discount']}%")
|
186 |
+
st.write(f"Similarity: {img['similarity']:.2f}")
|
187 |
+
|
188 |
+
if st.button("Start New Search"):
|
189 |
+
st.session_state.step = 'input'
|
190 |
+
st.session_state.query_image_url = ''
|
191 |
+
st.session_state.detections = []
|
192 |
+
st.session_state.selected_category = None
|
193 |
+
|
194 |
+
|
195 |
+
else: # Text search
|
196 |
+
query_text = st.text_input("Enter search text:")
|
197 |
+
if st.button("Search by Text"):
|
198 |
+
if query_text:
|
199 |
+
text_embedding = get_text_embedding(query_text)
|
200 |
+
similar_images = find_similar_images(text_embedding, collection)
|
201 |
+
st.subheader("Similar Items:")
|
202 |
+
for img in similar_images:
|
203 |
+
col1, col2 = st.columns(2)
|
204 |
+
with col1:
|
205 |
+
st.image(img['info']['image_url'], use_column_width=True)
|
206 |
+
with col2:
|
207 |
+
st.write(f"Name: {img['info']['name']}")
|
208 |
+
st.write(f"Brand: {img['info']['brand']}")
|
209 |
+
category = img['info'].get('category')
|
210 |
+
if category:
|
211 |
+
st.write(f"Category: {category}")
|
212 |
+
st.write(f"Price: {img['info']['price']}")
|
213 |
+
st.write(f"Discount: {img['info']['discount']}%")
|
214 |
+
st.write(f"Similarity: {img['similarity']:.2f}")
|
215 |
+
else:
|
216 |
+
st.warning("Please enter a search text.")
|
app_origin.py
ADDED
File without changes
|
clothesDB_11GmarketMusinsa/b4c365a0-3372-4031-9ff8-d457b4123d0a/data_level0.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:71986228e55386105f7be1ee4986b077f01efe1d3dcfc901cc1e4a61138af1dc
|
3 |
+
size 3212000
|
clothesDB_11GmarketMusinsa/b4c365a0-3372-4031-9ff8-d457b4123d0a/header.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2cbe47b73ab24fff8fda6ff745cf0c3153d60098f2b086c4e2f11f4dd46c39a9
|
3 |
+
size 100
|
clothesDB_11GmarketMusinsa/b4c365a0-3372-4031-9ff8-d457b4123d0a/index_metadata.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ab67b060118c79acfd5c73f1014457c1edf4745681f9bed95f08a422bd820656
|
3 |
+
size 346027
|
clothesDB_11GmarketMusinsa/b4c365a0-3372-4031-9ff8-d457b4123d0a/length.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4333c839a3d07797dedf8ee38d3a56f0ad5e42e45babb1bf94d830617d043fcb
|
3 |
+
size 4000
|
clothesDB_11GmarketMusinsa/b4c365a0-3372-4031-9ff8-d457b4123d0a/link_lists.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0ba1d7a2bccf360384cbf836c3cd29510268a93a83e100577fca177d77433f64
|
3 |
+
size 8420
|
clothesDB_11GmarketMusinsa/chroma.sqlite3
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:509eb23fbf078e21b37f013125481a4bc07f9a34a90a69a2930bb9c0784aabba
|
3 |
+
size 17100800
|
db_creation.log
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
2024-10-28 13:02:14,164 - INFO - Loading models...
|
2 |
+
2024-10-28 13:02:14,166 - INFO - Loading models...
|
db_segmentation.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import chromadb
|
2 |
+
import logging
|
3 |
+
import open_clip
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
import numpy as np
|
7 |
+
from transformers import pipeline
|
8 |
+
import requests
|
9 |
+
import io
|
10 |
+
from concurrent.futures import ThreadPoolExecutor
|
11 |
+
from tqdm import tqdm
|
12 |
+
import os
|
13 |
+
|
14 |
+
# 로깅 설정
|
15 |
+
logging.basicConfig(
|
16 |
+
level=logging.INFO,
|
17 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
18 |
+
handlers=[
|
19 |
+
logging.FileHandler('db_creation.log'),
|
20 |
+
logging.StreamHandler()
|
21 |
+
]
|
22 |
+
)
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
def load_models():
|
26 |
+
"""Load CLIP and segmentation models"""
|
27 |
+
try:
|
28 |
+
logger.info("Loading models...")
|
29 |
+
# CLIP 모델
|
30 |
+
model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
|
31 |
+
|
32 |
+
# 세그멘테이션 모델
|
33 |
+
segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
|
34 |
+
|
35 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
36 |
+
logger.info(f"Using device: {device}")
|
37 |
+
model.to(device)
|
38 |
+
|
39 |
+
return model, preprocess_val, segmenter, device
|
40 |
+
except Exception as e:
|
41 |
+
logger.error(f"Error loading models: {e}")
|
42 |
+
raise
|
43 |
+
|
44 |
+
def process_segmentation(image, segmenter):
|
45 |
+
"""Apply segmentation to image"""
|
46 |
+
try:
|
47 |
+
segments = segmenter(image)
|
48 |
+
if not segments:
|
49 |
+
return None
|
50 |
+
|
51 |
+
# 가장 큰 세그먼트 선택
|
52 |
+
largest_segment = max(segments, key=lambda s: np.sum(s['mask']))
|
53 |
+
mask = np.array(largest_segment['mask'])
|
54 |
+
|
55 |
+
return mask
|
56 |
+
|
57 |
+
except Exception as e:
|
58 |
+
logger.error(f"Segmentation error: {e}")
|
59 |
+
return None
|
60 |
+
|
61 |
+
def extract_features(image, mask, model, preprocess_val, device):
|
62 |
+
"""Extract CLIP features with segmentation mask"""
|
63 |
+
try:
|
64 |
+
if mask is not None:
|
65 |
+
img_array = np.array(image)
|
66 |
+
mask = np.expand_dims(mask, axis=2)
|
67 |
+
masked_img = img_array * mask
|
68 |
+
masked_img[mask[:,:,0] == 0] = 255 # 배경을 흰색으로
|
69 |
+
image = Image.fromarray(masked_img.astype(np.uint8))
|
70 |
+
|
71 |
+
image_tensor = preprocess_val(image).unsqueeze(0).to(device)
|
72 |
+
with torch.no_grad():
|
73 |
+
features = model.encode_image(image_tensor)
|
74 |
+
features /= features.norm(dim=-1, keepdim=True)
|
75 |
+
return features.cpu().numpy().flatten()
|
76 |
+
except Exception as e:
|
77 |
+
logger.error(f"Feature extraction error: {e}")
|
78 |
+
return None
|
79 |
+
|
80 |
+
def download_and_process_image(url, metadata_id, model, preprocess_val, segmenter, device):
|
81 |
+
"""Download and process single image"""
|
82 |
+
try:
|
83 |
+
response = requests.get(url, timeout=10)
|
84 |
+
if response.status_code != 200:
|
85 |
+
logger.error(f"Failed to download image {metadata_id}: HTTP {response.status_code}")
|
86 |
+
return None
|
87 |
+
|
88 |
+
image = Image.open(io.BytesIO(response.content)).convert('RGB')
|
89 |
+
|
90 |
+
# Apply segmentation
|
91 |
+
mask = process_segmentation(image, segmenter)
|
92 |
+
if mask is None:
|
93 |
+
logger.warning(f"No valid mask found for image {metadata_id}")
|
94 |
+
return None
|
95 |
+
|
96 |
+
# Extract features
|
97 |
+
features = extract_features(image, mask, model, preprocess_val, device)
|
98 |
+
if features is None:
|
99 |
+
logger.warning(f"Failed to extract features for image {metadata_id}")
|
100 |
+
return None
|
101 |
+
|
102 |
+
return features
|
103 |
+
|
104 |
+
except Exception as e:
|
105 |
+
logger.error(f"Error processing image {metadata_id}: {e}")
|
106 |
+
return None
|
107 |
+
|
108 |
+
def create_segmented_db(source_path, target_path, batch_size=100):
|
109 |
+
"""Create new segmented database from existing one"""
|
110 |
+
try:
|
111 |
+
logger.info("Loading models...")
|
112 |
+
model, preprocess_val, segmenter, device = load_models()
|
113 |
+
|
114 |
+
# Source DB 연결
|
115 |
+
source_client = chromadb.PersistentClient(path=source_path)
|
116 |
+
source_collection = source_client.get_collection(name="clothes")
|
117 |
+
|
118 |
+
# Target DB 생성
|
119 |
+
os.makedirs(target_path, exist_ok=True)
|
120 |
+
target_client = chromadb.PersistentClient(path=target_path)
|
121 |
+
|
122 |
+
try:
|
123 |
+
target_client.delete_collection("clothes_segmented")
|
124 |
+
except:
|
125 |
+
pass
|
126 |
+
|
127 |
+
target_collection = target_client.create_collection(
|
128 |
+
name="clothes_segmented",
|
129 |
+
metadata={"description": "Clothes collection with segmentation-based features"}
|
130 |
+
)
|
131 |
+
|
132 |
+
# 전체 아이템 수 확인
|
133 |
+
all_items = source_collection.get(include=['metadatas'])
|
134 |
+
total_items = len(all_items['metadatas'])
|
135 |
+
logger.info(f"Found {total_items} items in source database")
|
136 |
+
|
137 |
+
# 배치 처리를 위한 준비
|
138 |
+
successful_updates = 0
|
139 |
+
failed_updates = 0
|
140 |
+
|
141 |
+
# ThreadPoolExecutor ���정
|
142 |
+
max_workers = min(10, os.cpu_count() or 4)
|
143 |
+
|
144 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
145 |
+
# 전체 데이터를 배치로 나누어 처리
|
146 |
+
for batch_start in tqdm(range(0, total_items, batch_size), desc="Processing batches"):
|
147 |
+
batch_end = min(batch_start + batch_size, total_items)
|
148 |
+
batch_items = all_items['metadatas'][batch_start:batch_end]
|
149 |
+
|
150 |
+
# 배치 내의 모든 이미지에 대한 future 생성
|
151 |
+
futures = []
|
152 |
+
for metadata in batch_items:
|
153 |
+
if 'image_url' in metadata:
|
154 |
+
future = executor.submit(
|
155 |
+
download_and_process_image,
|
156 |
+
metadata['image_url'],
|
157 |
+
metadata.get('id', 'unknown'),
|
158 |
+
model, preprocess_val, segmenter, device
|
159 |
+
)
|
160 |
+
futures.append((metadata, future))
|
161 |
+
|
162 |
+
# 배치 결과 처리
|
163 |
+
batch_embeddings = []
|
164 |
+
batch_metadatas = []
|
165 |
+
batch_ids = []
|
166 |
+
|
167 |
+
for metadata, future in futures:
|
168 |
+
try:
|
169 |
+
features = future.result()
|
170 |
+
if features is not None:
|
171 |
+
batch_embeddings.append(features.tolist())
|
172 |
+
batch_metadatas.append(metadata)
|
173 |
+
batch_ids.append(metadata.get('id', str(hash(metadata['image_url']))))
|
174 |
+
successful_updates += 1
|
175 |
+
else:
|
176 |
+
failed_updates += 1
|
177 |
+
|
178 |
+
except Exception as e:
|
179 |
+
logger.error(f"Error processing batch item: {e}")
|
180 |
+
failed_updates += 1
|
181 |
+
continue
|
182 |
+
|
183 |
+
# 배치 데이터 저장
|
184 |
+
if batch_embeddings:
|
185 |
+
try:
|
186 |
+
target_collection.add(
|
187 |
+
embeddings=batch_embeddings,
|
188 |
+
metadatas=batch_metadatas,
|
189 |
+
ids=batch_ids
|
190 |
+
)
|
191 |
+
logger.info(f"Added batch of {len(batch_embeddings)} items")
|
192 |
+
except Exception as e:
|
193 |
+
logger.error(f"Error adding batch to collection: {e}")
|
194 |
+
failed_updates += len(batch_embeddings)
|
195 |
+
successful_updates -= len(batch_embeddings)
|
196 |
+
|
197 |
+
# 최종 결과 출력
|
198 |
+
logger.info(f"Database creation completed.")
|
199 |
+
logger.info(f"Successfully processed: {successful_updates}")
|
200 |
+
logger.info(f"Failed: {failed_updates}")
|
201 |
+
logger.info(f"Total completion rate: {(successful_updates/total_items)*100:.2f}%")
|
202 |
+
|
203 |
+
return True
|
204 |
+
|
205 |
+
except Exception as e:
|
206 |
+
logger.error(f"Database creation error: {e}")
|
207 |
+
return False
|
208 |
+
|
209 |
+
if __name__ == "__main__":
|
210 |
+
# 설정값
|
211 |
+
SOURCE_DB_PATH = "./clothesDB_11GmarketMusinsa" # 원본 DB 경로
|
212 |
+
TARGET_DB_PATH = "./clothesDB_11GmarketMusinsa_segmented" # 새로운 DB 경로
|
213 |
+
BATCH_SIZE = 50 # 한 번에 처리할 아이템 수
|
214 |
+
|
215 |
+
# DB 생성 실행
|
216 |
+
success = create_segmented_db(SOURCE_DB_PATH, TARGET_DB_PATH, BATCH_SIZE)
|
217 |
+
|
218 |
+
if success:
|
219 |
+
logger.info("Successfully created segmented database!")
|
220 |
+
else:
|
221 |
+
logger.error("Failed to create segmented database.")
|