leedoming commited on
Commit
2dba380
·
verified ·
1 Parent(s): 10569af

Upload 14 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ clothesDB_11GmarketMusinsa/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python 3.10
2
+ FROM python:3.10-slim
3
+
4
+ # Set the working directory in the container
5
+ WORKDIR /app
6
+
7
+ # Copy the current directory contents into the container at /app
8
+ COPY . /app
9
+
10
+ # Install system dependencies
11
+ RUN apt-get update && apt-get install -y \
12
+ build-essential \
13
+ curl \
14
+ software-properties-common \
15
+ git \
16
+ && rm -rf /var/lib/apt/lists/*
17
+
18
+ # Upgrade pip and install required python packages
19
+ RUN pip install --no-cache-dir --upgrade pip
20
+ RUN pip install --no-cache-dir -r requirements.txt
21
+
22
+ # Make port 8501 available to the world outside this container
23
+ EXPOSE 8501
24
+
25
+ # Run app.py when the container launches
26
+ CMD ["streamlit", "run", "app.py"]
README.md CHANGED
@@ -1,13 +1,13 @@
1
- ---
2
- title: Itda Segment
3
- emoji: 🐨
4
- colorFrom: purple
5
- colorTo: blue
6
- sdk: streamlit
7
- sdk_version: 1.39.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: Itda Nosegmentation
3
+ emoji: 🐨
4
+ colorFrom: gray
5
+ colorTo: purple
6
+ sdk: streamlit
7
+ sdk_version: 1.38.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import open_clip
3
+ import torch
4
+ from PIL import Image
5
+ import numpy as np
6
+ from transformers import pipeline
7
+ import chromadb
8
+ import logging
9
+ import io
10
+ import requests
11
+ from concurrent.futures import ThreadPoolExecutor
12
+
13
+ # 로깅 설정
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Initialize session state
18
+ if 'image' not in st.session_state:
19
+ st.session_state.image = None
20
+ if 'detected_items' not in st.session_state:
21
+ st.session_state.detected_items = None
22
+ if 'selected_item_index' not in st.session_state:
23
+ st.session_state.selected_item_index = None
24
+ if 'upload_state' not in st.session_state:
25
+ st.session_state.upload_state = 'initial'
26
+ if 'search_clicked' not in st.session_state:
27
+ st.session_state.search_clicked = False
28
+
29
+ # Load models
30
+ @st.cache_resource
31
+ def load_models():
32
+ try:
33
+ # CLIP 모델
34
+ model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
35
+
36
+ # 세그멘테이션 모델
37
+ segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
38
+
39
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+ model.to(device)
41
+
42
+ return model, preprocess_val, segmenter, device
43
+ except Exception as e:
44
+ logger.error(f"Error loading models: {e}")
45
+ raise
46
+
47
+ # 모델 로드
48
+ clip_model, preprocess_val, segmenter, device = load_models()
49
+
50
+ # ChromaDB 설정
51
+ client = chromadb.PersistentClient(path="./clothesDB_11GmarketMusinsa")
52
+ collection = client.get_collection(name="clothes")
53
+
54
+ def process_segmentation(image):
55
+ """Segmentation processing"""
56
+ try:
57
+ # pipeline 출력 결과 직접 처리
58
+ output = segmenter(image)
59
+
60
+ if not output:
61
+ logger.warning("No segments found in image")
62
+ return None
63
+
64
+ # 각 세그먼트의 마스크 크기 계산
65
+ segment_sizes = [np.sum(seg['mask']) for seg in output]
66
+
67
+ if not segment_sizes:
68
+ return None
69
+
70
+ # 가장 큰 세그먼트 선택
71
+ largest_idx = np.argmax(segment_sizes)
72
+ mask = output[largest_idx]['mask']
73
+
74
+ # 마스크가 numpy array가 아닌 경우 변환
75
+ if not isinstance(mask, np.ndarray):
76
+ mask = np.array(mask)
77
+
78
+ # 마스크가 2D가 아닌 경우 첫 번째 채널 사용
79
+ if len(mask.shape) > 2:
80
+ mask = mask[:, :, 0]
81
+
82
+ # bool 마스크를 float로 변환
83
+ mask = mask.astype(float)
84
+
85
+ logger.info(f"Successfully created mask with shape {mask.shape}")
86
+ return mask
87
+
88
+ except Exception as e:
89
+ logger.error(f"Segmentation error: {str(e)}")
90
+ import traceback
91
+ logger.error(traceback.format_exc())
92
+ return None
93
+
94
+ def download_and_process_image(image_url, metadata_id):
95
+ """Download image from URL and apply segmentation"""
96
+ try:
97
+ response = requests.get(image_url, timeout=10) # timeout 추가
98
+ if response.status_code != 200:
99
+ logger.error(f"Failed to download image {metadata_id}: HTTP {response.status_code}")
100
+ return None
101
+
102
+ image = Image.open(io.BytesIO(response.content)).convert('RGB')
103
+ logger.info(f"Successfully downloaded image {metadata_id}")
104
+
105
+ mask = process_segmentation(image)
106
+ if mask is not None:
107
+ features = extract_features(image, mask)
108
+ logger.info(f"Successfully extracted features for image {metadata_id}")
109
+ return features
110
+
111
+ logger.warning(f"No valid mask found for image {metadata_id}")
112
+ return None
113
+
114
+ except Exception as e:
115
+ logger.error(f"Error processing image {metadata_id}: {str(e)}")
116
+ import traceback
117
+ logger.error(traceback.format_exc())
118
+ return None
119
+
120
+ def update_db_with_segmentation():
121
+ """DB의 모든 이미지에 대해 segmentation을 적용하고 feature를 업데이트"""
122
+ try:
123
+ logger.info("Starting database update with segmentation")
124
+
125
+ # 새로운 collection 생성
126
+ try:
127
+ client.delete_collection("clothes_segmented")
128
+ logger.info("Deleted existing segmented collection")
129
+ except:
130
+ logger.info("No existing segmented collection to delete")
131
+
132
+ new_collection = client.create_collection(
133
+ name="clothes_segmented",
134
+ metadata={"description": "Clothes collection with segmentation-based features"}
135
+ )
136
+ logger.info("Created new segmented collection")
137
+
138
+ # 기존 collection에서 메타데이터만 가져오기
139
+ try:
140
+ all_items = collection.get(include=['metadatas'])
141
+ total_items = len(all_items['metadatas'])
142
+ logger.info(f"Found {total_items} items in database")
143
+ except Exception as e:
144
+ logger.error(f"Error getting items from collection: {str(e)}")
145
+ # 에러 발생 시 빈 리스트로 초기화
146
+ all_items = {'metadatas': []}
147
+ total_items = 0
148
+
149
+ # 진행 상황 표시를 위한 progress bar
150
+ progress_bar = st.progress(0)
151
+ status_text = st.empty()
152
+
153
+ successful_updates = 0
154
+ failed_updates = 0
155
+
156
+ with ThreadPoolExecutor(max_workers=4) as executor:
157
+ futures = []
158
+ # 이미지 URL이 있는 항목만 처리
159
+ valid_items = [m for m in all_items['metadatas'] if 'image_url' in m]
160
+
161
+ for metadata in valid_items:
162
+ future = executor.submit(
163
+ download_and_process_image,
164
+ metadata['image_url'],
165
+ metadata.get('id', 'unknown')
166
+ )
167
+ futures.append((metadata, future))
168
+
169
+ # 결과 처리 및 새 DB에 저장
170
+ for idx, (metadata, future) in enumerate(futures):
171
+ try:
172
+ new_features = future.result()
173
+ if new_features is not None:
174
+ item_id = metadata.get('id', str(hash(metadata['image_url'])))
175
+ try:
176
+ new_collection.add(
177
+ embeddings=[new_features.tolist()],
178
+ metadatas=[metadata],
179
+ ids=[item_id]
180
+ )
181
+ successful_updates += 1
182
+ logger.info(f"Successfully added item {item_id}")
183
+ except Exception as e:
184
+ logger.error(f"Error adding item to new collection: {str(e)}")
185
+ failed_updates += 1
186
+ else:
187
+ failed_updates += 1
188
+
189
+ # 진행 상황 업데이트
190
+ progress = (idx + 1) / len(futures)
191
+ progress_bar.progress(progress)
192
+ status_text.text(f"Processing: {idx + 1}/{len(futures)} items. Success: {successful_updates}, Failed: {failed_updates}")
193
+
194
+ except Exception as e:
195
+ logger.error(f"Error processing item: {str(e)}")
196
+ failed_updates += 1
197
+ continue
198
+
199
+ # 최종 결과 표시
200
+ status_text.text(f"Update completed. Successfully processed: {successful_updates}, Failed: {failed_updates}")
201
+ logger.info(f"Database update completed. Successful: {successful_updates}, Failed: {failed_updates}")
202
+
203
+ # 성공적으로 처리된 항목이 있는지 확인
204
+ if successful_updates > 0:
205
+ return True
206
+ else:
207
+ logger.error("No items were successfully processed")
208
+ return False
209
+
210
+ except Exception as e:
211
+ logger.error(f"Database update error: {str(e)}")
212
+ import traceback
213
+ logger.error(traceback.format_exc())
214
+ return False
215
+
216
+ def extract_features(image, mask=None):
217
+ """Extract CLIP features with segmentation mask"""
218
+ try:
219
+ if mask is not None:
220
+ img_array = np.array(image)
221
+ mask = np.expand_dims(mask, axis=2)
222
+ masked_img = img_array * mask
223
+ masked_img[mask[:,:,0] == 0] = 255 # 배경을 흰색으로
224
+ image = Image.fromarray(masked_img.astype(np.uint8))
225
+
226
+ image_tensor = preprocess_val(image).unsqueeze(0).to(device)
227
+ with torch.no_grad():
228
+ features = clip_model.encode_image(image_tensor)
229
+ features /= features.norm(dim=-1, keepdim=True)
230
+ return features.cpu().numpy().flatten()
231
+ except Exception as e:
232
+ logger.error(f"Feature extraction error: {e}")
233
+ raise
234
+
235
+ def search_similar_items(features, top_k=10):
236
+ """Search similar items using segmentation-based features"""
237
+ try:
238
+ # 세그멘테이션이 적용된 collection이 있는지 확인
239
+ try:
240
+ search_collection = client.get_collection("clothes_segmented")
241
+ logger.info("Using segmented collection for search")
242
+ except:
243
+ # 없으면 기존 collection 사용
244
+ search_collection = collection
245
+ logger.info("Using original collection for search")
246
+
247
+ results = search_collection.query(
248
+ query_embeddings=[features.tolist()],
249
+ n_results=top_k,
250
+ include=['metadatas', 'distances']
251
+ )
252
+
253
+ if not results or not results['metadatas'] or not results['distances']:
254
+ logger.warning("No results returned from ChromaDB")
255
+ return []
256
+
257
+ similar_items = []
258
+ for metadata, distance in zip(results['metadatas'][0], results['distances'][0]):
259
+ try:
260
+ similarity_score = 1 / (1 + float(distance))
261
+ item_data = metadata.copy()
262
+ item_data['similarity_score'] = similarity_score
263
+ similar_items.append(item_data)
264
+ except Exception as e:
265
+ logger.error(f"Error processing search result: {str(e)}")
266
+ continue
267
+
268
+ similar_items.sort(key=lambda x: x['similarity_score'], reverse=True)
269
+ return similar_items
270
+ except Exception as e:
271
+ logger.error(f"Search error: {str(e)}")
272
+ return []
273
+
274
+ def show_similar_items(similar_items):
275
+ """Display similar items in a structured format with similarity scores"""
276
+ if not similar_items:
277
+ st.warning("No similar items found.")
278
+ return
279
+
280
+ st.subheader("Similar Items:")
281
+
282
+ # 결과를 2열로 표시
283
+ items_per_row = 2
284
+ for i in range(0, len(similar_items), items_per_row):
285
+ cols = st.columns(items_per_row)
286
+ for j, col in enumerate(cols):
287
+ if i + j < len(similar_items):
288
+ item = similar_items[i + j]
289
+ with col:
290
+ try:
291
+ if 'image_url' in item:
292
+ st.image(item['image_url'], use_column_width=True)
293
+
294
+ # 유사도 점수를 퍼센트로 표시
295
+ similarity_percent = item['similarity_score'] * 100
296
+ st.markdown(f"**Similarity: {similarity_percent:.1f}%**")
297
+
298
+ st.write(f"Brand: {item.get('brand', 'Unknown')}")
299
+ name = item.get('name', 'Unknown')
300
+ if len(name) > 50: # 긴 이름은 줄임
301
+ name = name[:47] + "..."
302
+ st.write(f"Name: {name}")
303
+
304
+ # 가격 정보 표시
305
+ price = item.get('price', 0)
306
+ if isinstance(price, (int, float)):
307
+ st.write(f"Price: {price:,}원")
308
+ else:
309
+ st.write(f"Price: {price}")
310
+
311
+ # 할인 정보가 있는 경우
312
+ if 'discount' in item and item['discount']:
313
+ st.write(f"Discount: {item['discount']}%")
314
+ if 'original_price' in item:
315
+ st.write(f"Original: {item['original_price']:,}원")
316
+
317
+ st.divider() # 구분선 추가
318
+
319
+ except Exception as e:
320
+ logger.error(f"Error displaying item: {e}")
321
+ st.error("Error displaying this item")
322
+
323
+ def process_search(image, mask, num_results):
324
+ """유사 아이템 검색 처리"""
325
+ try:
326
+ with st.spinner("Extracting features..."):
327
+ features = extract_features(image, mask)
328
+
329
+ with st.spinner("Finding similar items..."):
330
+ similar_items = search_similar_items(features, top_k=num_results)
331
+
332
+ return similar_items
333
+ except Exception as e:
334
+ logger.error(f"Search processing error: {e}")
335
+ return None
336
+
337
+ # Callback functions
338
+ def handle_file_upload():
339
+ if st.session_state.uploaded_file is not None:
340
+ image = Image.open(st.session_state.uploaded_file).convert('RGB')
341
+ st.session_state.image = image
342
+ st.session_state.upload_state = 'image_uploaded'
343
+ st.rerun()
344
+
345
+ def handle_detection():
346
+ if st.session_state.image is not None:
347
+ detected_items = process_segmentation(st.session_state.image)
348
+ st.session_state.detected_items = detected_items
349
+ st.session_state.upload_state = 'items_detected'
350
+ st.rerun()
351
+
352
+ def handle_search():
353
+ st.session_state.search_clicked = True
354
+
355
+ def admin_interface():
356
+ st.title("Admin Interface - DB Update")
357
+ if st.button("Update DB with Segmentation"):
358
+ with st.spinner("Updating database with segmentation... This may take a while..."):
359
+ success = update_db_with_segmentation()
360
+ if success:
361
+ st.success("Database successfully updated with segmentation-based features!")
362
+ else:
363
+ st.error("Failed to update database. Please check the logs.")
364
+
365
+ def main():
366
+ st.title("Fashion Search App")
367
+
368
+ # Admin controls in sidebar
369
+ st.sidebar.title("Admin Controls")
370
+ if st.sidebar.checkbox("Show Admin Interface"):
371
+ admin_interface()
372
+ st.divider()
373
+
374
+ # 파일 업로더 (upload_state가 initial일 때만 표시)
375
+ if st.session_state.upload_state == 'initial':
376
+ uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'],
377
+ key='uploaded_file', on_change=handle_file_upload)
378
+
379
+ # 이미지가 업로드된 상태
380
+ if st.session_state.image is not None:
381
+ st.image(st.session_state.image, caption="Uploaded Image", use_column_width=True)
382
+
383
+ if st.session_state.detected_items is None:
384
+ if st.button("Detect Items", key='detect_button', on_click=handle_detection):
385
+ pass
386
+
387
+ # 검출된 아이템 표시
388
+ if st.session_state.detected_items:
389
+ # 감지된 아이템들을 2열로 표시
390
+ cols = st.columns(2)
391
+ for idx, item in enumerate(st.session_state.detected_items):
392
+ with cols[idx % 2]:
393
+ mask = item['mask']
394
+ masked_img = np.array(st.session_state.image) * np.expand_dims(mask, axis=2)
395
+ st.image(masked_img.astype(np.uint8), caption=f"Detected {item['label']}")
396
+ st.write(f"Item {idx + 1}: {item['label']}")
397
+ st.write(f"Confidence: {item['score']*100:.1f}%")
398
+
399
+ # 아이템 선택
400
+ selected_idx = st.selectbox(
401
+ "Select item to search:",
402
+ range(len(st.session_state.detected_items)),
403
+ format_func=lambda i: f"{st.session_state.detected_items[i]['label']}",
404
+ key='item_selector'
405
+ )
406
+
407
+ # 검색 컨트롤
408
+ search_col1, search_col2 = st.columns([1, 2])
409
+ with search_col1:
410
+ search_clicked = st.button("Search Similar Items",
411
+ key='search_button',
412
+ type="primary")
413
+ with search_col2:
414
+ num_results = st.slider("Number of results:",
415
+ min_value=1,
416
+ max_value=20,
417
+ value=5,
418
+ key='num_results')
419
+
420
+ # 검색 결과 처리
421
+ if search_clicked or st.session_state.get('search_clicked', False):
422
+ st.session_state.search_clicked = True
423
+ selected_mask = st.session_state.detected_items[selected_idx]['mask']
424
+
425
+ # 검색 결과를 세션 상태에 저장
426
+ if 'search_results' not in st.session_state:
427
+ similar_items = process_search(st.session_state.image, selected_mask, num_results)
428
+ st.session_state.search_results = similar_items
429
+
430
+ # 저장된 검색 결과 표시
431
+ if st.session_state.search_results:
432
+ show_similar_items(st.session_state.search_results)
433
+ else:
434
+ st.warning("No similar items found.")
435
+
436
+ # 새 검색 버튼
437
+ if st.button("Start New Search", key='new_search'):
438
+ # 모든 상태 초기화
439
+ for key in list(st.session_state.keys()):
440
+ del st.session_state[key]
441
+ st.rerun()
442
+
443
+ if __name__ == "__main__":
444
+ main()
app_10281200.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import open_clip
3
+ import torch
4
+ from PIL import Image
5
+ import numpy as np
6
+ from transformers import pipeline
7
+ import chromadb
8
+ import logging
9
+
10
+ # 로깅 설정
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # Initialize session state
15
+ if 'image' not in st.session_state:
16
+ st.session_state.image = None
17
+ if 'detected_items' not in st.session_state:
18
+ st.session_state.detected_items = None
19
+ if 'selected_item_index' not in st.session_state:
20
+ st.session_state.selected_item_index = None
21
+ if 'upload_state' not in st.session_state:
22
+ st.session_state.upload_state = 'initial'
23
+
24
+ # Load models 안녕
25
+ @st.cache_resource
26
+ def load_models():
27
+ try:
28
+ # CLIP 모델
29
+ model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
30
+
31
+ # 세그멘테이션 모델
32
+ segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
33
+
34
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+ model.to(device)
36
+
37
+ return model, preprocess_val, segmenter, device
38
+ except Exception as e:
39
+ logger.error(f"Error loading models: {e}")
40
+ raise
41
+
42
+ # 모델 로드
43
+ clip_model, preprocess_val, segmenter, device = load_models()
44
+
45
+ # ChromaDB 설정
46
+ client = chromadb.PersistentClient(path="./clothesDB_11GmarketMusinsa")
47
+ collection = client.get_collection(name="clothes")
48
+
49
+ def process_segmentation(image):
50
+ """Segmentation processing 안녕하세요"""
51
+ try:
52
+ segments = segmenter(image)
53
+ valid_items = []
54
+
55
+ for s in segments:
56
+ mask_array = np.array(s['mask'])
57
+ confidence = np.mean(mask_array)
58
+
59
+ valid_items.append({
60
+ 'score': confidence,
61
+ 'label': s['label'],
62
+ 'mask': mask_array
63
+ })
64
+
65
+ return valid_items
66
+ except Exception as e:
67
+ logger.error(f"Segmentation error: {e}")
68
+ return []
69
+
70
+ def extract_features(image, mask=None):
71
+ """Extract CLIP features"""
72
+ try:
73
+ if mask is not None:
74
+ img_array = np.array(image)
75
+ mask = np.expand_dims(mask, axis=2)
76
+ masked_img = img_array * mask
77
+ masked_img[mask[:,:,0] == 0] = 255
78
+ image = Image.fromarray(masked_img.astype(np.uint8))
79
+
80
+ image_tensor = preprocess_val(image).unsqueeze(0).to(device)
81
+ with torch.no_grad():
82
+ features = clip_model.encode_image(image_tensor)
83
+ features /= features.norm(dim=-1, keepdim=True)
84
+ return features.cpu().numpy().flatten()
85
+ except Exception as e:
86
+ logger.error(f"Feature extraction error: {e}")
87
+ raise
88
+
89
+ def search_similar_items(features, top_k=10):
90
+ """Search similar items with distance scores"""
91
+ try:
92
+ results = collection.query(
93
+ query_embeddings=[features.tolist()],
94
+ n_results=top_k,
95
+ include=['metadatas', 'distances'] # distances 포함
96
+ )
97
+
98
+ similar_items = []
99
+ for metadata, distance in zip(results['metadatas'][0], results['distances'][0]):
100
+ # 거리를 유사도 점수로 변환 (0~1 범위)
101
+ similarity_score = 1 / (1 + distance)
102
+ metadata['similarity_score'] = similarity_score # 메타데이터에 점수 추가
103
+ similar_items.append(metadata)
104
+
105
+ return similar_items
106
+ except Exception as e:
107
+ logger.error(f"Search error: {e}")
108
+ return []
109
+
110
+ def show_similar_items(similar_items):
111
+ """Display similar items in a structured format with similarity scores"""
112
+ st.subheader("Similar Items:")
113
+ for item in similar_items:
114
+ col1, col2 = st.columns([1, 2])
115
+ with col1:
116
+ st.image(item['image_url'])
117
+ with col2:
118
+ # 유사도 점수를 퍼센트로 표시
119
+ similarity_percent = item['similarity_score'] * 100
120
+ st.write(f"Similarity: {similarity_percent:.1f}%")
121
+ st.write(f"Brand: {item.get('brand', 'Unknown')}")
122
+ st.write(f"Name: {item.get('name', 'Unknown')}")
123
+ st.write(f"Price: {item.get('price', 'Unknown'):,}원")
124
+ if 'discount' in item:
125
+ st.write(f"Discount: {item['discount']}%")
126
+ if 'original_price' in item:
127
+ st.write(f"Original Price: {item['original_price']:,}원")
128
+
129
+ # Initialize session state
130
+ if 'image' not in st.session_state:
131
+ st.session_state.image = None
132
+ if 'detected_items' not in st.session_state:
133
+ st.session_state.detected_items = None
134
+ if 'selected_item_index' not in st.session_state:
135
+ st.session_state.selected_item_index = None
136
+ if 'upload_state' not in st.session_state:
137
+ st.session_state.upload_state = 'initial'
138
+ if 'search_clicked' not in st.session_state:
139
+ st.session_state.search_clicked = False
140
+
141
+ def reset_state():
142
+ """Reset all session state variables"""
143
+ for key in list(st.session_state.keys()):
144
+ del st.session_state[key]
145
+
146
+ # Callback functions
147
+ def handle_file_upload():
148
+ if st.session_state.uploaded_file is not None:
149
+ image = Image.open(st.session_state.uploaded_file).convert('RGB')
150
+ st.session_state.image = image
151
+ st.session_state.upload_state = 'image_uploaded'
152
+ st.rerun()
153
+
154
+ def handle_detection():
155
+ if st.session_state.image is not None:
156
+ detected_items = process_segmentation(st.session_state.image)
157
+ st.session_state.detected_items = detected_items
158
+ st.session_state.upload_state = 'items_detected'
159
+ st.rerun()
160
+
161
+ def handle_search():
162
+ st.session_state.search_clicked = True
163
+
164
+ def main():
165
+ st.title("포어블랙 fashion demo!!!")
166
+
167
+ # 파일 업로더 (upload_state가 initial일 때만 표시)
168
+ if st.session_state.upload_state == 'initial':
169
+ uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'],
170
+ key='uploaded_file', on_change=handle_file_upload)
171
+
172
+ # 이미지가 업로드된 상태 df
173
+ if st.session_state.image is not None:
174
+ st.image(st.session_state.image, caption="Uploaded Image", use_column_width=True)
175
+
176
+ if st.session_state.detected_items is None:
177
+ if st.button("Detect Items", key='detect_button', on_click=handle_detection):
178
+ pass
179
+
180
+ # 검출된 아이템 표시d
181
+ if st.session_state.detected_items:
182
+ # 감지된 아이템들d을 2열로 표시
183
+ cols = st.columns(2)
184
+ for idx, item in enumerate(st.session_state.detected_items):
185
+ with cols[idx % 2]:
186
+ mask = item['mask']
187
+ masked_img = np.array(st.session_state.image) * np.expand_dims(mask, axis=2)
188
+ st.image(masked_img.astype(np.uint8), caption=f"Detected {item['label']}")
189
+ st.write(f"Item {idx + 1}: {item['label']}")
190
+ st.write(f"Confidence: {item['score']*100:.1f}%")
191
+
192
+ # 아이템 선택
193
+ selected_idx = st.selectbox(
194
+ "Select item to search:",
195
+ range(len(st.session_state.detected_items)),
196
+ format_func=lambda i: f"{st.session_state.detected_items[i]['label']}",
197
+ key='item_selector'
198
+ )
199
+ st.session_state.selected_item_index = selected_idx
200
+
201
+ # 유사 아이템 검색
202
+ col1, col2 = st.columns([1, 2])
203
+ with col1:
204
+ search_button = st.button("Search Similar Items",
205
+ key='search_button',
206
+ on_click=handle_search,
207
+ type="primary") # 강조된 버튼
208
+ with col2:
209
+ num_results = st.slider("Number of results:",
210
+ min_value=1,
211
+ max_value=20,
212
+ value=5,
213
+ key='num_results')
214
+
215
+ if st.session_state.search_clicked:
216
+ with st.spinner("Searching similar items..."):
217
+ try:
218
+ selected_mask = st.session_state.detected_items[selected_idx]['mask']
219
+ features = extract_features(st.session_state.image, selected_mask)
220
+ similar_items = search_similar_items(features, top_k=num_results)
221
+
222
+ if similar_items:
223
+ show_similar_items(similar_items)
224
+ else:
225
+ st.warning("No similar items found.")
226
+ except Exception as e:
227
+ st.error(f"Error during search: {str(e)}")
228
+
229
+ # 새 검색 버튼
230
+ if st.button("Start New Search ", key='new_search'):
231
+ reset_state()
232
+ st.rerun()
233
+
234
+ if __name__ == "__main__":
235
+ main()
app_accessary.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import open_clip
3
+ import torch
4
+ import requests
5
+ from PIL import Image
6
+ from io import BytesIO
7
+ import time
8
+ import json
9
+ import numpy as np
10
+ import onnxruntime as ort
11
+ from ultralytics import YOLO
12
+ import cv2
13
+ import chromadb
14
+
15
+ @st.cache_resource
16
+ def load_clip_model():
17
+ model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
18
+ tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionSigLIP')
19
+
20
+ # 파인튜닝한 모델의 state_dict 불러오기
21
+ #state_dict = torch.load('./accessory_clip.pt', map_location=torch.device('cpu'))
22
+ #model.load_state_dict(state_dict) # 모델에 state_dict 적용
23
+
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ model.to(device)
26
+
27
+ return model, preprocess_val, tokenizer, device
28
+
29
+ clip_model, preprocess_val, tokenizer, device = load_clip_model()
30
+
31
+
32
+ @st.cache_resource
33
+ def load_yolo_model():
34
+ return YOLO("./accessaries.pt")
35
+
36
+ yolo_model = load_yolo_model()
37
+
38
+ # URL에서 이미지 로드
39
+ def load_image_from_url(url, max_retries=3):
40
+ for attempt in range(max_retries):
41
+ try:
42
+ response = requests.get(url, timeout=10)
43
+ response.raise_for_status()
44
+ img = Image.open(BytesIO(response.content)).convert('RGB')
45
+ return img
46
+ except (requests.RequestException, Image.UnidentifiedImageError) as e:
47
+ if attempt < max_retries - 1:
48
+ time.sleep(1)
49
+ else:
50
+ return None
51
+
52
+ # ChromaDB 클라이언트 설정
53
+ client = chromadb.PersistentClient(path="./accessaryDB")
54
+ collection = client.get_collection(name="accessary_items_ver2")
55
+
56
+ def get_image_embedding(image):
57
+ image_tensor = preprocess_val(image).unsqueeze(0).to(device)
58
+ with torch.no_grad():
59
+ image_features = clip_model.encode_image(image_tensor)
60
+ image_features /= image_features.norm(dim=-1, keepdim=True)
61
+ return image_features.cpu().numpy()
62
+
63
+ def get_text_embedding(text):
64
+ text_tokens = tokenizer([text]).to(device)
65
+ with torch.no_grad():
66
+ text_features = clip_model.encode_text(text_tokens)
67
+ text_features /= text_features.norm(dim=-1, keepdim=True)
68
+ return text_features.cpu().numpy()
69
+
70
+ def get_all_embeddings_from_collection(collection):
71
+ all_embeddings = collection.get(include=['embeddings'])['embeddings']
72
+ return np.array(all_embeddings)
73
+
74
+ def get_metadata_from_ids(collection, ids):
75
+ results = collection.get(ids=ids)
76
+ return results['metadatas']
77
+
78
+ def find_similar_images(query_embedding, collection, top_k=5):
79
+ database_embeddings = get_all_embeddings_from_collection(collection)
80
+ similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
81
+ top_indices = np.argsort(similarities)[::-1][:top_k]
82
+
83
+ all_data = collection.get(include=['metadatas'])['metadatas']
84
+
85
+ top_metadatas = [all_data[idx] for idx in top_indices]
86
+
87
+ results = []
88
+ for idx, metadata in enumerate(top_metadatas):
89
+ results.append({
90
+ 'info': metadata,
91
+ 'similarity': similarities[top_indices[idx]]
92
+ })
93
+ return results
94
+
95
+ def detect_clothing(image):
96
+ results = yolo_model(image)
97
+ detections = results[0].boxes.data.cpu().numpy()
98
+ categories = []
99
+ for detection in detections:
100
+ x1, y1, x2, y2, conf, cls = detection
101
+ category = yolo_model.names[int(cls)]
102
+ if category in ['Bracelets', 'Broches', 'bag', 'belt', 'earring', 'maangtika', 'necklace', 'nose ring', 'ring', 'tiara']:
103
+ categories.append({
104
+ 'category': category,
105
+ 'bbox': [int(x1), int(y1), int(x2), int(y2)],
106
+ 'confidence': conf
107
+ })
108
+ return categories
109
+
110
+ # 이미지 자르기
111
+ def crop_image(image, bbox):
112
+ return image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
113
+
114
+ # 세션 상태 초기화
115
+ if 'step' not in st.session_state:
116
+ st.session_state.step = 'input'
117
+ if 'query_image_url' not in st.session_state:
118
+ st.session_state.query_image_url = ''
119
+ if 'detections' not in st.session_state:
120
+ st.session_state.detections = []
121
+ if 'selected_category' not in st.session_state:
122
+ st.session_state.selected_category = None
123
+
124
+ # Streamlit app
125
+ st.title("Accessary Search App")
126
+
127
+ # 단계별 처리
128
+ if st.session_state.step == 'input':
129
+ st.session_state.query_image_url = st.text_input("Enter image URL:", st.session_state.query_image_url)
130
+ if st.button("Detect accesseary"):
131
+ if st.session_state.query_image_url:
132
+ query_image = load_image_from_url(st.session_state.query_image_url)
133
+ if query_image is not None:
134
+ st.session_state.query_image = query_image
135
+ st.session_state.detections = detect_clothing(query_image)
136
+ if st.session_state.detections:
137
+ st.session_state.step = 'select_category'
138
+ else:
139
+ st.warning("No items detected in the image.")
140
+ else:
141
+ st.error("Failed to load the image. Please try another URL.")
142
+ else:
143
+ st.warning("Please enter an image URL.")
144
+
145
+ elif st.session_state.step == 'select_category':
146
+ st.image(st.session_state.query_image, caption="Query Image", use_column_width=True)
147
+ st.subheader("Detected Clothing Items:")
148
+
149
+ for detection in st.session_state.detections:
150
+ col1, col2 = st.columns([1, 3])
151
+ with col1:
152
+ st.write(f"{detection['category']} (Confidence: {detection['confidence']:.2f})")
153
+ with col2:
154
+ cropped_image = crop_image(st.session_state.query_image, detection['bbox'])
155
+ st.image(cropped_image, caption=detection['category'], use_column_width=True)
156
+
157
+ options = [f"{d['category']} (Confidence: {d['confidence']:.2f})" for d in st.session_state.detections]
158
+ selected_option = st.selectbox("Select a category to search:", options)
159
+
160
+ if st.button("Search Similar Items"):
161
+ st.session_state.selected_category = selected_option
162
+ st.session_state.step = 'show_results'
163
+
164
+ elif st.session_state.step == 'show_results':
165
+ st.image(st.session_state.query_image, caption="Query Image", use_column_width=True)
166
+ selected_detection = next(d for d in st.session_state.detections
167
+ if f"{d['category']} (Confidence: {d['confidence']:.2f})" == st.session_state.selected_category)
168
+ cropped_image = crop_image(st.session_state.query_image, selected_detection['bbox'])
169
+ st.image(cropped_image, caption="Cropped Image", use_column_width=True)
170
+ query_embedding = get_image_embedding(cropped_image)
171
+ similar_images = find_similar_images(query_embedding, collection)
172
+
173
+ st.subheader("Similar Items:")
174
+ for img in similar_images:
175
+ col1, col2 = st.columns(2)
176
+ with col1:
177
+ st.image(img['info']['image_url'], use_column_width=True)
178
+ with col2:
179
+ st.write(f"Name: {img['info']['name']}")
180
+ st.write(f"Brand: {img['info']['brand']}")
181
+ category = img['info'].get('category')
182
+ if category:
183
+ st.write(f"Category: {category}")
184
+ st.write(f"Price: {img['info']['price']}")
185
+ st.write(f"Discount: {img['info']['discount']}%")
186
+ st.write(f"Similarity: {img['similarity']:.2f}")
187
+
188
+ if st.button("Start New Search"):
189
+ st.session_state.step = 'input'
190
+ st.session_state.query_image_url = ''
191
+ st.session_state.detections = []
192
+ st.session_state.selected_category = None
193
+
194
+
195
+ else: # Text search
196
+ query_text = st.text_input("Enter search text:")
197
+ if st.button("Search by Text"):
198
+ if query_text:
199
+ text_embedding = get_text_embedding(query_text)
200
+ similar_images = find_similar_images(text_embedding, collection)
201
+ st.subheader("Similar Items:")
202
+ for img in similar_images:
203
+ col1, col2 = st.columns(2)
204
+ with col1:
205
+ st.image(img['info']['image_url'], use_column_width=True)
206
+ with col2:
207
+ st.write(f"Name: {img['info']['name']}")
208
+ st.write(f"Brand: {img['info']['brand']}")
209
+ category = img['info'].get('category')
210
+ if category:
211
+ st.write(f"Category: {category}")
212
+ st.write(f"Price: {img['info']['price']}")
213
+ st.write(f"Discount: {img['info']['discount']}%")
214
+ st.write(f"Similarity: {img['similarity']:.2f}")
215
+ else:
216
+ st.warning("Please enter a search text.")
app_origin.py ADDED
File without changes
clothesDB_11GmarketMusinsa/b4c365a0-3372-4031-9ff8-d457b4123d0a/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71986228e55386105f7be1ee4986b077f01efe1d3dcfc901cc1e4a61138af1dc
3
+ size 3212000
clothesDB_11GmarketMusinsa/b4c365a0-3372-4031-9ff8-d457b4123d0a/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2cbe47b73ab24fff8fda6ff745cf0c3153d60098f2b086c4e2f11f4dd46c39a9
3
+ size 100
clothesDB_11GmarketMusinsa/b4c365a0-3372-4031-9ff8-d457b4123d0a/index_metadata.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab67b060118c79acfd5c73f1014457c1edf4745681f9bed95f08a422bd820656
3
+ size 346027
clothesDB_11GmarketMusinsa/b4c365a0-3372-4031-9ff8-d457b4123d0a/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4333c839a3d07797dedf8ee38d3a56f0ad5e42e45babb1bf94d830617d043fcb
3
+ size 4000
clothesDB_11GmarketMusinsa/b4c365a0-3372-4031-9ff8-d457b4123d0a/link_lists.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0ba1d7a2bccf360384cbf836c3cd29510268a93a83e100577fca177d77433f64
3
+ size 8420
clothesDB_11GmarketMusinsa/chroma.sqlite3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:509eb23fbf078e21b37f013125481a4bc07f9a34a90a69a2930bb9c0784aabba
3
+ size 17100800
db_creation.log ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ 2024-10-28 13:02:14,164 - INFO - Loading models...
2
+ 2024-10-28 13:02:14,166 - INFO - Loading models...
db_segmentation.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chromadb
2
+ import logging
3
+ import open_clip
4
+ import torch
5
+ from PIL import Image
6
+ import numpy as np
7
+ from transformers import pipeline
8
+ import requests
9
+ import io
10
+ from concurrent.futures import ThreadPoolExecutor
11
+ from tqdm import tqdm
12
+ import os
13
+
14
+ # 로깅 설정
15
+ logging.basicConfig(
16
+ level=logging.INFO,
17
+ format='%(asctime)s - %(levelname)s - %(message)s',
18
+ handlers=[
19
+ logging.FileHandler('db_creation.log'),
20
+ logging.StreamHandler()
21
+ ]
22
+ )
23
+ logger = logging.getLogger(__name__)
24
+
25
+ def load_models():
26
+ """Load CLIP and segmentation models"""
27
+ try:
28
+ logger.info("Loading models...")
29
+ # CLIP 모델
30
+ model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
31
+
32
+ # 세그멘테이션 모델
33
+ segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
34
+
35
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ logger.info(f"Using device: {device}")
37
+ model.to(device)
38
+
39
+ return model, preprocess_val, segmenter, device
40
+ except Exception as e:
41
+ logger.error(f"Error loading models: {e}")
42
+ raise
43
+
44
+ def process_segmentation(image, segmenter):
45
+ """Apply segmentation to image"""
46
+ try:
47
+ segments = segmenter(image)
48
+ if not segments:
49
+ return None
50
+
51
+ # 가장 큰 세그먼트 선택
52
+ largest_segment = max(segments, key=lambda s: np.sum(s['mask']))
53
+ mask = np.array(largest_segment['mask'])
54
+
55
+ return mask
56
+
57
+ except Exception as e:
58
+ logger.error(f"Segmentation error: {e}")
59
+ return None
60
+
61
+ def extract_features(image, mask, model, preprocess_val, device):
62
+ """Extract CLIP features with segmentation mask"""
63
+ try:
64
+ if mask is not None:
65
+ img_array = np.array(image)
66
+ mask = np.expand_dims(mask, axis=2)
67
+ masked_img = img_array * mask
68
+ masked_img[mask[:,:,0] == 0] = 255 # 배경을 흰색으로
69
+ image = Image.fromarray(masked_img.astype(np.uint8))
70
+
71
+ image_tensor = preprocess_val(image).unsqueeze(0).to(device)
72
+ with torch.no_grad():
73
+ features = model.encode_image(image_tensor)
74
+ features /= features.norm(dim=-1, keepdim=True)
75
+ return features.cpu().numpy().flatten()
76
+ except Exception as e:
77
+ logger.error(f"Feature extraction error: {e}")
78
+ return None
79
+
80
+ def download_and_process_image(url, metadata_id, model, preprocess_val, segmenter, device):
81
+ """Download and process single image"""
82
+ try:
83
+ response = requests.get(url, timeout=10)
84
+ if response.status_code != 200:
85
+ logger.error(f"Failed to download image {metadata_id}: HTTP {response.status_code}")
86
+ return None
87
+
88
+ image = Image.open(io.BytesIO(response.content)).convert('RGB')
89
+
90
+ # Apply segmentation
91
+ mask = process_segmentation(image, segmenter)
92
+ if mask is None:
93
+ logger.warning(f"No valid mask found for image {metadata_id}")
94
+ return None
95
+
96
+ # Extract features
97
+ features = extract_features(image, mask, model, preprocess_val, device)
98
+ if features is None:
99
+ logger.warning(f"Failed to extract features for image {metadata_id}")
100
+ return None
101
+
102
+ return features
103
+
104
+ except Exception as e:
105
+ logger.error(f"Error processing image {metadata_id}: {e}")
106
+ return None
107
+
108
+ def create_segmented_db(source_path, target_path, batch_size=100):
109
+ """Create new segmented database from existing one"""
110
+ try:
111
+ logger.info("Loading models...")
112
+ model, preprocess_val, segmenter, device = load_models()
113
+
114
+ # Source DB 연결
115
+ source_client = chromadb.PersistentClient(path=source_path)
116
+ source_collection = source_client.get_collection(name="clothes")
117
+
118
+ # Target DB 생성
119
+ os.makedirs(target_path, exist_ok=True)
120
+ target_client = chromadb.PersistentClient(path=target_path)
121
+
122
+ try:
123
+ target_client.delete_collection("clothes_segmented")
124
+ except:
125
+ pass
126
+
127
+ target_collection = target_client.create_collection(
128
+ name="clothes_segmented",
129
+ metadata={"description": "Clothes collection with segmentation-based features"}
130
+ )
131
+
132
+ # 전체 아이템 수 확인
133
+ all_items = source_collection.get(include=['metadatas'])
134
+ total_items = len(all_items['metadatas'])
135
+ logger.info(f"Found {total_items} items in source database")
136
+
137
+ # 배치 처리를 위한 준비
138
+ successful_updates = 0
139
+ failed_updates = 0
140
+
141
+ # ThreadPoolExecutor ���정
142
+ max_workers = min(10, os.cpu_count() or 4)
143
+
144
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
145
+ # 전체 데이터를 배치로 나누어 처리
146
+ for batch_start in tqdm(range(0, total_items, batch_size), desc="Processing batches"):
147
+ batch_end = min(batch_start + batch_size, total_items)
148
+ batch_items = all_items['metadatas'][batch_start:batch_end]
149
+
150
+ # 배치 내의 모든 이미지에 대한 future 생성
151
+ futures = []
152
+ for metadata in batch_items:
153
+ if 'image_url' in metadata:
154
+ future = executor.submit(
155
+ download_and_process_image,
156
+ metadata['image_url'],
157
+ metadata.get('id', 'unknown'),
158
+ model, preprocess_val, segmenter, device
159
+ )
160
+ futures.append((metadata, future))
161
+
162
+ # 배치 결과 처리
163
+ batch_embeddings = []
164
+ batch_metadatas = []
165
+ batch_ids = []
166
+
167
+ for metadata, future in futures:
168
+ try:
169
+ features = future.result()
170
+ if features is not None:
171
+ batch_embeddings.append(features.tolist())
172
+ batch_metadatas.append(metadata)
173
+ batch_ids.append(metadata.get('id', str(hash(metadata['image_url']))))
174
+ successful_updates += 1
175
+ else:
176
+ failed_updates += 1
177
+
178
+ except Exception as e:
179
+ logger.error(f"Error processing batch item: {e}")
180
+ failed_updates += 1
181
+ continue
182
+
183
+ # 배치 데이터 저장
184
+ if batch_embeddings:
185
+ try:
186
+ target_collection.add(
187
+ embeddings=batch_embeddings,
188
+ metadatas=batch_metadatas,
189
+ ids=batch_ids
190
+ )
191
+ logger.info(f"Added batch of {len(batch_embeddings)} items")
192
+ except Exception as e:
193
+ logger.error(f"Error adding batch to collection: {e}")
194
+ failed_updates += len(batch_embeddings)
195
+ successful_updates -= len(batch_embeddings)
196
+
197
+ # 최종 결과 출력
198
+ logger.info(f"Database creation completed.")
199
+ logger.info(f"Successfully processed: {successful_updates}")
200
+ logger.info(f"Failed: {failed_updates}")
201
+ logger.info(f"Total completion rate: {(successful_updates/total_items)*100:.2f}%")
202
+
203
+ return True
204
+
205
+ except Exception as e:
206
+ logger.error(f"Database creation error: {e}")
207
+ return False
208
+
209
+ if __name__ == "__main__":
210
+ # 설정값
211
+ SOURCE_DB_PATH = "./clothesDB_11GmarketMusinsa" # 원본 DB 경로
212
+ TARGET_DB_PATH = "./clothesDB_11GmarketMusinsa_segmented" # 새로운 DB 경로
213
+ BATCH_SIZE = 50 # 한 번에 처리할 아이템 수
214
+
215
+ # DB 생성 실행
216
+ success = create_segmented_db(SOURCE_DB_PATH, TARGET_DB_PATH, BATCH_SIZE)
217
+
218
+ if success:
219
+ logger.info("Successfully created segmented database!")
220
+ else:
221
+ logger.error("Failed to create segmented database.")