Spaces:
Build error
Build error
# Import Libraries | |
from pathlib import Path | |
import pandas as pd | |
import numpy as np | |
import torch | |
import clip | |
from PIL import Image | |
from io import BytesIO | |
import requests | |
import gradio as gr | |
# Load the openAI's CLIP model | |
#model, preprocess = clip.load("ViT-B/32", jit=False) | |
#display output photo | |
def show_output_image(matched_images) : | |
image=[] | |
for photo_id in matched_images: | |
photo_image_url = f"https://unsplash.com/photos/{photo_id}/download?w=280" | |
#photo_image_url = f"https://unsplash.com/photos/{photo_id}?w=280" | |
response = requests.get(photo_image_url) | |
img = Image.open(BytesIO(response.content)) | |
#return img | |
image.append(img) | |
return image | |
# Encode and normalize the search query using CLIP | |
def encode_search_query(search_query, model, device): | |
with torch.no_grad(): | |
text_encoded = model.encode_text(clip.tokenize(search_query).to(device)) | |
text_encoded /= text_encoded.norm(dim=-1, keepdim=True) | |
# Retrieve the feature vector from the GPU and convert it to a numpy array | |
return text_encoded.cpu().numpy() | |
# Find all matched photos | |
def find_matches(text_features, photo_features, photo_ids, results_count=4): | |
# Compute the similarity between the search query and each photo using the Cosine similarity | |
similarities = (photo_features @ text_features.T).squeeze(1) | |
# Sort the photos by their similarity score | |
best_photo_idx = (-similarities).argsort() | |
# Return the photo IDs of the best matches | |
return [photo_ids[i] for i in best_photo_idx[:results_count]] | |
def image_search(search_text, search_image, option): | |
# taking photo IDs | |
photo_ids = pd.read_csv("./photo_ids.csv") | |
photo_ids = list(photo_ids['photo_id']) | |
# taking features vectors | |
photo_features = np.load("./features.npy") | |
# check if CUDA available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load the openAI's CLIP model | |
model, preprocess = clip.load("ViT-B/32", device=device, jit=False) | |
#model = model.to(device) | |
# Input Text Query | |
#search_query = "The feeling when your program finally works" | |
if option == "Text-To-Image" : | |
# Extracting text features | |
text_features = encode_search_query(search_text, model, device) | |
# Find the matched Images | |
matched_images = find_matches(text_features, photo_features, photo_ids, 4) | |
# ---- debug purpose ------# | |
print(matched_images[0]) | |
id = matched_images[0] | |
photo_image_url = f"https://unsplash.com/photos/{id}/download?w=280" | |
print(photo_image_url) | |
#--------------------------# | |
return show_output_image(matched_images) | |
elif option == "Image-To-Image": | |
# Input Image for Search | |
with torch.no_grad(): | |
image_feature = model.encode_image(preprocess(search_image).unsqueeze(0).to(device)) | |
image_feature = (image_feature / image_feature.norm(dim=-1, keepdim=True)).cpu().numpy() | |
# Find the matched Images | |
matched_images = find_matches(image_feature, photo_features, photo_ids, 4) | |
#is_input_image = True | |
images = show_output_image(matched_images) | |
return images | |
gr.Interface(fn=image_search, | |
inputs=[gr.inputs.Textbox(lines=7, label="Input Text"), | |
gr.inputs.Image(type="pil", optional=True), | |
gr.inputs.Dropdown(["Text-To-Image", "Image-To-Image"]) | |
], | |
outputs=gr.outputs.Carousel([gr.outputs.Image(type="pil"), gr.outputs.Image(type="pil"), gr.outputs.Image(type="pil"), gr.outputs.Image(type="pil")]), | |
enable_queue=True | |
).launch(debug=True) |