Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -52,12 +52,13 @@ def load_image_from_url(url, max_retries=3):
|
|
52 |
else:
|
53 |
return None
|
54 |
|
|
|
55 |
def get_image_embedding(image):
|
56 |
image_tensor = preprocess_val(image).unsqueeze(0).to(device)
|
57 |
with torch.no_grad():
|
58 |
image_features = clip_model.encode_image(image_tensor)
|
59 |
image_features /= image_features.norm(dim=-1, keepdim=True)
|
60 |
-
return image_features.cpu().numpy()
|
61 |
|
62 |
def get_text_embedding(text):
|
63 |
text_tokens = tokenizer([text]).to(device)
|
@@ -65,7 +66,6 @@ def get_text_embedding(text):
|
|
65 |
text_features = clip_model.encode_text(text_tokens)
|
66 |
text_features /= text_features.norm(dim=-1, keepdim=True)
|
67 |
return text_features.cpu().numpy()
|
68 |
-
|
69 |
def get_average_embedding(main_image_url, additional_image_urls):
|
70 |
embeddings = []
|
71 |
|
@@ -81,10 +81,52 @@ def get_average_embedding(main_image_url, additional_image_urls):
|
|
81 |
embeddings.append(get_image_embedding(img))
|
82 |
|
83 |
if embeddings:
|
84 |
-
|
|
|
85 |
else:
|
86 |
return None
|
87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
def find_similar_images(query_embedding, collection, top_k=5):
|
89 |
results = collection.query(
|
90 |
query_embeddings=[query_embedding.squeeze().tolist()],
|
|
|
52 |
else:
|
53 |
return None
|
54 |
|
55 |
+
# 기존의 get_image_embedding 함수도 수정
|
56 |
def get_image_embedding(image):
|
57 |
image_tensor = preprocess_val(image).unsqueeze(0).to(device)
|
58 |
with torch.no_grad():
|
59 |
image_features = clip_model.encode_image(image_tensor)
|
60 |
image_features /= image_features.norm(dim=-1, keepdim=True)
|
61 |
+
return image_features.cpu().numpy().squeeze().tolist() # numpy 배열을 파이썬 리스트로 변환
|
62 |
|
63 |
def get_text_embedding(text):
|
64 |
text_tokens = tokenizer([text]).to(device)
|
|
|
66 |
text_features = clip_model.encode_text(text_tokens)
|
67 |
text_features /= text_features.norm(dim=-1, keepdim=True)
|
68 |
return text_features.cpu().numpy()
|
|
|
69 |
def get_average_embedding(main_image_url, additional_image_urls):
|
70 |
embeddings = []
|
71 |
|
|
|
81 |
embeddings.append(get_image_embedding(img))
|
82 |
|
83 |
if embeddings:
|
84 |
+
avg_embedding = np.mean(embeddings, axis=0)
|
85 |
+
return avg_embedding.squeeze().tolist() # numpy 배열을 파이썬 리스트로 변환
|
86 |
else:
|
87 |
return None
|
88 |
|
89 |
+
def update_collection_embeddings():
|
90 |
+
all_ids = collection.get(include=['metadatas'])['ids']
|
91 |
+
all_metadata = collection.get(include=['metadatas'])['metadatas']
|
92 |
+
|
93 |
+
batch_size = 100 # 한 번에 처리할 아이템 수
|
94 |
+
for i in range(0, len(all_ids), batch_size):
|
95 |
+
batch_ids = all_ids[i:i+batch_size]
|
96 |
+
batch_metadata = all_metadata[i:i+batch_size]
|
97 |
+
|
98 |
+
batch_embeddings = []
|
99 |
+
valid_ids = []
|
100 |
+
|
101 |
+
for id, metadata in zip(batch_ids, batch_metadata):
|
102 |
+
main_image_url = metadata['image_url']
|
103 |
+
additional_image_urls = metadata.get('additional_images', [])
|
104 |
+
|
105 |
+
try:
|
106 |
+
avg_embedding = get_average_embedding(main_image_url, additional_image_urls)
|
107 |
+
if avg_embedding is not None:
|
108 |
+
batch_embeddings.append(avg_embedding) # 이미 리스트 형태로 반환됨
|
109 |
+
valid_ids.append(id)
|
110 |
+
else:
|
111 |
+
st.warning(f"Failed to generate embedding for item {id}")
|
112 |
+
except Exception as e:
|
113 |
+
st.error(f"Error processing item {id}: {str(e)}")
|
114 |
+
|
115 |
+
if valid_ids:
|
116 |
+
try:
|
117 |
+
collection.update(
|
118 |
+
ids=valid_ids,
|
119 |
+
embeddings=batch_embeddings
|
120 |
+
)
|
121 |
+
st.success(f"Updated embeddings for {len(valid_ids)} items")
|
122 |
+
except Exception as e:
|
123 |
+
st.error(f"Error updating embeddings: {str(e)}")
|
124 |
+
st.error(f"First embedding type: {type(batch_embeddings[0])}")
|
125 |
+
st.error(f"First embedding shape: {len(batch_embeddings[0])}")
|
126 |
+
|
127 |
+
# 진행 상황 표시
|
128 |
+
st.progress((i + batch_size) / len(all_ids))
|
129 |
+
|
130 |
def find_similar_images(query_embedding, collection, top_k=5):
|
131 |
results = collection.query(
|
132 |
query_embeddings=[query_embedding.squeeze().tolist()],
|