|
import streamlit as st |
|
import open_clip |
|
import torch |
|
from PIL import Image |
|
import numpy as np |
|
from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation |
|
import chromadb |
|
import logging |
|
import io |
|
import requests |
|
from concurrent.futures import ThreadPoolExecutor |
|
from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction |
|
from chromadb.utils.data_loaders import ImageLoader |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class CustomFashionEmbeddingFunction: |
|
def __init__(self): |
|
self.model, _, self.preprocess = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP') |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.model = self.model.to(self.device) |
|
|
|
def __call__(self, input): |
|
try: |
|
|
|
processed_images = [] |
|
for img in input: |
|
if isinstance(img, (str, bytes)): |
|
if isinstance(img, str): |
|
response = requests.get(img) |
|
img = Image.open(io.BytesIO(response.content)).convert('RGB') |
|
else: |
|
img = Image.open(io.BytesIO(img)).convert('RGB') |
|
elif isinstance(img, np.ndarray): |
|
img = Image.fromarray(img.astype('uint8')).convert('RGB') |
|
|
|
processed_img = self.preprocess(img).unsqueeze(0).to(self.device) |
|
processed_images.append(processed_img) |
|
|
|
|
|
batch = torch.cat(processed_images) |
|
|
|
|
|
with torch.no_grad(): |
|
clip_features = self.model.encode_image(batch) |
|
clip_features = clip_features.cpu().numpy() |
|
|
|
|
|
color_features_list = [] |
|
for img in input: |
|
if isinstance(img, (str, bytes)): |
|
if isinstance(img, str): |
|
response = requests.get(img) |
|
img = Image.open(io.BytesIO(response.content)).convert('RGB') |
|
else: |
|
img = Image.open(io.BytesIO(img)).convert('RGB') |
|
elif isinstance(img, np.ndarray): |
|
img = Image.fromarray(img.astype('uint8')).convert('RGB') |
|
|
|
color_features = self.extract_color_histogram(img) |
|
color_features_list.append(color_features) |
|
|
|
|
|
combined_embeddings = [] |
|
for clip_emb, color_feat in zip(clip_features, color_features_list): |
|
|
|
if clip_emb.shape[0] < 768: |
|
padding = np.zeros(768 - clip_emb.shape[0]) |
|
clip_emb = np.concatenate([clip_emb, padding]) |
|
else: |
|
clip_emb = clip_emb[:768] |
|
|
|
|
|
color_features_expanded = np.repeat(color_feat, 32) |
|
|
|
|
|
clip_emb = clip_emb / (np.linalg.norm(clip_emb) + 1e-8) |
|
color_features_expanded = color_features_expanded / (np.linalg.norm(color_features_expanded) + 1e-8) |
|
|
|
|
|
combined = clip_emb * 0.7 + color_features_expanded * 0.3 |
|
combined = combined / (np.linalg.norm(combined) + 1e-8) |
|
|
|
combined_embeddings.append(combined) |
|
|
|
return np.array(combined_embeddings) |
|
|
|
except Exception as e: |
|
logger.error(f"Error in embedding function: {e}") |
|
raise |
|
|
|
def extract_color_histogram(self, image): |
|
"""Extract color histogram from the image""" |
|
try: |
|
if isinstance(image, (str, bytes)): |
|
if isinstance(image, str): |
|
response = requests.get(image) |
|
image = Image.open(io.BytesIO(response.content)) |
|
else: |
|
image = Image.open(io.BytesIO(image)) |
|
|
|
if not isinstance(image, np.ndarray): |
|
img_array = np.array(image) |
|
else: |
|
img_array = image |
|
|
|
|
|
img_hsv = Image.fromarray(img_array.astype('uint8')).convert('HSV') |
|
hsv_pixels = np.array(img_hsv) |
|
|
|
h_hist = np.histogram(hsv_pixels[:,:,0], bins=8, range=(0, 256))[0] |
|
s_hist = np.histogram(hsv_pixels[:,:,1], bins=8, range=(0, 256))[0] |
|
v_hist = np.histogram(hsv_pixels[:,:,2], bins=8, range=(0, 256))[0] |
|
|
|
|
|
h_hist = h_hist / (h_hist.sum() + 1e-8) |
|
s_hist = s_hist / (s_hist.sum() + 1e-8) |
|
v_hist = v_hist / (v_hist.sum() + 1e-8) |
|
|
|
return np.concatenate([h_hist, s_hist, v_hist]) |
|
except Exception as e: |
|
logger.error(f"Color histogram extraction error: {e}") |
|
return np.zeros(24) |
|
|
|
|
|
if 'image' not in st.session_state: |
|
st.session_state.image = None |
|
if 'detected_items' not in st.session_state: |
|
st.session_state.detected_items = None |
|
if 'selected_item_index' not in st.session_state: |
|
st.session_state.selected_item_index = None |
|
if 'upload_state' not in st.session_state: |
|
st.session_state.upload_state = 'initial' |
|
if 'search_clicked' not in st.session_state: |
|
st.session_state.search_clicked = False |
|
|
|
|
|
@st.cache_resource |
|
def load_segmentation_model(): |
|
try: |
|
model_name = "mattmdjaga/segformer_b2_clothes" |
|
image_processor = AutoImageProcessor.from_pretrained(model_name) |
|
model = AutoModelForSemanticSegmentation.from_pretrained(model_name) |
|
|
|
if torch.cuda.is_available(): |
|
model = model.to('cuda') |
|
|
|
return model, image_processor |
|
except Exception as e: |
|
logger.error(f"Error loading segmentation model: {e}") |
|
raise |
|
|
|
|
|
def setup_multimodal_collection(): |
|
"""๋ฉํฐ๋ชจ๋ฌ ์ปฌ๋ ์
์ค์ """ |
|
try: |
|
client = chromadb.PersistentClient(path="./fashion_multimodal_db") |
|
embedding_function = CustomFashionEmbeddingFunction() |
|
data_loader = ImageLoader() |
|
|
|
|
|
try: |
|
collection = client.get_collection( |
|
name="fashion_multimodal_v2", |
|
embedding_function=embedding_function, |
|
data_loader=data_loader |
|
) |
|
logger.info("Successfully connected to existing clothes_multimodal collection") |
|
return collection |
|
|
|
except Exception as e: |
|
logger.error(f"Error getting existing collection: {e}") |
|
|
|
collection = client.create_collection( |
|
name="clothes_multimodal", |
|
embedding_function=embedding_function, |
|
data_loader=data_loader |
|
) |
|
logger.info("Created new clothes_multimodal collection") |
|
return collection |
|
|
|
except Exception as e: |
|
logger.error(f"Error setting up multimodal collection: {e}") |
|
raise |
|
|
|
def process_segmentation(image): |
|
"""Segmentation processing""" |
|
try: |
|
model, image_processor = load_segmentation_model() |
|
|
|
|
|
inputs = image_processor(image, return_tensors="pt") |
|
|
|
if torch.cuda.is_available(): |
|
inputs = {k: v.to('cuda') for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
|
|
logits = outputs.logits.cpu() |
|
upsampled_logits = torch.nn.functional.interpolate( |
|
logits, |
|
size=image.size[::-1], |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
|
|
|
|
seg_masks = upsampled_logits.argmax(dim=1).numpy() |
|
|
|
processed_items = [] |
|
unique_labels = np.unique(seg_masks) |
|
|
|
for label_idx in unique_labels: |
|
if label_idx == 0: |
|
continue |
|
|
|
mask = (seg_masks[0] == label_idx).astype(float) |
|
|
|
processed_segment = { |
|
'label': f"Item_{label_idx}", |
|
'score': 1.0, |
|
'mask': mask |
|
} |
|
|
|
processed_items.append(processed_segment) |
|
|
|
logger.info(f"Successfully processed {len(processed_items)} segments") |
|
return processed_items |
|
|
|
except Exception as e: |
|
logger.error(f"Segmentation error: {str(e)}") |
|
import traceback |
|
logger.error(traceback.format_exc()) |
|
return [] |
|
|
|
def search_similar_items(image, mask=None, top_k=10): |
|
"""๋ ๊ฐ์ ๋ฉํฐ๋ชจ๋ฌ ์ปฌ๋ ์
์์ ๊ฒ์ ์ํ""" |
|
try: |
|
client = chromadb.PersistentClient(path="./fashion_multimodal_db") |
|
embedding_function = CustomFashionEmbeddingFunction() |
|
data_loader = ImageLoader() |
|
|
|
|
|
collections = [] |
|
collection_names = ["fashion_multimodal", "fashion_multimodal_v2"] |
|
|
|
for name in collection_names: |
|
try: |
|
collection = client.get_collection( |
|
name=name, |
|
embedding_function=embedding_function, |
|
data_loader=data_loader |
|
) |
|
collections.append(collection) |
|
logger.info(f"Successfully connected to {name} collection") |
|
except Exception as e: |
|
logger.error(f"Error getting collection {name}: {e}") |
|
continue |
|
|
|
if not collections: |
|
logger.error("No collections available for search") |
|
return [] |
|
|
|
|
|
if mask is not None: |
|
mask_3d = np.stack([mask] * 3, axis=-1) |
|
masked_image = np.array(image) * mask_3d |
|
query_image = Image.fromarray(masked_image.astype(np.uint8)) |
|
else: |
|
query_image = image |
|
|
|
|
|
all_results = [] |
|
|
|
for collection in collections: |
|
try: |
|
results = collection.query( |
|
query_images=[np.array(query_image)], |
|
n_results=top_k, |
|
include=['metadatas', 'distances'] |
|
) |
|
|
|
if results and 'metadatas' in results: |
|
for metadata, distance in zip(results['metadatas'][0], results['distances'][0]): |
|
|
|
cosine_similarity = 1 - (distance ** 2 / 2) |
|
similarity_score = ((cosine_similarity + 1) / 2) * 100 |
|
|
|
item_data = metadata.copy() |
|
item_data['similarity_score'] = similarity_score |
|
all_results.append(item_data) |
|
|
|
except Exception as e: |
|
logger.error(f"Error searching in collection: {e}") |
|
continue |
|
|
|
|
|
|
|
seen_urls = set() |
|
unique_results = [] |
|
|
|
for item in sorted(all_results, key=lambda x: x['similarity_score'], reverse=True): |
|
url = item.get('image_url', '') |
|
if url not in seen_urls: |
|
seen_urls.add(url) |
|
unique_results.append(item) |
|
|
|
|
|
if len(unique_results) >= top_k: |
|
break |
|
|
|
return unique_results |
|
|
|
except Exception as e: |
|
logger.error(f"Multimodal search error: {e}") |
|
return [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def show_similar_items(similar_items): |
|
"""Display similar items in a structured format with similarity scores""" |
|
if not similar_items: |
|
st.warning("No similar items found.") |
|
return |
|
|
|
st.subheader("Similar Items:") |
|
|
|
items_per_row = 2 |
|
for i in range(0, len(similar_items), items_per_row): |
|
cols = st.columns(items_per_row) |
|
for j, col in enumerate(cols): |
|
if i + j < len(similar_items): |
|
item = similar_items[i + j] |
|
with col: |
|
try: |
|
if 'image_url' in item: |
|
st.image(item['image_url'], use_column_width=True) |
|
|
|
st.markdown(f"**Similarity: {item['similarity_score']:.1f}%**") |
|
|
|
st.write(f"Brand: {item.get('brand', 'Unknown')}") |
|
name = item.get('name', 'Unknown') |
|
if len(name) > 50: |
|
name = name[:47] + "..." |
|
st.write(f"Name: {name}") |
|
|
|
price = item.get('price', 0) |
|
if isinstance(price, (int, float)): |
|
st.write(f"Price: {price:,}์") |
|
else: |
|
st.write(f"Price: {price}") |
|
|
|
if 'discount' in item and item['discount']: |
|
st.write(f"Discount: {item['discount']}%") |
|
if 'original_price' in item: |
|
st.write(f"Original: {item['original_price']:,}์") |
|
|
|
st.divider() |
|
|
|
except Exception as e: |
|
logger.error(f"Error displaying item: {e}") |
|
st.error("Error displaying this item") |
|
|
|
def process_search(image, mask, num_results): |
|
"""์ ์ฌ ์์ดํ
๊ฒ์ ์ฒ๋ฆฌ""" |
|
try: |
|
with st.spinner("Finding similar items..."): |
|
similar_items = search_similar_items(image, mask, num_results) |
|
|
|
return similar_items |
|
except Exception as e: |
|
logger.error(f"Search processing error: {e}") |
|
return None |
|
|
|
def handle_file_upload(): |
|
if st.session_state.uploaded_file is not None: |
|
image = Image.open(st.session_state.uploaded_file).convert('RGB') |
|
st.session_state.image = image |
|
st.session_state.upload_state = 'image_uploaded' |
|
st.rerun() |
|
|
|
def handle_detection(): |
|
if st.session_state.image is not None: |
|
detected_items = process_segmentation(st.session_state.image) |
|
st.session_state.detected_items = detected_items |
|
st.session_state.upload_state = 'items_detected' |
|
st.rerun() |
|
|
|
def handle_search(): |
|
st.session_state.search_clicked = True |
|
|
|
def main(): |
|
st.title("Fashion Search App") |
|
|
|
|
|
st.sidebar.title("Admin Controls") |
|
if st.sidebar.checkbox("Show Admin Interface"): |
|
if st.sidebar.button("Update Database (Multimodal)"): |
|
with st.spinner("Updating database with multimodal support..."): |
|
success = update_db_with_multimodal() |
|
if success: |
|
st.sidebar.success("Database updated successfully!") |
|
else: |
|
st.sidebar.error("Failed to update database") |
|
st.divider() |
|
|
|
|
|
if st.session_state.upload_state == 'initial': |
|
uploaded_file = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'], |
|
key='uploaded_file', on_change=handle_file_upload) |
|
|
|
|
|
if st.session_state.image is not None: |
|
st.image(st.session_state.image, caption="Uploaded Image", use_column_width=True) |
|
|
|
if st.session_state.detected_items is None: |
|
if st.button("Detect Items", key='detect_button', on_click=handle_detection): |
|
pass |
|
|
|
|
|
if st.session_state.detected_items is not None and len(st.session_state.detected_items) > 0: |
|
cols = st.columns(2) |
|
for idx, item in enumerate(st.session_state.detected_items): |
|
with cols[idx % 2]: |
|
try: |
|
if item.get('mask') is not None: |
|
masked_img = np.array(st.session_state.image) * np.expand_dims(item['mask'], axis=2) |
|
st.image(masked_img.astype(np.uint8), caption=f"Detected {item.get('label', 'Unknown')}") |
|
|
|
st.write(f"Item {idx + 1}: {item.get('label', 'Unknown')}") |
|
score = item.get('score') |
|
if score is not None and isinstance(score, (int, float)): |
|
st.write(f"Confidence: {score*100:.1f}%") |
|
else: |
|
st.write("Confidence: N/A") |
|
except Exception as e: |
|
logger.error(f"Error displaying item {idx}: {str(e)}") |
|
st.error(f"Error displaying item {idx}") |
|
|
|
valid_items = [i for i in range(len(st.session_state.detected_items)) |
|
if st.session_state.detected_items[i].get('mask') is not None] |
|
|
|
if not valid_items: |
|
st.warning("No valid items detected for search.") |
|
return |
|
|
|
selected_idx = st.selectbox( |
|
"Select item to search:", |
|
valid_items, |
|
format_func=lambda i: f"{st.session_state.detected_items[i].get('label', 'Unknown')}", |
|
key='item_selector' |
|
) |
|
|
|
search_col1, search_col2 = st.columns([1, 2]) |
|
with search_col1: |
|
search_clicked = st.button("Search Similar Items", |
|
key='search_button', |
|
type="primary") |
|
with search_col2: |
|
num_results = st.slider("Number of results:", |
|
min_value=1, |
|
max_value=20, |
|
value=5, |
|
key='num_results') |
|
|
|
if search_clicked or st.session_state.get('search_clicked', False): |
|
st.session_state.search_clicked = True |
|
selected_item = st.session_state.detected_items[selected_idx] |
|
|
|
if selected_item.get('mask') is None: |
|
st.error("Selected item has no valid mask for search.") |
|
return |
|
|
|
if 'search_results' not in st.session_state: |
|
similar_items = process_search(st.session_state.image, |
|
selected_item['mask'], |
|
num_results) |
|
st.session_state.search_results = similar_items |
|
|
|
if st.session_state.search_results: |
|
show_similar_items(st.session_state.search_results) |
|
else: |
|
st.warning("No similar items found.") |
|
|
|
|
|
if st.button("Start New Search", key='new_search'): |
|
for key in list(st.session_state.keys()): |
|
del st.session_state[key] |
|
st.rerun() |
|
|
|
if __name__ == "__main__": |
|
print('์์') |
|
main() |