leedoming commited on
Commit
5db0821
·
verified ·
1 Parent(s): 3d70c20

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -0
app.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import open_clip
3
+ import torch
4
+ import requests
5
+ from PIL import Image
6
+ from io import BytesIO
7
+ import time
8
+ import json
9
+ import numpy as np
10
+
11
+ # Load model and tokenizer
12
+ @st.cache_resource
13
+ def load_model():
14
+ model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:Marqo/marqo-fashionSigLIP')
15
+ tokenizer = open_clip.get_tokenizer('hf-hub:Marqo/marqo-fashionSigLIP')
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ model.to(device)
18
+ return model, preprocess_val, tokenizer, device
19
+
20
+ model, preprocess_val, tokenizer, device = load_model()
21
+
22
+ # Load and process data
23
+ @st.cache_data
24
+ def load_data():
25
+ with open('./musinsa-final.json', 'r', encoding='utf-8') as f:
26
+ return json.load(f)
27
+
28
+ data = load_data()
29
+
30
+ # Helper functions
31
+ def load_image_from_url(url, max_retries=3):
32
+ for attempt in range(max_retries):
33
+ try:
34
+ response = requests.get(url, timeout=10)
35
+ response.raise_for_status()
36
+ img = Image.open(BytesIO(response.content)).convert('RGB')
37
+ return img
38
+ except (requests.RequestException, Image.UnidentifiedImageError) as e:
39
+ #st.warning(f"Attempt {attempt + 1} failed: {str(e)}")
40
+ if attempt < max_retries - 1:
41
+ time.sleep(1)
42
+ else:
43
+ #st.error(f"Failed to load image from {url} after {max_retries} attempts")
44
+ return None
45
+
46
+ def get_image_embedding_from_url(image_url):
47
+ image = load_image_from_url(image_url)
48
+ if image is None:
49
+ return None
50
+
51
+ image_tensor = preprocess_val(image).unsqueeze(0).to(device)
52
+
53
+ with torch.no_grad():
54
+ image_features = model.encode_image(image_tensor)
55
+ image_features /= image_features.norm(dim=-1, keepdim=True)
56
+
57
+ return image_features.cpu().numpy()
58
+
59
+ @st.cache_data
60
+ def process_database():
61
+ database_embeddings = []
62
+ database_info = []
63
+
64
+ for item in data:
65
+ image_url = item['이미지 링크'][0]
66
+ embedding = get_image_embedding_from_url(image_url)
67
+
68
+ if embedding is not None:
69
+ database_embeddings.append(embedding)
70
+ database_info.append({
71
+ 'id': item['\ufeff상품 ID'],
72
+ 'category': item['카테고리'],
73
+ 'brand': item['브랜드명'],
74
+ 'name': item['제품명'],
75
+ 'price': item['정가'],
76
+ 'discount': item['할인율'],
77
+ 'image_url': image_url
78
+ })
79
+ else:
80
+ st.warning(f"Skipping item {item['상품 ID']} due to image loading failure")
81
+
82
+ if database_embeddings:
83
+ return np.vstack(database_embeddings), database_info
84
+ else:
85
+ st.error("No valid embeddings were generated.")
86
+ return None, None
87
+
88
+ database_embeddings, database_info = process_database()
89
+
90
+ def get_text_embedding(text):
91
+ text_tokens = tokenizer([text]).to(device)
92
+
93
+ with torch.no_grad():
94
+ text_features = model.encode_text(text_tokens)
95
+ text_features /= text_features.norm(dim=-1, keepdim=True)
96
+
97
+ return text_features.cpu().numpy()
98
+
99
+ def find_similar_images(query_embedding, top_k=5):
100
+ similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
101
+ top_indices = np.argsort(similarities)[::-1][:top_k]
102
+
103
+ results = []
104
+ for idx in top_indices:
105
+ results.append({
106
+ 'info': database_info[idx],
107
+ 'similarity': similarities[idx]
108
+ })
109
+
110
+ return results
111
+
112
+ # Streamlit app
113
+ st.title("Fashion Search App")
114
+
115
+ search_type = st.radio("Search by:", ("Image URL", "Text"))
116
+
117
+ if search_type == "Image URL":
118
+ query_image_url = st.text_input("Enter image URL:")
119
+ if st.button("Search by Image"):
120
+ if query_image_url:
121
+ query_embedding = get_image_embedding_from_url(query_image_url)
122
+ if query_embedding is not None:
123
+ similar_images = find_similar_images(query_embedding)
124
+ st.image(query_image_url, caption="Query Image", use_column_width=True)
125
+ st.subheader("Similar Items:")
126
+ for img in similar_images:
127
+ col1, col2 = st.columns(2)
128
+ with col1:
129
+ st.image(img['info']['image_url'], use_column_width=True)
130
+ with col2:
131
+ st.write(f"Name: {img['info']['name']}")
132
+ st.write(f"Brand: {img['info']['brand']}")
133
+ st.write(f"Category: {img['info']['category']}")
134
+ st.write(f"Price: {img['info']['price']}")
135
+ st.write(f"Discount: {img['info']['discount']}%")
136
+ st.write(f"Similarity: {img['similarity']:.2f}")
137
+ else:
138
+ st.error("Failed to process the image. Please try another URL.")
139
+ else:
140
+ st.warning("Please enter an image URL.")
141
+
142
+ else: # Text search
143
+ query_text = st.text_input("Enter search text:")
144
+ if st.button("Search by Text"):
145
+ if query_text:
146
+ text_embedding = get_text_embedding(query_text)
147
+ similar_images = find_similar_images(text_embedding)
148
+ st.subheader("Similar Items:")
149
+ for img in similar_images:
150
+ col1, col2 = st.columns(2)
151
+ with col1:
152
+ st.image(img['info']['image_url'], use_column_width=True)
153
+ with col2:
154
+ st.write(f"Name: {img['info']['name']}")
155
+ st.write(f"Brand: {img['info']['brand']}")
156
+ st.write(f"Category: {img['info']['category']}")
157
+ st.write(f"Price: {img['info']['price']}")
158
+ st.write(f"Discount: {img['info']['discount']}%")
159
+ st.write(f"Similarity: {img['similarity']:.2f}")
160
+ else:
161
+ st.warning("Please enter a search text.")