kaushalya commited on
Commit
0f2db82
·
1 Parent(s): 01973e8

Fix: Load embeddings from hdf files

Browse files
.gitattributes CHANGED
@@ -15,3 +15,4 @@
15
  *.pt filter=lfs diff=lfs merge=lfs -text
16
  *.pth filter=lfs diff=lfs merge=lfs -text
17
  *.pkl filter=lfs diff=lfs merge=lfs -text
 
 
15
  *.pt filter=lfs diff=lfs merge=lfs -text
16
  *.pth filter=lfs diff=lfs merge=lfs -text
17
  *.pkl filter=lfs diff=lfs merge=lfs -text
18
+ *.hdf filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -3,9 +3,10 @@ import pandas as pd
3
  import numpy as np
4
  import os
5
  import matplotlib.pyplot as plt
6
- from transformers import AutoTokenizer, CLIPProcessor
7
  from medclip.modeling_hybrid_clip import FlaxHybridCLIP
8
 
 
9
  @st.cache(allow_output_mutation=True)
10
  def load_model():
11
  model = FlaxHybridCLIP.from_pretrained("flax-community/medclip-roco")
@@ -14,7 +15,7 @@ def load_model():
14
 
15
  @st.cache(allow_output_mutation=True)
16
  def load_image_embeddings():
17
- embeddings_df = pd.read_pickle('feature_store/image_embeddings_large.pkl')
18
  image_embeds = np.stack(embeddings_df['image_embedding'])
19
  image_files = np.asarray(embeddings_df['files'].tolist())
20
  return image_files, image_embeds
@@ -66,6 +67,9 @@ model, processor = load_model()
66
  query = st.text_input("Enter your query here:", value=text_value)
67
  dot_prod = None
68
 
 
 
 
69
  if st.button("Search") or k_slider:
70
  if len(query)==0:
71
  st.write("Please enter a valid search query")
 
3
  import numpy as np
4
  import os
5
  import matplotlib.pyplot as plt
6
+ from transformers import CLIPProcessor
7
  from medclip.modeling_hybrid_clip import FlaxHybridCLIP
8
 
9
+
10
  @st.cache(allow_output_mutation=True)
11
  def load_model():
12
  model = FlaxHybridCLIP.from_pretrained("flax-community/medclip-roco")
 
15
 
16
  @st.cache(allow_output_mutation=True)
17
  def load_image_embeddings():
18
+ embeddings_df = pd.read_hdf('feature_store/image_embeddings_large.hdf', key='emb')
19
  image_embeds = np.stack(embeddings_df['image_embedding'])
20
  image_files = np.asarray(embeddings_df['files'].tolist())
21
  return image_files, image_embeds
 
67
  query = st.text_input("Enter your query here:", value=text_value)
68
  dot_prod = None
69
 
70
+ if len(query)==0:
71
+ query = text_value
72
+
73
  if st.button("Search") or k_slider:
74
  if len(query)==0:
75
  st.write("Please enter a valid search query")
feature_store/image_embeddings_large.hdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f64ee3e4916a8289fb2352ccfb856c472213b0465c6809c08225b26aef15d13a
3
+ size 15159216
feature_store/image_embeddings_small.hdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1aa88cf89bdf714e7ed9e48d44aa2271f2b89764cb6f81769fdd4e21667f1434
3
+ size 1988616