import streamlit as st | |
from annoy import AnnoyIndex | |
from sentence_transformers import SentenceTransformer | |
import json | |
from PIL import Image | |
import os | |
import urllib | |
st.set_page_config( | |
page_title="BHL Flickr Image Search", | |
page_icon="🖼️", | |
layout="wide" | |
) | |
def load_clip_model(): | |
return SentenceTransformer('clip-ViT-B-32') | |
def load_annoy_index(): | |
annoy_index = AnnoyIndex(512, metric='angular') | |
annoy_index.load('bhl_index.annoy') | |
return annoy_index | |
def load_flickr_data(): | |
with open('bhl_flickr_list.json') as json_in: | |
bhl_flickr_ids = json.load(json_in) | |
return bhl_flickr_ids | |
def bhl_annoy_search(mode, query, k=5): | |
if mode == 'id': | |
for idx, row in enumerate(bhl_flickr_ids): | |
if str(row['flickr_id']) == query: | |
matching_row = idx | |
neighbors = bhl_index.get_nns_by_item(matching_row, k, | |
include_distances=True) | |
elif mode == 'text': | |
query_emb = model.encode([query], show_progress_bar=False) | |
neighbors = bhl_index.get_nns_by_vector(query_emb[0], k, | |
include_distances=True) | |
elif mode == 'image': | |
query_emb = model.encode([query], show_progress_bar=False) | |
neighbors = bhl_index.get_nns_by_vector(query_emb[0], k, | |
include_distances=True) | |
return neighbors | |
#DEPLOY_MODE = 'streamlit_share' | |
DEPLOY_MODE = 'hf_spaces' | |
#DEPLOY_MODE = 'localhost' | |
if DEPLOY_MODE == 'localhost': | |
BASE_URL = 'http://localhost:8501/' | |
elif DEPLOY_MODE == 'streamlit_share': | |
BASE_URL = '' | |
elif DEPLOY_MODE == 'hf_spaces': | |
BASE_URL = '' | |
if __name__ == "__main__": | |
st.markdown("# BHL Flickr Image Search") | |
with st.expander("How does this work?", expanded=False): | |
st.write('placeholder') | |
st.sidebar.markdown('### Search Mode') | |
query_params = st.experimental_get_query_params() | |
mode_index = 0 | |
if 'mode' in query_params: | |
if query_params['mode'][0] == 'text_search': | |
mode_index = 0 | |
elif query_params['mode'][0] == 'flickr_id': | |
mode_index = 2 | |
app_mode ="How would you like to search?", | |
['Text search','Upload Image', 'BHL Flickr ID'], | |
index = mode_index) | |
model = load_clip_model() | |
bhl_index = load_annoy_index() | |
bhl_flickr_ids = load_flickr_data() | |
if app_mode == 'Text search': | |
search_text = 'a watercolor illustration of an insect with flowers' | |
if 'mode' in query_params: | |
if query_params['mode'][0] == 'text_search': | |
if 'query' in query_params: | |
search_text = query_params['query'][0] | |
else: | |
st.experimental_set_query_params() | |
query = st.text_input('Text query',search_text) | |
search_mode = 'text' | |
#closest_k_idx, closest_k_dist = bhl_text_search(text_query, 100) | |
elif app_mode == 'BHL Flickr ID': | |
search_id = '5974846748' | |
if 'mode' in st.experimental_get_query_params(): | |
if st.experimental_get_query_params()['mode'][0] == 'flickr_id': | |
if 'query' in st.experimental_get_query_params(): | |
search_id = st.experimental_get_query_params()['query'][0] | |
else: | |
st.experimental_set_query_params() | |
query = st.text_input('Query ID', search_id) | |
search_mode = 'id' | |
#closest_k_idx, closest_k_dist = bhl_id_search(id_query, 100) | |
elif app_mode == 'Upload Image': | |
st.experimental_set_query_params() | |
query = None | |
image_file = st.file_uploader("Upload Image", type=["png","jpg","jpeg"]) | |
search_mode = 'image' | |
#closest_k_idx = [] | |
if image_file is not None: | |
query = | |
st.image(query,width=100,caption='Query image') | |
#closest_k_idx, closest_k_dist = bhl_image_search(img, 100) | |
if query: | |
closest_k_idx, closest_k_dist = bhl_annoy_search(search_mode, query, 100) | |
col_list = st.columns(5) | |
if len(closest_k_idx): | |
for idx, annoy_idx in enumerate(closest_k_idx): | |
bhl_ids = bhl_flickr_ids[annoy_idx] | |
bhl_url = f"{bhl_ids['server']}/{bhl_ids['flickr_id']}_{bhl_ids['secret']}.jpg" | |
col_list[idx%5].image(bhl_url, use_column_width=True) | |
flickr_url = f"{bhl_ids['flickr_id']}/" | |
neighbors_url = f"{BASE_URL}?mode=flickr_id&query={bhl_ids['flickr_id']}" | |
link_html = f'<a href="{flickr_url}" target="_blank">Flickr Link</a> | <a href="{neighbors_url}">Neighbors</a>' | |
col_list[idx%5].markdown(link_html, unsafe_allow_html=True) | |
col_list[idx%5].markdown("---") | |