leedoming commited on
Commit
ed40fd9
·
verified ·
1 Parent(s): a0e438d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -66,6 +66,7 @@ def get_text_embedding(text):
66
  text_features = clip_model.encode_text(text_tokens)
67
  text_features /= text_features.norm(dim=-1, keepdim=True)
68
  return text_features.cpu().numpy()
 
69
  def get_average_embedding(main_image_url, additional_image_urls):
70
  embeddings = []
71
 
@@ -82,7 +83,7 @@ def get_average_embedding(main_image_url, additional_image_urls):
82
 
83
  if embeddings:
84
  avg_embedding = np.mean(embeddings, axis=0)
85
- return avg_embedding.tolist() # numpy 배열을 파이썬 리스트로 변환
86
  else:
87
  return None
88
 
@@ -105,7 +106,7 @@ def update_collection_embeddings():
105
  try:
106
  avg_embedding = get_average_embedding(main_image_url, additional_image_urls)
107
  if avg_embedding is not None:
108
- batch_embeddings.append(avg_embedding) # 이미 리스트 형태임
109
  valid_ids.append(id)
110
  else:
111
  st.warning(f"Failed to generate embedding for item {id}")
@@ -123,9 +124,12 @@ def update_collection_embeddings():
123
  st.error(f"Error updating embeddings: {str(e)}")
124
  st.error(f"First embedding type: {type(batch_embeddings[0])}")
125
  st.error(f"First embedding length: {len(batch_embeddings[0])}")
 
126
 
127
  # 진행 상황 표시
128
  st.progress((i + batch_size) / len(all_ids))
 
 
129
  def find_similar_images(query_embedding, collection, top_k=5):
130
  results = collection.query(
131
  query_embeddings=[query_embedding.squeeze().tolist()],
 
66
  text_features = clip_model.encode_text(text_tokens)
67
  text_features /= text_features.norm(dim=-1, keepdim=True)
68
  return text_features.cpu().numpy()
69
+
70
  def get_average_embedding(main_image_url, additional_image_urls):
71
  embeddings = []
72
 
 
83
 
84
  if embeddings:
85
  avg_embedding = np.mean(embeddings, axis=0)
86
+ return avg_embedding.tolist() if isinstance(avg_embedding, np.ndarray) else avg_embedding
87
  else:
88
  return None
89
 
 
106
  try:
107
  avg_embedding = get_average_embedding(main_image_url, additional_image_urls)
108
  if avg_embedding is not None:
109
+ batch_embeddings.append(avg_embedding)
110
  valid_ids.append(id)
111
  else:
112
  st.warning(f"Failed to generate embedding for item {id}")
 
124
  st.error(f"Error updating embeddings: {str(e)}")
125
  st.error(f"First embedding type: {type(batch_embeddings[0])}")
126
  st.error(f"First embedding length: {len(batch_embeddings[0])}")
127
+ st.error(f"First embedding: {batch_embeddings[0][:10]}...") # 처음 10개 요소만 출력
128
 
129
  # 진행 상황 표시
130
  st.progress((i + batch_size) / len(all_ids))
131
+
132
+
133
  def find_similar_images(query_embedding, collection, top_k=5):
134
  results = collection.query(
135
  query_embeddings=[query_embedding.squeeze().tolist()],