Vivien commited on
Commit
aae8769
·
1 Parent(s): 5b1c1bd

Add side-by-side comparison of the ViT models

Browse files
app.py CHANGED
@@ -5,38 +5,40 @@ import pandas as pd, numpy as np
5
  from transformers import CLIPProcessor, CLIPModel
6
  from st_clickable_images import clickable_images
7
 
 
8
 
9
- @st.cache(
10
- show_spinner=False,
11
- hash_funcs={
12
- CLIPModel: lambda _: None,
13
- CLIPProcessor: lambda _: None,
14
- dict: lambda _: None,
15
- },
16
- )
17
  def load():
18
- model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
19
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
20
  df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")}
21
- embeddings = {0: np.load("embeddings.npy"), 1: np.load("embeddings2.npy")}
22
- for k in [0, 1]:
23
- embeddings[k] = embeddings[k] / np.linalg.norm(
24
- embeddings[k], axis=1, keepdims=True
25
- )
26
- return model, processor, df, embeddings
27
-
28
-
29
- model, processor, df, embeddings = load()
 
 
 
 
 
 
 
 
 
30
  source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"}
31
 
32
 
33
- def compute_text_embeddings(list_of_strings):
34
- inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
35
- result = model.get_text_features(**inputs).detach().numpy()
36
  return result / np.linalg.norm(result, axis=1, keepdims=True)
37
 
38
 
39
- def image_search(query, corpus, n_results=24):
40
  positive_embeddings = None
41
 
42
  def concatenate_embeddings(e1, e2):
@@ -57,25 +59,25 @@ def image_search(query, corpus, n_results=24):
57
  idx, remainder = int(idx), remainder.strip()
58
  k2 = 0 if corpus2 == "Unsplash" else 1
59
  positive_embeddings = concatenate_embeddings(
60
- positive_embeddings, embeddings[k2][idx : idx + 1, :]
61
  )
62
  if len(remainder) > 0:
63
  positive_embeddings = concatenate_embeddings(
64
- positive_embeddings, compute_text_embeddings([remainder])
65
  )
66
  else:
67
  positive_embeddings = concatenate_embeddings(
68
- positive_embeddings, compute_text_embeddings([positive_query])
69
  )
70
- dot_product = embeddings[k] @ positive_embeddings.T
71
  dot_product = dot_product - np.median(dot_product, axis=0)
72
  dot_product = dot_product / np.max(dot_product, axis=0, keepdims=True)
73
  dot_product = np.min(dot_product, axis=1)
74
 
75
  if len(splitted_query) > 1:
76
  negative_queries = (" ".join(splitted_query[1:])).split(";")
77
- negative_embeddings = compute_text_embeddings(negative_queries)
78
- dot_product2 = embeddings[k] @ negative_embeddings.T
79
  dot_product2 = dot_product2 - np.median(dot_product2, axis=0)
80
  dot_product2 = dot_product2 / np.max(dot_product2, axis=0, keepdims=True)
81
  dot_product -= np.max(np.maximum(dot_product2, 0), axis=1)
@@ -96,7 +98,7 @@ description = """
96
 
97
  **Enter your query and hit enter**
98
 
99
- *Built with OpenAI's [CLIP](https://openai.com/blog/clip/) model, 🤗 Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)*
100
 
101
  *Inspired by [Unsplash Image Search](https://github.com/haltakov/natural-language-image-search) from Vladimir Haltakov and [Alph, The Sacred River](https://github.com/thoppe/alph-the-sacred-river) from Travis Hoppe*
102
  """
@@ -107,6 +109,12 @@ howto = """
107
  - If the input includes "**EXCLUDING**", the part right of it will be used as a negative query
108
  """
109
 
 
 
 
 
 
 
110
 
111
  def main():
112
  st.markdown(
@@ -124,10 +132,10 @@ def main():
124
  margin-left: 5px;
125
  margin-right: 5px;
126
  }
127
- section.main>div:first-child {
128
- padding-top: 0px;
129
  }
130
- section:not(.main)>div:first-child {
131
  padding-top: 30px;
132
  }
133
  div.reportview-container > section:first-child{
@@ -145,6 +153,9 @@ def main():
145
  st.sidebar.markdown(description)
146
  with st.sidebar.expander("Advanced use"):
147
  st.markdown(howto)
 
 
 
148
 
149
  _, c, _ = st.columns((1, 3, 1))
150
  if "query" in st.session_state:
@@ -152,27 +163,65 @@ def main():
152
  else:
153
  query = c.text_input("", value="clouds at sunset")
154
  corpus = st.radio("", ["Unsplash", "Movies"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  if len(query) > 0:
156
- results = image_search(query, corpus)
157
- clicked = clickable_images(
158
- [result[0] for result in results],
159
- titles=[result[1] for result in results],
160
- div_style={
161
- "display": "flex",
162
- "justify-content": "center",
163
- "flex-wrap": "wrap",
164
- },
165
- img_style={"margin": "2px", "height": "200px"},
166
- )
167
- if clicked >= 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  change_query = False
169
  if "last_clicked" not in st.session_state:
170
  change_query = True
171
  else:
172
- if clicked != st.session_state["last_clicked"]:
173
  change_query = True
174
  if change_query:
175
- st.session_state["query"] = f"[{corpus}:{results[clicked][2]}]"
 
 
 
176
  st.experimental_rerun()
177
 
178
 
 
5
  from transformers import CLIPProcessor, CLIPModel
6
  from st_clickable_images import clickable_images
7
 
8
+ MODEL_NAMES = ["base-patch32", "base-patch16", "large-patch14", "large-patch14-336"]
9
 
10
+
11
+ @st.cache(show_spinner=False, hash_funcs={dict: lambda _: None})
 
 
 
 
 
 
12
  def load():
 
 
13
  df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")}
14
+ models = {}
15
+ processors = {}
16
+ embeddings = {}
17
+ for name in MODEL_NAMES:
18
+ models[name] = CLIPModel.from_pretrained(f"openai/clip-vit-{name}")
19
+ processors[name] = CLIPProcessor.from_pretrained(f"openai/clip-vit-{name}")
20
+ embeddings[name] = {
21
+ 0: np.load(f"embeddings-vit-{name}.npy"),
22
+ 1: np.load(f"embeddings2-vit-{name}.npy"),
23
+ }
24
+ for k in [0, 1]:
25
+ embeddings[name][k] = embeddings[name][k] / np.linalg.norm(
26
+ embeddings[name][k], axis=1, keepdims=True
27
+ )
28
+ return models, processors, df, embeddings
29
+
30
+
31
+ models, processors, df, embeddings = load()
32
  source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"}
33
 
34
 
35
+ def compute_text_embeddings(list_of_strings, name):
36
+ inputs = processors[name](text=list_of_strings, return_tensors="pt", padding=True)
37
+ result = models[name].get_text_features(**inputs).detach().numpy()
38
  return result / np.linalg.norm(result, axis=1, keepdims=True)
39
 
40
 
41
+ def image_search(query, corpus, name, n_results=24):
42
  positive_embeddings = None
43
 
44
  def concatenate_embeddings(e1, e2):
 
59
  idx, remainder = int(idx), remainder.strip()
60
  k2 = 0 if corpus2 == "Unsplash" else 1
61
  positive_embeddings = concatenate_embeddings(
62
+ positive_embeddings, embeddings[name][k2][idx : idx + 1, :]
63
  )
64
  if len(remainder) > 0:
65
  positive_embeddings = concatenate_embeddings(
66
+ positive_embeddings, compute_text_embeddings([remainder], name)
67
  )
68
  else:
69
  positive_embeddings = concatenate_embeddings(
70
+ positive_embeddings, compute_text_embeddings([positive_query], name)
71
  )
72
+ dot_product = embeddings[name][k] @ positive_embeddings.T
73
  dot_product = dot_product - np.median(dot_product, axis=0)
74
  dot_product = dot_product / np.max(dot_product, axis=0, keepdims=True)
75
  dot_product = np.min(dot_product, axis=1)
76
 
77
  if len(splitted_query) > 1:
78
  negative_queries = (" ".join(splitted_query[1:])).split(";")
79
+ negative_embeddings = compute_text_embeddings(negative_queries, name)
80
+ dot_product2 = embeddings[name][k] @ negative_embeddings.T
81
  dot_product2 = dot_product2 - np.median(dot_product2, axis=0)
82
  dot_product2 = dot_product2 / np.max(dot_product2, axis=0, keepdims=True)
83
  dot_product -= np.max(np.maximum(dot_product2, 0), axis=1)
 
98
 
99
  **Enter your query and hit enter**
100
 
101
+ *Built with OpenAI's [CLIP](https://openai.com/blog/clip/) models, 🤗 Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)*
102
 
103
  *Inspired by [Unsplash Image Search](https://github.com/haltakov/natural-language-image-search) from Vladimir Haltakov and [Alph, The Sacred River](https://github.com/thoppe/alph-the-sacred-river) from Travis Hoppe*
104
  """
 
109
  - If the input includes "**EXCLUDING**", the part right of it will be used as a negative query
110
  """
111
 
112
+ div_style = {
113
+ "display": "flex",
114
+ "justify-content": "center",
115
+ "flex-wrap": "wrap",
116
+ }
117
+
118
 
119
  def main():
120
  st.markdown(
 
132
  margin-left: 5px;
133
  margin-right: 5px;
134
  }
135
+ .row-widget {
136
+ margin-top: -25px;
137
  }
138
+ section>div:first-child {
139
  padding-top: 30px;
140
  }
141
  div.reportview-container > section:first-child{
 
153
  st.sidebar.markdown(description)
154
  with st.sidebar.expander("Advanced use"):
155
  st.markdown(howto)
156
+ mode = st.sidebar.selectbox(
157
+ "", ["Results for ViT-L/14@336px", "Comparison of 2 models"], index=0
158
+ )
159
 
160
  _, c, _ = st.columns((1, 3, 1))
161
  if "query" in st.session_state:
 
163
  else:
164
  query = c.text_input("", value="clouds at sunset")
165
  corpus = st.radio("", ["Unsplash", "Movies"])
166
+
167
+ models_dict = {
168
+ "ViT-B/32 (quickest)": "base-patch32",
169
+ "ViT-B/16 (quick)": "base-patch16",
170
+ "ViT-L/14 (slow)": "large-patch14",
171
+ "ViT-L/14@336px (slowest)": "large-patch14-336",
172
+ }
173
+
174
+ if "Comparison" in mode:
175
+ c1, c2 = st.columns((1, 1))
176
+ selection1 = c1.selectbox("", models_dict.keys(), index=0)
177
+ selection2 = c2.selectbox("", models_dict.keys(), index=3)
178
+ name1 = models_dict[selection1]
179
+ name2 = models_dict[selection2]
180
+ else:
181
+ name1 = MODEL_NAMES[-1]
182
+
183
  if len(query) > 0:
184
+ results1 = image_search(query, corpus, name1)
185
+ if "Comparison" in mode:
186
+ with c1:
187
+ clicked1 = clickable_images(
188
+ [result[0] for result in results1],
189
+ titles=[result[1] for result in results1],
190
+ div_style=div_style,
191
+ img_style={"margin": "2px", "height": "150px"},
192
+ key=query + corpus + name1 + "1",
193
+ )
194
+ results2 = image_search(query, corpus, name2)
195
+ with c2:
196
+ clicked2 = clickable_images(
197
+ [result[0] for result in results2],
198
+ titles=[result[1] for result in results2],
199
+ div_style=div_style,
200
+ img_style={"margin": "2px", "height": "150px"},
201
+ key=query + corpus + name2 + "2",
202
+ )
203
+ else:
204
+ clicked1 = clickable_images(
205
+ [result[0] for result in results1],
206
+ titles=[result[1] for result in results1],
207
+ div_style=div_style,
208
+ img_style={"margin": "2px", "height": "200px"},
209
+ key=query + corpus + name1 + "1",
210
+ )
211
+ clicked2 = -1
212
+
213
+ if clicked2 >= 0 or clicked1 >= 0:
214
  change_query = False
215
  if "last_clicked" not in st.session_state:
216
  change_query = True
217
  else:
218
+ if max(clicked2, clicked1) != st.session_state["last_clicked"]:
219
  change_query = True
220
  if change_query:
221
+ if clicked1 >= 0:
222
+ st.session_state["query"] = f"[{corpus}:{results1[clicked1][2]}]"
223
+ elif clicked2 >= 0:
224
+ st.session_state["query"] = f"[{corpus}:{results2[clicked2][2]}]"
225
  st.experimental_rerun()
226
 
227
 
embeddings-vit-base-patch16.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:125430e11a4a415ec0c0fc5339f97544f0447e4b0a24c20f2e59f8852e706afc
3
+ size 51200128
embeddings-vit-base-patch32.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f7ebdff24079665faf58d07045056a63b5499753e3ffbda479691d53de3ab38
3
+ size 51200128
embeddings-vit-large-patch14-336.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f79f10ebe267b4ee7acd553dfe0ee31df846123630058a6d58c04bf22e0ad068
3
+ size 76800128
embeddings.npy → embeddings-vit-large-patch14.npy RENAMED
File without changes
embeddings2-vit-base-patch16.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:153cf3fae2385d51fe8729d3a1c059f611ca47a3fc501049708114d1bbf79049
3
+ size 16732288
embeddings2-vit-base-patch32.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7d545bed86121dac1cedcc1de61ea5295f5840c1eb751637e6628ac54faef81
3
+ size 16732288
embeddings2-vit-large-patch14-336.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e66eb377465fbfaa56cec079aa3e214533ceac43646f2ca78028ae4d8ad6d03
3
+ size 25098368
embeddings2.npy → embeddings2-vit-large-patch14.npy RENAMED
File without changes