leedoming's picture
Update app.py
c9fcc85 verified
import streamlit as st
import open_clip
import torch
from PIL import Image
import numpy as np
from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation
import chromadb
import logging
import io
import requests
from concurrent.futures import ThreadPoolExecutor
from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction
from chromadb.utils.data_loaders import ImageLoader
# ๋กœ๊น… ์„ค์ •
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class CustomFashionEmbeddingFunction:
def __init__(self):
self.model, _, self.preprocess = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self.model.to(self.device)
def __call__(self, input):
try:
# ์ž…๋ ฅ์ด URL์ด๋‚˜ ๊ฒฝ๋กœ์ธ ๊ฒฝ์šฐ ์ฒ˜๋ฆฌ
processed_images = []
for img in input:
if isinstance(img, (str, bytes)):
if isinstance(img, str):
response = requests.get(img)
img = Image.open(io.BytesIO(response.content)).convert('RGB')
else:
img = Image.open(io.BytesIO(img)).convert('RGB')
elif isinstance(img, np.ndarray):
img = Image.fromarray(img.astype('uint8')).convert('RGB')
processed_img = self.preprocess(img).unsqueeze(0).to(self.device)
processed_images.append(processed_img)
# ๋ฐฐ์น˜ ์ฒ˜๋ฆฌ
batch = torch.cat(processed_images)
# CLIP ์ž„๋ฒ ๋”ฉ ์ถ”์ถœ
with torch.no_grad():
clip_features = self.model.encode_image(batch)
clip_features = clip_features.cpu().numpy()
# ์ƒ‰์ƒ ํŠน์ง• ์ถ”์ถœ
color_features_list = []
for img in input:
if isinstance(img, (str, bytes)):
if isinstance(img, str):
response = requests.get(img)
img = Image.open(io.BytesIO(response.content)).convert('RGB')
else:
img = Image.open(io.BytesIO(img)).convert('RGB')
elif isinstance(img, np.ndarray):
img = Image.fromarray(img.astype('uint8')).convert('RGB')
color_features = self.extract_color_histogram(img)
color_features_list.append(color_features)
# ํŠน์ง• ๊ฒฐํ•ฉ
combined_embeddings = []
for clip_emb, color_feat in zip(clip_features, color_features_list):
# CLIP ์ž„๋ฒ ๋”ฉ์„ 768์ฐจ์›์œผ๋กœ ํŒจ๋”ฉ
if clip_emb.shape[0] < 768:
padding = np.zeros(768 - clip_emb.shape[0])
clip_emb = np.concatenate([clip_emb, padding])
else:
clip_emb = clip_emb[:768] # 768์ฐจ์›์œผ๋กœ ์ž๋ฅด๊ธฐ
# ์ƒ‰์ƒ ํŠน์ง•์„ 768์ฐจ์›์œผ๋กœ ํ™•์žฅ
color_features_expanded = np.repeat(color_feat, 32) # 24 * 32 = 768
# ์ •๊ทœํ™”
clip_emb = clip_emb / (np.linalg.norm(clip_emb) + 1e-8)
color_features_expanded = color_features_expanded / (np.linalg.norm(color_features_expanded) + 1e-8)
# ๊ฐ€์ค‘์น˜ ๊ฒฐํ•ฉ
combined = clip_emb * 0.7 + color_features_expanded * 0.3
combined = combined / (np.linalg.norm(combined) + 1e-8)
combined_embeddings.append(combined)
return np.array(combined_embeddings)
except Exception as e:
logger.error(f"Error in embedding function: {e}")
raise
def extract_color_histogram(self, image):
"""Extract color histogram from the image"""
try:
if isinstance(image, (str, bytes)):
if isinstance(image, str):
response = requests.get(image)
image = Image.open(io.BytesIO(response.content))
else:
image = Image.open(io.BytesIO(image))
if not isinstance(image, np.ndarray):
img_array = np.array(image)
else:
img_array = image
# HSV ๋ณ€ํ™˜ ๋ฐ ํžˆ์Šคํ† ๊ทธ๋žจ ๊ณ„์‚ฐ
img_hsv = Image.fromarray(img_array.astype('uint8')).convert('HSV')
hsv_pixels = np.array(img_hsv)
h_hist = np.histogram(hsv_pixels[:,:,0], bins=8, range=(0, 256))[0]
s_hist = np.histogram(hsv_pixels[:,:,1], bins=8, range=(0, 256))[0]
v_hist = np.histogram(hsv_pixels[:,:,2], bins=8, range=(0, 256))[0]
# ์ •๊ทœํ™”
h_hist = h_hist / (h_hist.sum() + 1e-8)
s_hist = s_hist / (s_hist.sum() + 1e-8)
v_hist = v_hist / (v_hist.sum() + 1e-8)
return np.concatenate([h_hist, s_hist, v_hist])
except Exception as e:
logger.error(f"Color histogram extraction error: {e}")
return np.zeros(24)
# Initialize session state
if 'image' not in st.session_state:
st.session_state.image = None
if 'detected_items' not in st.session_state:
st.session_state.detected_items = None
if 'selected_item_index' not in st.session_state:
st.session_state.selected_item_index = None
if 'upload_state' not in st.session_state:
st.session_state.upload_state = 'initial'
if 'search_clicked' not in st.session_state:
st.session_state.search_clicked = False
# Load segmentation model
@st.cache_resource
def load_segmentation_model():
try:
model_name = "mattmdjaga/segformer_b2_clothes"
image_processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForSemanticSegmentation.from_pretrained(model_name)
if torch.cuda.is_available():
model = model.to('cuda')
return model, image_processor
except Exception as e:
logger.error(f"Error loading segmentation model: {e}")
raise
# ChromaDB ์„ค์ •
def setup_multimodal_collection():
"""๋ฉ€ํ‹ฐ๋ชจ๋‹ฌ ์ปฌ๋ ‰์…˜ ์„ค์ •"""
try:
client = chromadb.PersistentClient(path="./fashion_multimodal_db")
embedding_function = CustomFashionEmbeddingFunction()
data_loader = ImageLoader()
# ๊ธฐ์กด ์ปฌ๋ ‰์…˜ ๊ฐ€์ ธ์˜ค๊ธฐ
try:
collection = client.get_collection(
name="fashion_multimodal_v2",
embedding_function=embedding_function,
data_loader=data_loader
)
logger.info("Successfully connected to existing clothes_multimodal collection")
return collection
except Exception as e:
logger.error(f"Error getting existing collection: {e}")
# ์ปฌ๋ ‰์…˜์ด ์—†๋Š” ๊ฒฝ์šฐ์—๋งŒ ์ƒˆ๋กœ ์ƒ์„ฑ
collection = client.create_collection(
name="clothes_multimodal",
embedding_function=embedding_function,
data_loader=data_loader
)
logger.info("Created new clothes_multimodal collection")
return collection
except Exception as e:
logger.error(f"Error setting up multimodal collection: {e}")
raise
def process_segmentation(image):
"""Segmentation processing"""
try:
model, image_processor = load_segmentation_model()
# ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ
inputs = image_processor(image, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to('cuda') for k, v in inputs.items()}
# ์ถ”๋ก 
with torch.no_grad():
outputs = model(**inputs)
# ๋กœ์ง ๋ฐ ํ›„์ฒ˜๋ฆฌ
logits = outputs.logits.cpu()
upsampled_logits = torch.nn.functional.interpolate(
logits,
size=image.size[::-1], # (height, width)
mode="bilinear",
align_corners=False,
)
# ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ๋งˆ์Šคํฌ ์ƒ์„ฑ
seg_masks = upsampled_logits.argmax(dim=1).numpy()
processed_items = []
unique_labels = np.unique(seg_masks)
for label_idx in unique_labels:
if label_idx == 0: # background
continue
mask = (seg_masks[0] == label_idx).astype(float)
processed_segment = {
'label': f"Item_{label_idx}", # ๋ผ๋ฒจ ๋งคํ•‘์ด ํ•„์š”ํ•˜๋‹ค๋ฉด ์—ฌ๊ธฐ์„œ ์ฒ˜๋ฆฌ
'score': 1.0, # confidence score ๊ณ„์‚ฐ์ด ํ•„์š”ํ•˜๋‹ค๋ฉด ์ถ”๊ฐ€
'mask': mask
}
processed_items.append(processed_segment)
logger.info(f"Successfully processed {len(processed_items)} segments")
return processed_items
except Exception as e:
logger.error(f"Segmentation error: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return []
def search_similar_items(image, mask=None, top_k=10):
"""๋‘ ๊ฐœ์˜ ๋ฉ€ํ‹ฐ๋ชจ๋‹ฌ ์ปฌ๋ ‰์…˜์—์„œ ๊ฒ€์ƒ‰ ์ˆ˜ํ–‰"""
try:
client = chromadb.PersistentClient(path="./fashion_multimodal_db")
embedding_function = CustomFashionEmbeddingFunction()
data_loader = ImageLoader()
# ๋‘ ์ปฌ๋ ‰์…˜ ๋ชจ๋‘ ๊ฐ€์ ธ์˜ค๊ธฐ
collections = []
collection_names = ["fashion_multimodal", "fashion_multimodal_v2"]
for name in collection_names:
try:
collection = client.get_collection(
name=name,
embedding_function=embedding_function,
data_loader=data_loader
)
collections.append(collection)
logger.info(f"Successfully connected to {name} collection")
except Exception as e:
logger.error(f"Error getting collection {name}: {e}")
continue
if not collections:
logger.error("No collections available for search")
return []
# ๋งˆ์Šคํฌ ์ ์šฉ
if mask is not None:
mask_3d = np.stack([mask] * 3, axis=-1)
masked_image = np.array(image) * mask_3d
query_image = Image.fromarray(masked_image.astype(np.uint8))
else:
query_image = image
# ๊ฐ ์ปฌ๋ ‰์…˜์—์„œ ๊ฒ€์ƒ‰ ์ˆ˜ํ–‰
all_results = []
for collection in collections:
try:
results = collection.query(
query_images=[np.array(query_image)],
n_results=top_k,
include=['metadatas', 'distances']
)
if results and 'metadatas' in results:
for metadata, distance in zip(results['metadatas'][0], results['distances'][0]):
# L2 ๊ฑฐ๋ฆฌ๋ฅผ ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„๋กœ ๋ณ€ํ™˜
cosine_similarity = 1 - (distance ** 2 / 2)
similarity_score = ((cosine_similarity + 1) / 2) * 100
item_data = metadata.copy()
item_data['similarity_score'] = similarity_score
all_results.append(item_data)
except Exception as e:
logger.error(f"Error searching in collection: {e}")
continue
# ๊ฒฐ๊ณผ ์ •๋ ฌ ๋ฐ ์ค‘๋ณต ์ œ๊ฑฐ
# URL ๊ธฐ๋ฐ˜์œผ๋กœ ์ค‘๋ณต ์ œ๊ฑฐ
seen_urls = set()
unique_results = []
for item in sorted(all_results, key=lambda x: x['similarity_score'], reverse=True):
url = item.get('image_url', '')
if url not in seen_urls:
seen_urls.add(url)
unique_results.append(item)
# top_k ๊ฐœ์ˆ˜๋งŒํผ ์ˆ˜์ง‘๋˜๋ฉด ์ข…๋ฃŒ
if len(unique_results) >= top_k:
break
return unique_results
except Exception as e:
logger.error(f"Multimodal search error: {e}")
return []
# def update_db_with_multimodal():
# """DB๋ฅผ ๋ฉ€ํ‹ฐ๋ชจ๋‹ฌ ๋ฐฉ์‹์œผ๋กœ ์—…๋ฐ์ดํŠธ"""
# try:
# # ์ƒˆ ์ปฌ๋ ‰์…˜ ์ƒ์„ฑ
# collection = setup_multimodal_collection()
# # ๊ธฐ์กด ์ปฌ๋ ‰์…˜์—์„œ ๋ฐ์ดํ„ฐ ๊ฐ€์ ธ์˜ค๊ธฐ
# client = chromadb.PersistentClient(path="./clothesDB_11GmarketMusinsa")
# old_collection = client.get_collection("clothes")
# old_data = old_collection.get(include=['metadatas'])
# total_items = len(old_data['metadatas'])
# progress_bar = st.progress(0)
# status_text = st.empty()
# batch_size = 100
# successful_updates = 0
# failed_updates = 0
# for i in range(0, total_items, batch_size):
# batch = old_data['metadatas'][i:i + batch_size]
# images = []
# valid_metadatas = []
# valid_ids = []
# for metadata in batch:
# try:
# if 'image_url' in metadata:
# response = requests.get(metadata['image_url'])
# img = Image.open(io.BytesIO(response.content)).convert('RGB')
# images.append(np.array(img))
# valid_metadatas.append(metadata)
# valid_ids.append(metadata.get('id', str(hash(metadata['image_url']))))
# successful_updates += 1
# except Exception as e:
# logger.error(f"Error processing image: {e}")
# failed_updates += 1
# continue
# if images:
# collection.add(
# ids=valid_ids,
# images=images,
# metadatas=valid_metadatas
# )
# # Update progress
# progress = (i + len(batch)) / total_items
# progress_bar.progress(progress)
# status_text.text(f"Processing: {i + len(batch)}/{total_items} items. "
# f"Success: {successful_updates}, Failed: {failed_updates}")
# status_text.text(f"Update completed. Successfully processed: {successful_updates}, "
# f"Failed: {failed_updates}")
# return True
# except Exception as e:
# logger.error(f"Multimodal DB update error: {e}")
# return False
def show_similar_items(similar_items):
"""Display similar items in a structured format with similarity scores"""
if not similar_items:
st.warning("No similar items found.")
return
st.subheader("Similar Items:")
items_per_row = 2
for i in range(0, len(similar_items), items_per_row):
cols = st.columns(items_per_row)
for j, col in enumerate(cols):
if i + j < len(similar_items):
item = similar_items[i + j]
with col:
try:
if 'image_url' in item:
st.image(item['image_url'], use_column_width=True)
st.markdown(f"**Similarity: {item['similarity_score']:.1f}%**")
st.write(f"Brand: {item.get('brand', 'Unknown')}")
name = item.get('name', 'Unknown')
if len(name) > 50:
name = name[:47] + "..."
st.write(f"Name: {name}")
price = item.get('price', 0)
if isinstance(price, (int, float)):
st.write(f"Price: {price:,}์›")
else:
st.write(f"Price: {price}")
if 'discount' in item and item['discount']:
st.write(f"Discount: {item['discount']}%")
if 'original_price' in item:
st.write(f"Original: {item['original_price']:,}์›")
st.divider()
except Exception as e:
logger.error(f"Error displaying item: {e}")
st.error("Error displaying this item")
def process_search(image, mask, num_results):
"""์œ ์‚ฌ ์•„์ดํ…œ ๊ฒ€์ƒ‰ ์ฒ˜๋ฆฌ"""
try:
with st.spinner("Finding similar items..."):
similar_items = search_similar_items(image, mask, num_results)
return similar_items
except Exception as e:
logger.error(f"Search processing error: {e}")
return None
def handle_file_upload():
if st.session_state.uploaded_file is not None:
image = Image.open(st.session_state.uploaded_file).convert('RGB')
st.session_state.image = image
st.session_state.upload_state = 'image_uploaded'
st.rerun()
def handle_detection():
if st.session_state.image is not None:
detected_items = process_segmentation(st.session_state.image)
st.session_state.detected_items = detected_items
st.session_state.upload_state = 'items_detected'
st.rerun()
def handle_search():
st.session_state.search_clicked = True
def main():
st.title("Fashion Search App")
# Admin controls in sidebar
st.sidebar.title("Admin Controls")
if st.sidebar.checkbox("Show Admin Interface"):
if st.sidebar.button("Update Database (Multimodal)"):
with st.spinner("Updating database with multimodal support..."):
success = update_db_with_multimodal()
if success:
st.sidebar.success("Database updated successfully!")
else:
st.sidebar.error("Failed to update database")
st.divider()
# ํŒŒ์ผ ์—…๋กœ๋”
if st.session_state.upload_state == 'initial':
uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'],
key='uploaded_file', on_change=handle_file_upload)
# ์ด๋ฏธ์ง€๊ฐ€ ์—…๋กœ๋“œ๋œ ์ƒํƒœ
if st.session_state.image is not None:
st.image(st.session_state.image, caption="Uploaded Image", use_column_width=True)
if st.session_state.detected_items is None:
if st.button("Detect Items", key='detect_button', on_click=handle_detection):
pass
# ๊ฒ€์ถœ๋œ ์•„์ดํ…œ ํ‘œ์‹œ ๋ฐ ๊ฒ€์ƒ‰
if st.session_state.detected_items is not None and len(st.session_state.detected_items) > 0:
cols = st.columns(2)
for idx, item in enumerate(st.session_state.detected_items):
with cols[idx % 2]:
try:
if item.get('mask') is not None:
masked_img = np.array(st.session_state.image) * np.expand_dims(item['mask'], axis=2)
st.image(masked_img.astype(np.uint8), caption=f"Detected {item.get('label', 'Unknown')}")
st.write(f"Item {idx + 1}: {item.get('label', 'Unknown')}")
score = item.get('score')
if score is not None and isinstance(score, (int, float)):
st.write(f"Confidence: {score*100:.1f}%")
else:
st.write("Confidence: N/A")
except Exception as e:
logger.error(f"Error displaying item {idx}: {str(e)}")
st.error(f"Error displaying item {idx}")
valid_items = [i for i in range(len(st.session_state.detected_items))
if st.session_state.detected_items[i].get('mask') is not None]
if not valid_items:
st.warning("No valid items detected for search.")
return
selected_idx = st.selectbox(
"Select item to search:",
valid_items,
format_func=lambda i: f"{st.session_state.detected_items[i].get('label', 'Unknown')}",
key='item_selector'
)
search_col1, search_col2 = st.columns([1, 2])
with search_col1:
search_clicked = st.button("Search Similar Items",
key='search_button',
type="primary")
with search_col2:
num_results = st.slider("Number of results:",
min_value=1,
max_value=20,
value=5,
key='num_results')
if search_clicked or st.session_state.get('search_clicked', False):
st.session_state.search_clicked = True
selected_item = st.session_state.detected_items[selected_idx]
if selected_item.get('mask') is None:
st.error("Selected item has no valid mask for search.")
return
if 'search_results' not in st.session_state:
similar_items = process_search(st.session_state.image,
selected_item['mask'],
num_results)
st.session_state.search_results = similar_items
if st.session_state.search_results:
show_similar_items(st.session_state.search_results)
else:
st.warning("No similar items found.")
# ์ƒˆ ๊ฒ€์ƒ‰ ๋ฒ„ํŠผ
if st.button("Start New Search", key='new_search'):
for key in list(st.session_state.keys()):
del st.session_state[key]
st.rerun()
if __name__ == "__main__":
print('์‹œ์ž‘')
main()