clip / app.py
visualizingjp's picture
Update app.py
5639350 verified
from html import escape
import re
import streamlit as st
import pandas as pd, numpy as np
import torch
from transformers import CLIPProcessor, CLIPModel
from st_clickable_images import clickable_images
MODEL_NAMES = [
# "base-patch32",
# "base-patch16",
# "large-patch14",
"large-patch14-336"
]
@st.cache(allow_output_mutation=True)
def load():
df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")}
models = {}
processors = {}
embeddings = {}
for name in MODEL_NAMES:
models[name] = CLIPModel.from_pretrained(f"openai/clip-vit-{name}").eval()
processors[name] = CLIPProcessor.from_pretrained(f"openai/clip-vit-{name}")
embeddings[name] = {
0: np.load(f"embeddings-vit-{name}.npy"),
1: np.load(f"embeddings2-vit-{name}.npy"),
}
for k in [0, 1]:
embeddings[name][k] = embeddings[name][k] / np.linalg.norm(
embeddings[name][k], axis=1, keepdims=True
)
return models, processors, df, embeddings
models, processors, df, embeddings = load()
source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"}
def compute_text_embeddings(list_of_strings, name):
inputs = processors[name](text=list_of_strings, return_tensors="pt", padding=True)
with torch.no_grad():
result = models[name].get_text_features(**inputs).detach().numpy()
return result / np.linalg.norm(result, axis=1, keepdims=True)
def image_search(query, corpus, name, n_results=24):
positive_embeddings = None
def concatenate_embeddings(e1, e2):
if e1 is None:
return e2
else:
return np.concatenate((e1, e2), axis=0)
splitted_query = query.split("EXCLUDING ")
dot_product = 0
k = 0 if corpus == "Unsplash" else 1
if len(splitted_query[0]) > 0:
positive_queries = splitted_query[0].split(";")
for positive_query in positive_queries:
match = re.match(r"\[(Movies|Unsplash):(\d{1,5})\](.*)", positive_query)
if match:
corpus2, idx, remainder = match.groups()
idx, remainder = int(idx), remainder.strip()
k2 = 0 if corpus2 == "Unsplash" else 1
positive_embeddings = concatenate_embeddings(
positive_embeddings, embeddings[name][k2][idx : idx + 1, :]
)
if len(remainder) > 0:
positive_embeddings = concatenate_embeddings(
positive_embeddings, compute_text_embeddings([remainder], name)
)
else:
positive_embeddings = concatenate_embeddings(
positive_embeddings, compute_text_embeddings([positive_query], name)
)
dot_product = embeddings[name][k] @ positive_embeddings.T
dot_product = dot_product - np.median(dot_product, axis=0)
dot_product = dot_product / np.max(dot_product, axis=0, keepdims=True)
dot_product = np.min(dot_product, axis=1)
if len(splitted_query) > 1:
negative_queries = (" ".join(splitted_query[1:])).split(";")
negative_embeddings = compute_text_embeddings(negative_queries, name)
dot_product2 = embeddings[name][k] @ negative_embeddings.T
dot_product2 = dot_product2 - np.median(dot_product2, axis=0)
dot_product2 = dot_product2 / np.max(dot_product2, axis=0, keepdims=True)
dot_product -= np.max(np.maximum(dot_product2, 0), axis=1)
results = np.argsort(dot_product)[-1 : -n_results - 1 : -1]
return [
(
df[k].iloc[i]["path"],
df[k].iloc[i]["tooltip"] + source[k],
i,
)
for i in results
]
description = """
# 意味による画像検索
**検索語を入力してから Enter キーを押してください**
*OpenAI の [CLIP](https://openai.com/blog/clip/) model, 🤗 Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), [Unsplash](https://unsplash.com/) の 25k images と [The Movie Database (TMDB)](https://www.themoviedb.org/) の 8k images を使用して構築しています。*
*Vladimir Haltakov の [Unsplash Image Search](https://github.com/haltakov/natural-language-image-search) と Travis Hoppe の [Alph, The Sacred River](https://github.com/thoppe/alph-the-sacred-river)  に触発されました。*
"""
howto = """
- 画像をクリックすると、それをクエリとして使用し、類似画像を検索できます。
- 複数の検索語を組み合わせることができます(区切り文字として「**;**」を使用します)。
- 検索語に 「**EXCLUDING**」 が含まれている場合、その右側の部分が否定クエリとして使用されます。
"""
div_style = {
"display": "flex",
"justify-content": "center",
"flex-wrap": "wrap",
}
def main():
st.markdown(
"""
<style>
.block-container{
max-width: 1200px;
}
div.row-widget.stRadio > div{
flex-direction:row;
display: flex;
justify-content: center;
}
div.row-widget.stRadio > div > label{
margin-left: 5px;
margin-right: 5px;
}
.row-widget {
margin-top: -25px;
}
section>div:first-child {
padding-top: 30px;
}
div.reportview-container > section:first-child{
max-width: 320px;
}
#MainMenu {
visibility: hidden;
}
footer {
visibility: hidden;
}
</style>""",
unsafe_allow_html=True,
)
st.sidebar.markdown(description)
with st.sidebar.expander("高度な使用方法"):
st.markdown(howto)
# mode = st.sidebar.selectbox(
# "", ["Results for ViT-L/14@336px", "Comparison of 2 models"], index=0
# )
_, c, _ = st.columns((1, 3, 1))
if "query" in st.session_state:
query = c.text_input("", value=st.session_state["query"])
else:
query = c.text_input("", value="clouds at sunset")
corpus = st.radio("", ["Unsplash", "Movies"])
models_dict = {
"ViT-B/32 (quicker)": "base-patch32",
"ViT-B/16 (average)": "base-patch16",
# "ViT-L/14 (slow)": "large-patch14",
"ViT-L/14@336px (slower)": "large-patch14-336",
}
if False: # "Comparison" in mode:
c1, c2 = st.columns((1, 1))
selection1 = c1.selectbox("", models_dict.keys(), index=0)
selection2 = c2.selectbox("", models_dict.keys(), index=2)
name1 = models_dict[selection1]
name2 = models_dict[selection2]
else:
name1 = MODEL_NAMES[-1]
if len(query) > 0:
results1 = image_search(query, corpus, name1)
if False: # "Comparison" in mode:
with c1:
clicked1 = clickable_images(
[result[0] for result in results1],
titles=[result[1] for result in results1],
div_style=div_style,
img_style={"margin": "2px", "height": "150px"},
key=query + corpus + name1 + "1",
)
results2 = image_search(query, corpus, name2)
with c2:
clicked2 = clickable_images(
[result[0] for result in results2],
titles=[result[1] for result in results2],
div_style=div_style,
img_style={"margin": "2px", "height": "150px"},
key=query + corpus + name2 + "2",
)
else:
clicked1 = clickable_images(
[result[0] for result in results1],
titles=[result[1] for result in results1],
div_style=div_style,
img_style={"margin": "2px", "height": "200px"},
key=query + corpus + name1 + "1",
)
clicked2 = -1
if clicked2 >= 0 or clicked1 >= 0:
change_query = False
if "last_clicked" not in st.session_state:
change_query = True
else:
if max(clicked2, clicked1) != st.session_state["last_clicked"]:
change_query = True
if change_query:
if clicked1 >= 0:
st.session_state["query"] = f"[{corpus}:{results1[clicked1][2]}]"
# elif clicked2 >= 0:
# st.session_state["query"] = f"[{corpus}:{results2[clicked2][2]}]"
st.experimental_rerun()
if __name__ == "__main__":
main()