leedoming commited on
Commit
9153745
·
verified ·
1 Parent(s): 165b6cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -3
app.py CHANGED
@@ -52,12 +52,13 @@ def load_image_from_url(url, max_retries=3):
52
  else:
53
  return None
54
 
 
55
  def get_image_embedding(image):
56
  image_tensor = preprocess_val(image).unsqueeze(0).to(device)
57
  with torch.no_grad():
58
  image_features = clip_model.encode_image(image_tensor)
59
  image_features /= image_features.norm(dim=-1, keepdim=True)
60
- return image_features.cpu().numpy()
61
 
62
  def get_text_embedding(text):
63
  text_tokens = tokenizer([text]).to(device)
@@ -65,7 +66,6 @@ def get_text_embedding(text):
65
  text_features = clip_model.encode_text(text_tokens)
66
  text_features /= text_features.norm(dim=-1, keepdim=True)
67
  return text_features.cpu().numpy()
68
-
69
  def get_average_embedding(main_image_url, additional_image_urls):
70
  embeddings = []
71
 
@@ -81,10 +81,52 @@ def get_average_embedding(main_image_url, additional_image_urls):
81
  embeddings.append(get_image_embedding(img))
82
 
83
  if embeddings:
84
- return np.mean(embeddings, axis=0)
 
85
  else:
86
  return None
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  def find_similar_images(query_embedding, collection, top_k=5):
89
  results = collection.query(
90
  query_embeddings=[query_embedding.squeeze().tolist()],
 
52
  else:
53
  return None
54
 
55
+ # 기존의 get_image_embedding 함수도 수정
56
  def get_image_embedding(image):
57
  image_tensor = preprocess_val(image).unsqueeze(0).to(device)
58
  with torch.no_grad():
59
  image_features = clip_model.encode_image(image_tensor)
60
  image_features /= image_features.norm(dim=-1, keepdim=True)
61
+ return image_features.cpu().numpy().squeeze().tolist() # numpy 배열을 파이썬 리스트로 변환
62
 
63
  def get_text_embedding(text):
64
  text_tokens = tokenizer([text]).to(device)
 
66
  text_features = clip_model.encode_text(text_tokens)
67
  text_features /= text_features.norm(dim=-1, keepdim=True)
68
  return text_features.cpu().numpy()
 
69
  def get_average_embedding(main_image_url, additional_image_urls):
70
  embeddings = []
71
 
 
81
  embeddings.append(get_image_embedding(img))
82
 
83
  if embeddings:
84
+ avg_embedding = np.mean(embeddings, axis=0)
85
+ return avg_embedding.squeeze().tolist() # numpy 배열을 파이썬 리스트로 변환
86
  else:
87
  return None
88
 
89
+ def update_collection_embeddings():
90
+ all_ids = collection.get(include=['metadatas'])['ids']
91
+ all_metadata = collection.get(include=['metadatas'])['metadatas']
92
+
93
+ batch_size = 100 # 한 번에 처리할 아이템 수
94
+ for i in range(0, len(all_ids), batch_size):
95
+ batch_ids = all_ids[i:i+batch_size]
96
+ batch_metadata = all_metadata[i:i+batch_size]
97
+
98
+ batch_embeddings = []
99
+ valid_ids = []
100
+
101
+ for id, metadata in zip(batch_ids, batch_metadata):
102
+ main_image_url = metadata['image_url']
103
+ additional_image_urls = metadata.get('additional_images', [])
104
+
105
+ try:
106
+ avg_embedding = get_average_embedding(main_image_url, additional_image_urls)
107
+ if avg_embedding is not None:
108
+ batch_embeddings.append(avg_embedding) # 이미 리스트 형태로 반환됨
109
+ valid_ids.append(id)
110
+ else:
111
+ st.warning(f"Failed to generate embedding for item {id}")
112
+ except Exception as e:
113
+ st.error(f"Error processing item {id}: {str(e)}")
114
+
115
+ if valid_ids:
116
+ try:
117
+ collection.update(
118
+ ids=valid_ids,
119
+ embeddings=batch_embeddings
120
+ )
121
+ st.success(f"Updated embeddings for {len(valid_ids)} items")
122
+ except Exception as e:
123
+ st.error(f"Error updating embeddings: {str(e)}")
124
+ st.error(f"First embedding type: {type(batch_embeddings[0])}")
125
+ st.error(f"First embedding shape: {len(batch_embeddings[0])}")
126
+
127
+ # 진행 상황 표시
128
+ st.progress((i + batch_size) / len(all_ids))
129
+
130
  def find_similar_images(query_embedding, collection, top_k=5):
131
  results = collection.query(
132
  query_embeddings=[query_embedding.squeeze().tolist()],