openclip-embed / app.py
RoboApocalypse's picture
remove unnecessary `spaces` library import
1e3eed5
from typing import Union
import gradio as gr
from numpy import empty
import open_clip
import torch
import PIL.Image as Image
# Set device to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"PyTorch Device {device}")
# Load the OpenCLIP model and the necessary preprocessors
# openclip_model = 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
# openclip_model = 'laion/CLIP-ViT-B-16-laion2B-s34B-b88K'
openclip_model_name = "laion/CLIP-ViT-L-14-laion2B-s32B-b82K"
openclip_model = "hf-hub:" + openclip_model_name
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(
model_name=openclip_model, device=device
)
# Define function to generate text embeddings
# @spaces.GPU
def generate_text_embedding(text_data: Union[str, tuple[str]]) -> list[str]:
"""
Generate embeddings for text data using the OpenCLIP model.
Parameters
----------
text_data : str or tuple of str
Text data to embed.
Returns
-------
text_embeddings : list of str
List of text embeddings.
"""
# Embed text data
text_embeddings = []
empty_data_indices = []
if text_data:
# If text_data is a string, convert to list of strings
if isinstance(text_data, str):
text_data = [text_data]
# If text_data is a tuple of strings, convert to list of strings
if isinstance(text_data, tuple):
text_data = list(text_data)
# If text_data is not a list of strings, raise error
if not isinstance(text_data, list):
raise TypeError("text_data must be a string or a tuple of strings.")
# Keep track of indices of empty text strings
empty_data_indices = [i for i, text in enumerate(text_data) if text == ""]
# Remove empty text strings
text_data = [text for text in text_data if text != ""]
if text_data:
# Tokenize text_data and convert to tensor
text_data = open_clip.tokenize(text_data).to(device)
# Generate text embeddings
with torch.no_grad():
text_embeddings = model.encode_text(text_data)
# Convert embeddings to list of strings
text_embeddings = [
embedding.detach().cpu().numpy().tolist()
for embedding in text_embeddings
]
# Insert empty strings at indices of empty text strings
for i in empty_data_indices:
text_embeddings.insert(i, "")
return text_embeddings
# Define function to generate image embeddings
def generate_image_embedding(
image_data: Union[Image.Image, tuple[Image.Image]]
) -> list[str]:
"""
Generate embeddings for image data using the OpenCLIP model.
Parameters
----------
image_data : PIL.Image.Image or tuple of PIL.Image.Image
Image data to embed.
Returns
-------
image_embeddings : list of str
List of image embeddings.
"""
# Embed image data
image_embeddings = []
empty_data_indices = []
if image_data:
# If image_data is a single PIL image, convert to list of PIL images
if isinstance(image_data, Image.Image):
image_data = [image_data]
# If image_data is a tuple of images, convert to list of images
if isinstance(image_data, tuple):
image_data = list(image_data)
# Keep track of indices of None images
empty_data_indices = [i for i, img in enumerate(image_data) if img is None]
# Remove None images
image_data = [img for img in image_data if img is not None]
if image_data:
# Preprocess image_data and convert to tensor
image_data = [preprocess_val(img).unsqueeze(0) for img in image_data]
image_data = torch.stack(image_data).squeeze(1).to(device)
# Generate image embeddings
with torch.no_grad():
image_embeddings = model.encode_image(image_data)
# Convert embeddings to list of strings
image_embeddings = [
embedding.detach().cpu().numpy().tolist()
for embedding in image_embeddings
]
# Insert empty strings at indices of empty images
for i in empty_data_indices:
image_embeddings.insert(i, "")
return image_embeddings
# Define function to generate embeddings
def generate_embedding(
text_data: Union[str, tuple[str]],
image_data: Union[Image.Image, tuple[Image.Image]],
) -> tuple[list[str], list[str], list[str]]:
"""
Generate embeddings for text and image data using the OpenCLIP model.
Parameters
----------
text_data : str or tuple of str
Text data to embed.
image_data : PIL.Image.Image or tuple of PIL.Image.Image
Image data to embed.
Returns
-------
text_embeddings : list of str
List of text embeddings.
image_embeddings : list of str
List of image embeddings.
similarity : list of str
List of cosine similarity between text and image embeddings.
"""
# Embed text data
text_embeddings = generate_text_embedding(text_data)
# Embed image data
image_embeddings = generate_image_embedding(image_data)
# Calculate cosine similarity between text and image embeddings
similarity = []
empty_data_indices = []
if text_embeddings and image_embeddings:
# Filter out embedding pairs with either empty text or image embeddings, tracking indices of empty embeddings
text_embeddings_filtered = []
image_embeddings_filtered = []
for i, (text_embedding, image_embedding) in enumerate(
zip(text_embeddings, image_embeddings)
):
if text_embedding != "" and image_embedding != "":
text_embeddings_filtered.append(text_embedding)
image_embeddings_filtered.append(image_embedding)
else:
empty_data_indices.append(i)
# Calculate cosine similarity if there are any non-empty embedding pairs
if image_embeddings_filtered and text_embeddings_filtered:
# Convert lists back to tensors for processing
text_embeddings_tensor = torch.tensor(text_embeddings_filtered)
image_embeddings_tensor = torch.tensor(image_embeddings_filtered)
# Normalize the embeddings
text_embedding_norm = text_embeddings_tensor / text_embeddings_tensor.norm(
dim=-1, keepdim=True
)
image_embedding_norm = (
image_embeddings_tensor
/ image_embeddings_tensor.norm(dim=-1, keepdim=True)
)
# Calculate cosine similarity
similarity = torch.nn.functional.cosine_similarity(
text_embedding_norm, image_embedding_norm, dim=-1
)
# Convert to percentage as text
similarity = [f"{sim.item() * 100:.2f}%" for sim in similarity]
# Insert empty text strings in similarity
for i in empty_data_indices:
similarity.insert(i, "")
return (text_embeddings, image_embeddings, similarity, openclip_model_name)
# Define Gradio interface
demo = gr.Interface(
fn=generate_embedding,
inputs=[
gr.Textbox(
lines=5,
max_lines=5,
placeholder="Enter Text Here...",
label="Text to Embed",
),
gr.Image(height=512, type="pil", label="Image to Embed"),
],
outputs=[
gr.Textbox(lines=5, max_lines=5, label="Text Embedding", autoscroll=False),
gr.Textbox(lines=5, max_lines=5, label="Image Embedding", autoscroll=False),
gr.Textbox(label="Cosine Similarity"),
gr.Textbox(label="Embedding Model"),
],
title="OpenCLIP Embedding Generator",
description="Generate embeddings using OpenCLIP model for text and images.",
allow_flagging="never",
batch=False,
api_name="embed",
)
# Enable queueing and launch the app
if __name__ == "__main__":
demo.queue(api_open=True).launch(show_api=True)