itda / app.py
leedoming's picture
Update app.py
b404c7a verified
import streamlit as st
import open_clip
import torch
import requests
from PIL import Image
from io import BytesIO
import time
import json
import numpy as np
import cv2
from inference_sdk import InferenceHTTPClient
import matplotlib.pyplot as plt
import base64
import os
import pickle
# Load model and tokenizer
@st.cache_resource
def load_model():
model, preprocess_val, tokenizer = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
return model, preprocess_val, tokenizer, device
model, preprocess_val, tokenizer, device = load_model()
# Load and process data
@st.cache_data
def load_data():
with open('musinsa-final.json', 'r', encoding='utf-8') as f:
return json.load(f)
data = load_data()
def setup_roboflow_client(api_key):
return InferenceHTTPClient(
api_url="https://outline.roboflow.com",
api_key=api_key
)
def download_and_process_image(image_url):
try:
response = requests.get(image_url)
response.raise_for_status()
image = Image.open(BytesIO(response.content))
if image.mode == 'RGBA':
image = image.convert('RGB')
return image
except Exception as e:
st.error(f"Error downloading/processing image: {str(e)}")
return None
def segment_image_and_get_categories(image_path, client):
try:
with open(image_path, "rb") as image_file:
image_data = image_file.read()
encoded_image = base64.b64encode(image_data).decode('utf-8')
image = cv2.imread(image_path)
image = cv2.resize(image, (800, 600))
mask = np.zeros(image.shape, dtype=np.uint8)
results = client.infer(encoded_image, model_id="closet/1")
if isinstance(results, dict):
predictions = results.get('predictions', [])
else:
predictions = json.loads(results).get('predictions', [])
categories = []
if predictions:
for prediction in predictions:
points = prediction['points']
pts = np.array([[p['x'], p['y']] for p in points], np.int32)
scale_x = image.shape[1] / results.get('image', {}).get('width', 1)
scale_y = image.shape[0] / results.get('image', {}).get('height', 1)
pts = pts * [scale_x, scale_y]
pts = pts.astype(np.int32)
pts = pts.reshape((-1, 1, 2))
cv2.fillPoly(mask, [pts], color=(255, 255, 255))
category = prediction.get('class', 'Unknown')
confidence = prediction.get('confidence', 0)
categories.append(f"{category} ({confidence:.2f})")
segmented_image = cv2.bitwise_and(image, mask)
else:
st.warning("No predictions found in the image. Returning original image.")
segmented_image = image
return Image.fromarray(cv2.cvtColor(segmented_image, cv2.COLOR_BGR2RGB)), categories
except Exception as e:
st.error(f"Error in segmentation: {str(e)}")
return Image.open(image_path), []
def get_image_embedding(image):
image_tensor = preprocess_val(image).unsqueeze(0).to(device)
with torch.no_grad():
image_features = model.encode_image(image_tensor)
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features.cpu().numpy()
@st.cache_data
def process_database_cached(data):
database_info = []
for item in data:
image_url = item['์ด๋ฏธ์ง€ ๋งํฌ'][0]
product_id = item.get('\ufeff์ƒํ’ˆ ID') or item.get('์ƒํ’ˆ ID')
image = download_and_process_image(image_url)
if image is None:
continue
temp_path = f"temp_{product_id}.jpg"
image.save(temp_path, 'JPEG')
database_info.append({
'id': product_id,
'category': item['์นดํ…Œ๊ณ ๋ฆฌ'],
'brand': item['๋ธŒ๋žœ๋“œ๋ช…'],
'name': item['์ œํ’ˆ๋ช…'],
'price': item['์ •๊ฐ€'],
'discount': item['ํ• ์ธ์œจ'],
'image_url': image_url,
'temp_path': temp_path
})
return database_info
def process_database(client, data):
database_info = process_database_cached(data)
cache_dir = "segmentation_cache"
os.makedirs(cache_dir, exist_ok=True)
database_embeddings = []
for item in database_info:
cache_file = os.path.join(cache_dir, f"{item['id']}_segmented.pkl")
if os.path.exists(cache_file):
with open(cache_file, 'rb') as f:
segmented_image, categories = pickle.load(f)
else:
segmented_image, categories = segment_image_and_get_categories(item['temp_path'], client)
with open(cache_file, 'wb') as f:
pickle.dump((segmented_image, categories), f)
embedding = get_image_embedding(segmented_image)
database_embeddings.append(embedding)
item['categories'] = categories
return np.vstack(database_embeddings), database_info
def find_similar_images(query_embedding, database_embeddings, database_info, top_k=5):
similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
top_indices = np.argsort(similarities)[::-1][:top_k]
results = []
for idx in top_indices:
results.append({
'info': database_info[idx],
'similarity': similarities[idx]
})
return results
# Streamlit app
st.title("Fashion Search App with Segmentation and Category Detection")
# API Key input
api_key = st.text_input("Enter your Roboflow API Key", type="password")
if api_key:
CLIENT = setup_roboflow_client(api_key)
# Initialize database_embeddings and database_info
database_embeddings, database_info = process_database(CLIENT, data)
uploaded_file = st.file_uploader("Choose an image...", type="jpg")
if uploaded_file is not None:
image = Image.open(uploaded_file)
st.image(image, caption='Uploaded Image', use_column_width=True)
if st.button('Find Similar Items'):
with st.spinner('Processing...'):
temp_path = "temp_upload.jpg"
image.save(temp_path)
segmented_image, input_categories = segment_image_and_get_categories(temp_path, CLIENT)
st.image(segmented_image, caption='Segmented Image', use_column_width=True)
st.subheader("Detected Categories in Input Image:")
for category in input_categories:
st.write(category)
query_embedding = get_image_embedding(segmented_image)
similar_images = find_similar_images(query_embedding, database_embeddings, database_info)
st.subheader("Similar Items:")
for img in similar_images:
col1, col2 = st.columns(2)
with col1:
st.image(img['info']['image_url'], use_column_width=True)
with col2:
st.write(f"Name: {img['info']['name']}")
st.write(f"Brand: {img['info']['brand']}")
st.write(f"Category: {img['info']['category']}")
st.write(f"Price: {img['info']['price']}")
st.write(f"Discount: {img['info']['discount']}%")
st.write(f"Similarity: {img['similarity']:.2f}")
st.write("Detected Categories:")
for category in img['info']['categories']:
st.write(category)
else:
st.warning("Please enter your Roboflow API Key to use the app.")