itda-segment / app.py
leedoming's picture
Update app.py
ee78aad verified
import streamlit as st
import open_clip
import torch
from PIL import Image
import numpy as np
from transformers import pipeline
import chromadb
import logging
import io
import requests
from concurrent.futures import ThreadPoolExecutor
# ๋กœ๊น… ์„ค์ •
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize session state
if 'image' not in st.session_state:
st.session_state.image = None
if 'detected_items' not in st.session_state:
st.session_state.detected_items = None
if 'selected_item_index' not in st.session_state:
st.session_state.selected_item_index = None
if 'upload_state' not in st.session_state:
st.session_state.upload_state = 'initial'
if 'search_clicked' not in st.session_state:
st.session_state.search_clicked = False
# Load models
@st.cache_resource
def load_models():
try:
# CLIP ๋ชจ๋ธ
model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
# ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ๋ชจ๋ธ
segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
return model, preprocess_val, segmenter, device
except Exception as e:
logger.error(f"Error loading models: {e}")
raise
# ๋ชจ๋ธ ๋กœ๋“œ
clip_model, preprocess_val, segmenter, device = load_models()
# ChromaDB ์„ค์ •
client = chromadb.PersistentClient(path="./clothesDB_11GmarketMusinsa")
collection = client.get_collection(name="clothes")
def extract_color_histogram(image, mask=None):
"""Extract color histogram from the image, considering the mask if provided"""
try:
img_array = np.array(image)
if mask is not None:
# Reshape mask to match image dimensions
mask = np.expand_dims(mask, axis=-1) # Add channel dimension
img_array = img_array * mask # Broadcasting will work correctly now
# Only consider pixels that are part of the clothing item
valid_pixels = img_array[mask[:,:,0] > 0]
else:
valid_pixels = img_array.reshape(-1, 3)
# Convert to HSV color space for better color representation
if len(valid_pixels) > 0:
# Reshape to proper dimensions for PIL Image
valid_pixels = valid_pixels.reshape(-1, 3)
img_hsv = Image.fromarray(valid_pixels.astype(np.uint8)).convert('HSV')
hsv_pixels = np.array(img_hsv)
# Calculate histogram for each HSV channel
h_hist = np.histogram(hsv_pixels[:,0], bins=8, range=(0, 256))[0]
s_hist = np.histogram(hsv_pixels[:,1], bins=8, range=(0, 256))[0]
v_hist = np.histogram(hsv_pixels[:,2], bins=8, range=(0, 256))[0]
# Normalize histograms
h_hist = h_hist / (h_hist.sum() + 1e-8) # Add small epsilon to avoid division by zero
s_hist = s_hist / (s_hist.sum() + 1e-8)
v_hist = v_hist / (v_hist.sum() + 1e-8)
return np.concatenate([h_hist, s_hist, v_hist])
return np.zeros(24) # 8bins * 3channels = 24 features
except Exception as e:
logger.error(f"Color histogram extraction error: {e}")
return np.zeros(24)
def process_segmentation(image):
"""Segmentation processing"""
try:
# pipeline ์ถœ๋ ฅ ๊ฒฐ๊ณผ ์ง์ ‘ ์ฒ˜๋ฆฌ
output = segmenter(image)
if not output or len(output) == 0:
logger.warning("No segments found in image")
return []
processed_items = []
for segment in output:
# ๊ธฐ๋ณธ๊ฐ’์„ ํฌํ•จํ•˜์—ฌ ๋”•์…”๋„ˆ๋ฆฌ ์ƒ์„ฑ
processed_segment = {
'label': segment.get('label', 'Unknown'),
'score': segment.get('score', 1.0), # score๊ฐ€ ์—†์œผ๋ฉด 1.0์„ ๊ธฐ๋ณธ๊ฐ’์œผ๋กœ ์‚ฌ์šฉ
'mask': None
}
mask = segment.get('mask')
if mask is not None:
# ๋งˆ์Šคํฌ๊ฐ€ numpy array๊ฐ€ ์•„๋‹Œ ๊ฒฝ์šฐ ๋ณ€ํ™˜
if not isinstance(mask, np.ndarray):
mask = np.array(mask)
# ๋งˆ์Šคํฌ๊ฐ€ 2D๊ฐ€ ์•„๋‹Œ ๊ฒฝ์šฐ ์ฒซ ๋ฒˆ์งธ ์ฑ„๋„ ์‚ฌ์šฉ
if len(mask.shape) > 2:
mask = mask[:, :, 0]
# bool ๋งˆ์Šคํฌ๋ฅผ float๋กœ ๋ณ€ํ™˜
processed_segment['mask'] = mask.astype(float)
else:
logger.warning(f"No mask found for segment with label {processed_segment['label']}")
continue # ๋งˆ์Šคํฌ๊ฐ€ ์—†๋Š” ์„ธ๊ทธ๋จผํŠธ๋Š” ๊ฑด๋„ˆ๋œ€
processed_items.append(processed_segment)
logger.info(f"Successfully processed {len(processed_items)} segments")
return processed_items
except Exception as e:
logger.error(f"Segmentation error: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return []
def extract_features(image, mask=None):
"""Extract both CLIP features and color features with segmentation mask"""
try:
# Extract CLIP features
if mask is not None:
img_array = np.array(image)
mask = np.expand_dims(mask, axis=-1)
masked_img = img_array * mask
masked_img[mask[:,:,0] == 0] = 255 # Set background to white
image = Image.fromarray(masked_img.astype(np.uint8))
image_tensor = preprocess_val(image).unsqueeze(0).to(device)
with torch.no_grad():
clip_features = clip_model.encode_image(image_tensor)
clip_features /= clip_features.norm(dim=-1, keepdim=True)
clip_features = clip_features.cpu().numpy().flatten()
# Extract color features
color_features = extract_color_histogram(image, mask)
# CLIP features are 768-dimensional, so we'll resize color features
# to maintain the same total dimensionality
clip_features = clip_features[:744] # Trim CLIP features to make room for color
# Normalize features
clip_features_normalized = clip_features / (np.linalg.norm(clip_features) + 1e-8)
color_features_normalized = color_features / (np.linalg.norm(color_features) + 1e-8)
# Adjust weights (total should be 768 to match collection dimensionality)
clip_weight = 0.7
color_weight = 0.3
combined_features = np.zeros(768) # Initialize with zeros
combined_features[:744] = clip_features_normalized * clip_weight # First 744 dimensions for CLIP
combined_features[744:] = color_features_normalized * color_weight # Last 24 dimensions for color
# Ensure final normalization
combined_features = combined_features / (np.linalg.norm(combined_features) + 1e-8)
return combined_features
except Exception as e:
logger.error(f"Feature extraction error: {e}")
raise
def download_and_process_image(image_url, metadata_id):
"""Download image from URL and apply segmentation"""
try:
response = requests.get(image_url, timeout=10)
if response.status_code != 200:
logger.error(f"Failed to download image {metadata_id}: HTTP {response.status_code}")
return None
image = Image.open(io.BytesIO(response.content)).convert('RGB')
logger.info(f"Successfully downloaded image {metadata_id}")
processed_items = process_segmentation(image)
if processed_items and len(processed_items) > 0:
# ๊ฐ€์žฅ ํฐ ์„ธ๊ทธ๋จผํŠธ์˜ ๋งˆ์Šคํฌ ์‚ฌ์šฉ
largest_mask = max(processed_items, key=lambda x: np.sum(x['mask']))['mask']
features = extract_features(image, largest_mask)
logger.info(f"Successfully extracted features for image {metadata_id}")
return features
logger.warning(f"No valid mask found for image {metadata_id}")
return None
except Exception as e:
logger.error(f"Error processing image {metadata_id}: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return None
def update_db_with_segmentation():
"""DB์˜ ๋ชจ๋“  ์ด๋ฏธ์ง€์— ๋Œ€ํ•ด segmentation์„ ์ ์šฉํ•˜๊ณ  feature๋ฅผ ์—…๋ฐ์ดํŠธ"""
try:
logger.info("Starting database update with segmentation and color features")
# ์ƒˆ๋กœ์šด collection ์ƒ์„ฑ
try:
client.delete_collection("clothes_segmented")
logger.info("Deleted existing segmented collection")
except:
logger.info("No existing segmented collection to delete")
new_collection = client.create_collection(
name="clothes_segmented",
metadata={"description": "Clothes collection with segmentation and color features"}
)
logger.info("Created new segmented collection")
# ๊ธฐ์กด collection์—์„œ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ๋งŒ ๊ฐ€์ ธ์˜ค๊ธฐ
try:
all_items = collection.get(include=['metadatas'])
total_items = len(all_items['metadatas'])
logger.info(f"Found {total_items} items in database")
except Exception as e:
logger.error(f"Error getting items from collection: {str(e)}")
all_items = {'metadatas': []}
total_items = 0
# ์ง„ํ–‰ ์ƒํ™ฉ ํ‘œ์‹œ๋ฅผ ์œ„ํ•œ progress bar
progress_bar = st.progress(0)
status_text = st.empty()
successful_updates = 0
failed_updates = 0
with ThreadPoolExecutor(max_workers=4) as executor:
futures = []
# ์ด๋ฏธ์ง€ URL์ด ์žˆ๋Š” ํ•ญ๋ชฉ๋งŒ ์ฒ˜๋ฆฌ
valid_items = [m for m in all_items['metadatas'] if 'image_url' in m]
for metadata in valid_items:
future = executor.submit(
download_and_process_image,
metadata['image_url'],
metadata.get('id', 'unknown')
)
futures.append((metadata, future))
# ๊ฒฐ๊ณผ ์ฒ˜๋ฆฌ ๋ฐ ์ƒˆ DB์— ์ €์žฅ
for idx, (metadata, future) in enumerate(futures):
try:
new_features = future.result()
if new_features is not None:
item_id = metadata.get('id', str(hash(metadata['image_url'])))
try:
new_collection.add(
embeddings=[new_features.tolist()],
metadatas=[metadata],
ids=[item_id]
)
successful_updates += 1
logger.info(f"Successfully added item {item_id}")
except Exception as e:
logger.error(f"Error adding item to new collection: {str(e)}")
failed_updates += 1
else:
failed_updates += 1
# ์ง„ํ–‰ ์ƒํ™ฉ ์—…๋ฐ์ดํŠธ
progress = (idx + 1) / len(futures)
progress_bar.progress(progress)
status_text.text(f"Processing: {idx + 1}/{len(futures)} items. Success: {successful_updates}, Failed: {failed_updates}")
except Exception as e:
logger.error(f"Error processing item: {str(e)}")
failed_updates += 1
continue
# ์ตœ์ข… ๊ฒฐ๊ณผ ํ‘œ์‹œ
status_text.text(f"Update completed. Successfully processed: {successful_updates}, Failed: {failed_updates}")
logger.info(f"Database update completed. Successful: {successful_updates}, Failed: {failed_updates}")
# ์„ฑ๊ณต์ ์œผ๋กœ ์ฒ˜๋ฆฌ๋œ ํ•ญ๋ชฉ์ด ์žˆ๋Š”์ง€ ํ™•์ธ
if successful_updates > 0:
return True
else:
logger.error("No items were successfully processed")
return False
except Exception as e:
logger.error(f"Database update error: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return False
def search_similar_items(features, top_k=10):
"""Search similar items using combined features"""
try:
# ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜์ด ์ ์šฉ๋œ collection์ด ์žˆ๋Š”์ง€ ํ™•์ธ
try:
search_collection = client.get_collection("clothes_segmented")
logger.info("Using segmented collection for search")
except:
# ์—†์œผ๋ฉด ๊ธฐ์กด collection ์‚ฌ์šฉ
search_collection = collection
logger.info("Using original collection for search")
results = search_collection.query(
query_embeddings=[features.tolist()],
n_results=top_k,
include=['metadatas', 'scores']
)
if not results or not results['metadatas'] or not results['scores']:
logger.warning("No results returned from ChromaDB")
return []
similar_items = []
for metadata, distance in zip(results['metadatas'][0], results['scores'][0]):
try:
similarity_score = distance
item_data = metadata.copy()
item_data['similarity_score'] = similarity_score
similar_items.append(item_data)
except Exception as e:
logger.error(f"Error processing search result: {str(e)}")
continue
similar_items.sort(key=lambda x: x['similarity_score'], reverse=True)
return similar_items
except Exception as e:
logger.error(f"Search error: {str(e)}")
return []
def show_similar_items(similar_items):
"""Display similar items in a structured format with similarity scores"""
if not similar_items:
st.warning("No similar items found.")
return
st.subheader("Similar Items:")
# ๊ฒฐ๊ณผ๋ฅผ 2์—ด๋กœ ํ‘œ์‹œ
items_per_row = 2
for i in range(0, len(similar_items), items_per_row):
cols = st.columns(items_per_row)
for j, col in enumerate(cols):
if i + j < len(similar_items):
item = similar_items[i + j]
with col:
try:
if 'image_url' in item:
st.image(item['image_url'], use_column_width=True)
# ์œ ์‚ฌ๋„ ์ ์ˆ˜๋ฅผ ํผ์„ผํŠธ๋กœ ํ‘œ์‹œ
similarity_percent = item['similarity_score']
st.markdown(f"**Similarity: {similarity_percent:.1f}%**")
st.write(f"Brand: {item.get('brand', 'Unknown')}")
name = item.get('name', 'Unknown')
if len(name) > 50: # ๊ธด ์ด๋ฆ„์€ ์ค„์ž„
name = name[:47] + "..."
st.write(f"Name: {name}")
# ๊ฐ€๊ฒฉ ์ •๋ณด ํ‘œ์‹œ
price = item.get('price', 0)
if isinstance(price, (int, float)):
st.write(f"Price: {price:,}์›")
else:
st.write(f"Price: {price}")
# ํ• ์ธ ์ •๋ณด๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ
if 'discount' in item and item['discount']:
st.write(f"Discount: {item['discount']}%")
if 'original_price' in item:
st.write(f"Original: {item['original_price']:,}์›")
st.divider() # ๊ตฌ๋ถ„์„  ์ถ”๊ฐ€
except Exception as e:
logger.error(f"Error displaying item: {e}")
st.error("Error displaying this item")
def process_search(image, mask, num_results):
"""์œ ์‚ฌ ์•„์ดํ…œ ๊ฒ€์ƒ‰ ์ฒ˜๋ฆฌ"""
try:
with st.spinner("Extracting features..."):
features = extract_features(image, mask)
with st.spinner("Finding similar items..."):
similar_items = search_similar_items(features, top_k=num_results)
return similar_items
except Exception as e:
logger.error(f"Search processing error: {e}")
return None
def handle_file_upload():
if st.session_state.uploaded_file is not None:
image = Image.open(st.session_state.uploaded_file).convert('RGB')
st.session_state.image = image
st.session_state.upload_state = 'image_uploaded'
st.rerun()
def handle_detection():
if st.session_state.image is not None:
detected_items = process_segmentation(st.session_state.image)
st.session_state.detected_items = detected_items
st.session_state.upload_state = 'items_detected'
st.rerun()
def handle_search():
st.session_state.search_clicked = True
def main():
st.title("Fashion Search App")
# Admin controls in sidebar
st.sidebar.title("Admin Controls")
if st.sidebar.checkbox("Show Admin Interface"):
# Admin interface ๊ตฌํ˜„ (ํ•„์š”ํ•œ ๊ฒฝ์šฐ)
st.sidebar.warning("Admin interface is not implemented yet.")
st.divider()
# ํŒŒ์ผ ์—…๋กœ๋”
if st.session_state.upload_state == 'initial':
uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'],
key='uploaded_file', on_change=handle_file_upload)
# ์ด๋ฏธ์ง€๊ฐ€ ์—…๋กœ๋“œ๋œ ์ƒํƒœ
if st.session_state.image is not None:
st.image(st.session_state.image, caption="Uploaded Image", use_column_width=True)
if st.session_state.detected_items is None:
if st.button("Detect Items", key='detect_button', on_click=handle_detection):
pass
# ๊ฒ€์ถœ๋œ ์•„์ดํ…œ ํ‘œ์‹œ
if st.session_state.detected_items is not None and len(st.session_state.detected_items) > 0:
# ๊ฐ์ง€๋œ ์•„์ดํ…œ๋“ค์„ 2์—ด๋กœ ํ‘œ์‹œ
cols = st.columns(2)
for idx, item in enumerate(st.session_state.detected_items):
with cols[idx % 2]:
try:
if item.get('mask') is not None:
masked_img = np.array(st.session_state.image) * np.expand_dims(item['mask'], axis=2)
st.image(masked_img.astype(np.uint8), caption=f"Detected {item.get('label', 'Unknown')}")
st.write(f"Item {idx + 1}: {item.get('label', 'Unknown')}")
# score ๊ฐ’์ด ์žˆ๊ณ  ์ˆซ์ž์ธ ๊ฒฝ์šฐ์—๋งŒ ํ‘œ์‹œ
score = item.get('score')
if score is not None and isinstance(score, (int, float)):
st.write(f"Confidence: {score*100:.1f}%")
else:
st.write("Confidence: N/A")
except Exception as e:
logger.error(f"Error displaying item {idx}: {str(e)}")
st.error(f"Error displaying item {idx}")
valid_items = [i for i in range(len(st.session_state.detected_items))
if st.session_state.detected_items[i].get('mask') is not None]
if not valid_items:
st.warning("No valid items detected for search.")
return
# ์•„์ดํ…œ ์„ ํƒ
selected_idx = st.selectbox(
"Select item to search:",
valid_items,
format_func=lambda i: f"{st.session_state.detected_items[i].get('label', 'Unknown')}",
key='item_selector'
)
# ๊ฒ€์ƒ‰ ์ปจํŠธ๋กค
search_col1, search_col2 = st.columns([1, 2])
with search_col1:
search_clicked = st.button("Search Similar Items",
key='search_button',
type="primary")
with search_col2:
num_results = st.slider("Number of results:",
min_value=1,
max_value=20,
value=5,
key='num_results')
# ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ์ฒ˜๋ฆฌ
if search_clicked or st.session_state.get('search_clicked', False):
st.session_state.search_clicked = True
selected_item = st.session_state.detected_items[selected_idx]
if selected_item.get('mask') is None:
st.error("Selected item has no valid mask for search.")
return
# ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋ฅผ ์„ธ์…˜ ์ƒํƒœ์— ์ €์žฅ
if 'search_results' not in st.session_state:
similar_items = process_search(st.session_state.image, selected_item['mask'], num_results)
st.session_state.search_results = similar_items
# ์ €์žฅ๋œ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ํ‘œ์‹œ
if st.session_state.search_results:
show_similar_items(st.session_state.search_results)
else:
st.warning("No similar items found.")
# ์ƒˆ ๊ฒ€์ƒ‰ ๋ฒ„ํŠผ
if st.button("Start New Search", key='new_search'):
# ๋ชจ๋“  ์ƒํƒœ ์ดˆ๊ธฐํ™”
for key in list(st.session_state.keys()):
del st.session_state[key]
st.rerun()
if __name__ == "__main__":
main()