Spaces:
Build error
Build error
Trent
commited on
Commit
ยท
cf349fd
1
Parent(s):
2cd1913
Text to image Search Engine demo
Browse files- app.py +1 -1
- requirements.txt +3 -1
- text2image.py +34 -4
- utils.py +17 -0
app.py
CHANGED
@@ -10,4 +10,4 @@ st.sidebar.title("Navigation")
|
|
10 |
model = st.sidebar.selectbox("Choose a model", ["koclip", "koclip-large"])
|
11 |
page = st.sidebar.selectbox("Choose a task", list(PAGES.keys()))
|
12 |
|
13 |
-
PAGES[page].app(
|
|
|
10 |
model = st.sidebar.selectbox("Choose a model", ["koclip", "koclip-large"])
|
11 |
page = st.sidebar.selectbox("Choose a task", list(PAGES.keys()))
|
12 |
|
13 |
+
PAGES[page].app(model)
|
requirements.txt
CHANGED
@@ -3,4 +3,6 @@ jaxlib
|
|
3 |
flax
|
4 |
transformers
|
5 |
streamlit
|
6 |
-
tqdm
|
|
|
|
|
|
3 |
flax
|
4 |
transformers
|
5 |
streamlit
|
6 |
+
tqdm
|
7 |
+
nmslib
|
8 |
+
matplotlib
|
text2image.py
CHANGED
@@ -1,14 +1,44 @@
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
|
3 |
-
from utils import load_model
|
|
|
|
|
4 |
|
5 |
|
6 |
def app(model_name):
|
7 |
-
|
|
|
8 |
|
|
|
|
|
9 |
|
10 |
-
st.title("Text to Image")
|
11 |
st.markdown("""
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
""")
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
import streamlit as st
|
4 |
|
5 |
+
from utils import load_model, load_index
|
6 |
+
import numpy as np
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
|
9 |
|
10 |
def app(model_name):
|
11 |
+
images_directory = 'images/val2017'
|
12 |
+
features_directory = f'features/val2017/{model_name}.tsv'
|
13 |
|
14 |
+
files, index = load_index(features_directory)
|
15 |
+
model, processor = load_model(f'koclip/{model_name}')
|
16 |
|
17 |
+
st.title("Text to Image Search Engine")
|
18 |
st.markdown("""
|
19 |
+
This demonstration explores capability of KoCLIP as a Korean-language Image search engine. Embeddings for each of
|
20 |
+
5000 images from [MSCOCO](https://cocodataset.org/#home) 2017 validation set was generated using trained KoCLIP
|
21 |
+
vision model. They are ranked based on cosine similarity distance from input Text query embeddings and top 10 images
|
22 |
+
are displayed below.
|
23 |
+
|
24 |
+
KoCLIP is a retraining of OpenAI's CLIP model using 82,783 images from [MSCOCO](https://cocodataset.org/#home) dataset and
|
25 |
+
Korean caption annotations. Korean translation of caption annotations were obtained from [AI Hub](https://aihub.or.kr/keti_data_board/visual_intelligence).
|
26 |
+
|
27 |
+
Example Queries : ์ํํธ(Apartment), ์๋์ฐจ(Car), ์ปดํจํฐ(Computer)
|
28 |
""")
|
29 |
|
30 |
+
query = st.text_input("ํ๊ธ ์ง๋ฌธ์ ์ ์ด์ฃผ์ธ์ (Korean Text Query) :", value="์ํํธ")
|
31 |
+
if st.button("์ง๋ฌธ (Query)"):
|
32 |
+
proc = processor(text=[query], images=None, return_tensors="jax", padding=True)
|
33 |
+
vec = np.asarray(model.get_text_features(**proc))
|
34 |
+
ids, dists = index.knnQuery(vec, k=10)
|
35 |
+
result_files = map(lambda id: files[id], ids)
|
36 |
+
result_imgs, result_captions = [], []
|
37 |
+
for file, dist in zip(result_files, dists):
|
38 |
+
result_imgs.append(plt.imread(os.path.join(images_directory, file)))
|
39 |
+
result_captions.append("{:s} (์ ์ฌ๋: {:.3f})".format(file, 1.0 - dist))
|
40 |
+
|
41 |
+
st.image(result_imgs[:3], caption=result_captions[:3], width=200)
|
42 |
+
st.image(result_imgs[3:6], caption=result_captions[3:6], width=200)
|
43 |
+
st.image(result_imgs[6:9], caption=result_captions[6:9], width=200)
|
44 |
+
st.image(result_imgs[9:], caption=result_captions[9:], width=200)
|
utils.py
CHANGED
@@ -1,8 +1,25 @@
|
|
|
|
1 |
import streamlit as st
|
2 |
from transformers import CLIPProcessor, AutoTokenizer, ViTFeatureExtractor
|
|
|
3 |
|
4 |
from koclip import FlaxHybridCLIP
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
@st.cache(allow_output_mutation=True)
|
8 |
def load_model(model_name="koclip/koclip"):
|
|
|
1 |
+
import nmslib
|
2 |
import streamlit as st
|
3 |
from transformers import CLIPProcessor, AutoTokenizer, ViTFeatureExtractor
|
4 |
+
import numpy as np
|
5 |
|
6 |
from koclip import FlaxHybridCLIP
|
7 |
|
8 |
+
@st.cache(allow_output_mutation=True)
|
9 |
+
def load_index(img_file):
|
10 |
+
filenames, embeddings = [], []
|
11 |
+
lines = open(img_file, "r")
|
12 |
+
for line in lines:
|
13 |
+
cols = line.strip().split('\t')
|
14 |
+
filename = cols[0]
|
15 |
+
embedding = np.array([float(x) for x in cols[1].split(',')])
|
16 |
+
filenames.append(filename)
|
17 |
+
embeddings.append(embedding)
|
18 |
+
embeddings = np.array(embeddings)
|
19 |
+
index = nmslib.init(method='hnsw', space='cosinesimil')
|
20 |
+
index.addDataPointBatch(embeddings)
|
21 |
+
index.createIndex({'post': 2}, print_progress=True)
|
22 |
+
return filenames, index
|
23 |
|
24 |
@st.cache(allow_output_mutation=True)
|
25 |
def load_model(model_name="koclip/koclip"):
|