Spaces:
Running
Running
import streamlit as st | |
import open_clip | |
import torch | |
from PIL import Image | |
import numpy as np | |
from transformers import pipeline | |
import chromadb | |
import logging | |
# λ‘κΉ μ€μ | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Initialize session state | |
if 'image' not in st.session_state: | |
st.session_state.image = None | |
if 'detected_items' not in st.session_state: | |
st.session_state.detected_items = None | |
if 'selected_item_index' not in st.session_state: | |
st.session_state.selected_item_index = None | |
if 'upload_state' not in st.session_state: | |
st.session_state.upload_state = 'initial' | |
# Load models μλ | |
def load_models(): | |
try: | |
# CLIP λͺ¨λΈ | |
model, _, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP') | |
# μΈκ·Έλ©ν μ΄μ λͺ¨λΈ | |
segmenter = pipeline(model="mattmdjaga/segformer_b2_clothes") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
return model, preprocess_val, segmenter, device | |
except Exception as e: | |
logger.error(f"Error loading models: {e}") | |
raise | |
# λͺ¨λΈ λ‘λ | |
clip_model, preprocess_val, segmenter, device = load_models() | |
# ChromaDB μ€μ | |
client = chromadb.PersistentClient(path="./clothesDB_11GmarketMusinsa") | |
collection = client.get_collection(name="clothes") | |
def process_segmentation(image): | |
"""Segmentation processing μλ νμΈμ""" | |
try: | |
segments = segmenter(image) | |
valid_items = [] | |
for s in segments: | |
mask_array = np.array(s['mask']) | |
confidence = np.mean(mask_array) | |
valid_items.append({ | |
'score': confidence, | |
'label': s['label'], | |
'mask': mask_array | |
}) | |
return valid_items | |
except Exception as e: | |
logger.error(f"Segmentation error: {e}") | |
return [] | |
def extract_features(image, mask=None): | |
"""Extract CLIP features""" | |
try: | |
if mask is not None: | |
img_array = np.array(image) | |
mask = np.expand_dims(mask, axis=2) | |
masked_img = img_array * mask | |
masked_img[mask[:,:,0] == 0] = 255 | |
image = Image.fromarray(masked_img.astype(np.uint8)) | |
image_tensor = preprocess_val(image).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
features = clip_model.encode_image(image_tensor) | |
features /= features.norm(dim=-1, keepdim=True) | |
return features.cpu().numpy().flatten() | |
except Exception as e: | |
logger.error(f"Feature extraction error: {e}") | |
raise | |
def search_similar_items(features, top_k=10): | |
"""Search similar items with distance scores""" | |
try: | |
results = collection.query( | |
query_embeddings=[features.tolist()], | |
n_results=top_k, | |
include=['metadatas', 'distances'] # distances ν¬ν¨ | |
) | |
similar_items = [] | |
for metadata, distance in zip(results['metadatas'][0], results['distances'][0]): | |
# 거리λ₯Ό μ μ¬λ μ μλ‘ λ³ν (0~1 λ²μ) | |
similarity_score = 1 / (1 + distance) | |
metadata['similarity_score'] = similarity_score # λ©νλ°μ΄ν°μ μ μ μΆκ° | |
similar_items.append(metadata) | |
return similar_items | |
except Exception as e: | |
logger.error(f"Search error: {e}") | |
return [] | |
def show_similar_items(similar_items): | |
"""Display similar items in a structured format with similarity scores""" | |
st.subheader("Similar Items:") | |
for item in similar_items: | |
col1, col2 = st.columns([1, 2]) | |
with col1: | |
st.image(item['image_url']) | |
with col2: | |
# μ μ¬λ μ μλ₯Ό νΌμΌνΈλ‘ νμ | |
similarity_percent = item['similarity_score'] * 100 | |
st.write(f"Similarity: {similarity_percent:.1f}%") | |
st.write(f"Brand: {item.get('brand', 'Unknown')}") | |
st.write(f"Name: {item.get('name', 'Unknown')}") | |
st.write(f"Price: {item.get('price', 'Unknown'):,}μ") | |
if 'discount' in item: | |
st.write(f"Discount: {item['discount']}%") | |
if 'original_price' in item: | |
st.write(f"Original Price: {item['original_price']:,}μ") | |
# Initialize session state | |
if 'image' not in st.session_state: | |
st.session_state.image = None | |
if 'detected_items' not in st.session_state: | |
st.session_state.detected_items = None | |
if 'selected_item_index' not in st.session_state: | |
st.session_state.selected_item_index = None | |
if 'upload_state' not in st.session_state: | |
st.session_state.upload_state = 'initial' | |
if 'search_clicked' not in st.session_state: | |
st.session_state.search_clicked = False | |
def reset_state(): | |
"""Reset all session state variables""" | |
for key in list(st.session_state.keys()): | |
del st.session_state[key] | |
# Callback functions | |
def handle_file_upload(): | |
if st.session_state.uploaded_file is not None: | |
image = Image.open(st.session_state.uploaded_file).convert('RGB') | |
st.session_state.image = image | |
st.session_state.upload_state = 'image_uploaded' | |
st.rerun() | |
def handle_detection(): | |
if st.session_state.image is not None: | |
detected_items = process_segmentation(st.session_state.image) | |
st.session_state.detected_items = detected_items | |
st.session_state.upload_state = 'items_detected' | |
st.rerun() | |
def handle_search(): | |
st.session_state.search_clicked = True | |
def main(): | |
st.title("ν¬μ΄λΈλ fashion demo!!!") | |
# νμΌ μ λ‘λ (upload_stateκ° initialμΌ λλ§ νμ) | |
if st.session_state.upload_state == 'initial': | |
uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'], | |
key='uploaded_file', on_change=handle_file_upload) | |
# μ΄λ―Έμ§κ° μ λ‘λλ μν df | |
if st.session_state.image is not None: | |
st.image(st.session_state.image, caption="Uploaded Image", use_column_width=True) | |
if st.session_state.detected_items is None: | |
if st.button("Detect Items", key='detect_button', on_click=handle_detection): | |
pass | |
# κ²μΆλ μμ΄ν νμd | |
if st.session_state.detected_items: | |
# κ°μ§λ μμ΄ν λ€dμ 2μ΄λ‘ νμ | |
cols = st.columns(2) | |
for idx, item in enumerate(st.session_state.detected_items): | |
with cols[idx % 2]: | |
mask = item['mask'] | |
masked_img = np.array(st.session_state.image) * np.expand_dims(mask, axis=2) | |
st.image(masked_img.astype(np.uint8), caption=f"Detected {item['label']}") | |
st.write(f"Item {idx + 1}: {item['label']}") | |
st.write(f"Confidence: {item['score']*100:.1f}%") | |
# μμ΄ν μ ν | |
selected_idx = st.selectbox( | |
"Select item to search:", | |
range(len(st.session_state.detected_items)), | |
format_func=lambda i: f"{st.session_state.detected_items[i]['label']}", | |
key='item_selector' | |
) | |
st.session_state.selected_item_index = selected_idx | |
# μ μ¬ μμ΄ν κ²μ | |
col1, col2 = st.columns([1, 2]) | |
with col1: | |
search_button = st.button("Search Similar Items", | |
key='search_button', | |
on_click=handle_search, | |
type="primary") # κ°μ‘°λ λ²νΌ | |
with col2: | |
num_results = st.slider("Number of results:", | |
min_value=1, | |
max_value=20, | |
value=5, | |
key='num_results') | |
if st.session_state.search_clicked: | |
with st.spinner("Searching similar items..."): | |
try: | |
selected_mask = st.session_state.detected_items[selected_idx]['mask'] | |
features = extract_features(st.session_state.image, selected_mask) | |
similar_items = search_similar_items(features, top_k=num_results) | |
if similar_items: | |
show_similar_items(similar_items) | |
else: | |
st.warning("No similar items found.") | |
except Exception as e: | |
st.error(f"Error during search: {str(e)}") | |
# μ κ²μ λ²νΌ | |
if st.button("Start New Search ", key='new_search'): | |
reset_state() | |
st.rerun() | |
if __name__ == "__main__": | |
main() |