Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -66,6 +66,7 @@ def get_text_embedding(text):
|
|
66 |
text_features = clip_model.encode_text(text_tokens)
|
67 |
text_features /= text_features.norm(dim=-1, keepdim=True)
|
68 |
return text_features.cpu().numpy()
|
|
|
69 |
def get_average_embedding(main_image_url, additional_image_urls):
|
70 |
embeddings = []
|
71 |
|
@@ -82,7 +83,7 @@ def get_average_embedding(main_image_url, additional_image_urls):
|
|
82 |
|
83 |
if embeddings:
|
84 |
avg_embedding = np.mean(embeddings, axis=0)
|
85 |
-
return avg_embedding.tolist()
|
86 |
else:
|
87 |
return None
|
88 |
|
@@ -105,7 +106,7 @@ def update_collection_embeddings():
|
|
105 |
try:
|
106 |
avg_embedding = get_average_embedding(main_image_url, additional_image_urls)
|
107 |
if avg_embedding is not None:
|
108 |
-
batch_embeddings.append(avg_embedding)
|
109 |
valid_ids.append(id)
|
110 |
else:
|
111 |
st.warning(f"Failed to generate embedding for item {id}")
|
@@ -123,9 +124,12 @@ def update_collection_embeddings():
|
|
123 |
st.error(f"Error updating embeddings: {str(e)}")
|
124 |
st.error(f"First embedding type: {type(batch_embeddings[0])}")
|
125 |
st.error(f"First embedding length: {len(batch_embeddings[0])}")
|
|
|
126 |
|
127 |
# 진행 상황 표시
|
128 |
st.progress((i + batch_size) / len(all_ids))
|
|
|
|
|
129 |
def find_similar_images(query_embedding, collection, top_k=5):
|
130 |
results = collection.query(
|
131 |
query_embeddings=[query_embedding.squeeze().tolist()],
|
|
|
66 |
text_features = clip_model.encode_text(text_tokens)
|
67 |
text_features /= text_features.norm(dim=-1, keepdim=True)
|
68 |
return text_features.cpu().numpy()
|
69 |
+
|
70 |
def get_average_embedding(main_image_url, additional_image_urls):
|
71 |
embeddings = []
|
72 |
|
|
|
83 |
|
84 |
if embeddings:
|
85 |
avg_embedding = np.mean(embeddings, axis=0)
|
86 |
+
return avg_embedding.tolist() if isinstance(avg_embedding, np.ndarray) else avg_embedding
|
87 |
else:
|
88 |
return None
|
89 |
|
|
|
106 |
try:
|
107 |
avg_embedding = get_average_embedding(main_image_url, additional_image_urls)
|
108 |
if avg_embedding is not None:
|
109 |
+
batch_embeddings.append(avg_embedding)
|
110 |
valid_ids.append(id)
|
111 |
else:
|
112 |
st.warning(f"Failed to generate embedding for item {id}")
|
|
|
124 |
st.error(f"Error updating embeddings: {str(e)}")
|
125 |
st.error(f"First embedding type: {type(batch_embeddings[0])}")
|
126 |
st.error(f"First embedding length: {len(batch_embeddings[0])}")
|
127 |
+
st.error(f"First embedding: {batch_embeddings[0][:10]}...") # 처음 10개 요소만 출력
|
128 |
|
129 |
# 진행 상황 표시
|
130 |
st.progress((i + batch_size) / len(all_ids))
|
131 |
+
|
132 |
+
|
133 |
def find_similar_images(query_embedding, collection, top_k=5):
|
134 |
results = collection.query(
|
135 |
query_embeddings=[query_embedding.squeeze().tolist()],
|