Spaces:
Running
Running
from typing import Union | |
import gradio as gr | |
from numpy import empty | |
import open_clip | |
import torch | |
import PIL.Image as Image | |
# Set device to GPU if available | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print(f"PyTorch Device {device}") | |
# Load the OpenCLIP model and the necessary preprocessors | |
# openclip_model = 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K' | |
# openclip_model = 'laion/CLIP-ViT-B-16-laion2B-s34B-b88K' | |
openclip_model_name = "laion/CLIP-ViT-L-14-laion2B-s32B-b82K" | |
openclip_model = "hf-hub:" + openclip_model_name | |
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms( | |
model_name=openclip_model, device=device | |
) | |
# Define function to generate text embeddings | |
# @spaces.GPU | |
def generate_text_embedding(text_data: Union[str, tuple[str]]) -> list[str]: | |
""" | |
Generate embeddings for text data using the OpenCLIP model. | |
Parameters | |
---------- | |
text_data : str or tuple of str | |
Text data to embed. | |
Returns | |
------- | |
text_embeddings : list of str | |
List of text embeddings. | |
""" | |
# Embed text data | |
text_embeddings = [] | |
empty_data_indices = [] | |
if text_data: | |
# If text_data is a string, convert to list of strings | |
if isinstance(text_data, str): | |
text_data = [text_data] | |
# If text_data is a tuple of strings, convert to list of strings | |
if isinstance(text_data, tuple): | |
text_data = list(text_data) | |
# If text_data is not a list of strings, raise error | |
if not isinstance(text_data, list): | |
raise TypeError("text_data must be a string or a tuple of strings.") | |
# Keep track of indices of empty text strings | |
empty_data_indices = [i for i, text in enumerate(text_data) if text == ""] | |
# Remove empty text strings | |
text_data = [text for text in text_data if text != ""] | |
if text_data: | |
# Tokenize text_data and convert to tensor | |
text_data = open_clip.tokenize(text_data).to(device) | |
# Generate text embeddings | |
with torch.no_grad(): | |
text_embeddings = model.encode_text(text_data) | |
# Convert embeddings to list of strings | |
text_embeddings = [ | |
embedding.detach().cpu().numpy().tolist() | |
for embedding in text_embeddings | |
] | |
# Insert empty strings at indices of empty text strings | |
for i in empty_data_indices: | |
text_embeddings.insert(i, "") | |
return text_embeddings | |
# Define function to generate image embeddings | |
def generate_image_embedding( | |
image_data: Union[Image.Image, tuple[Image.Image]] | |
) -> list[str]: | |
""" | |
Generate embeddings for image data using the OpenCLIP model. | |
Parameters | |
---------- | |
image_data : PIL.Image.Image or tuple of PIL.Image.Image | |
Image data to embed. | |
Returns | |
------- | |
image_embeddings : list of str | |
List of image embeddings. | |
""" | |
# Embed image data | |
image_embeddings = [] | |
empty_data_indices = [] | |
if image_data: | |
# If image_data is a single PIL image, convert to list of PIL images | |
if isinstance(image_data, Image.Image): | |
image_data = [image_data] | |
# If image_data is a tuple of images, convert to list of images | |
if isinstance(image_data, tuple): | |
image_data = list(image_data) | |
# Keep track of indices of None images | |
empty_data_indices = [i for i, img in enumerate(image_data) if img is None] | |
# Remove None images | |
image_data = [img for img in image_data if img is not None] | |
if image_data: | |
# Preprocess image_data and convert to tensor | |
image_data = [preprocess_val(img).unsqueeze(0) for img in image_data] | |
image_data = torch.stack(image_data).squeeze(1).to(device) | |
# Generate image embeddings | |
with torch.no_grad(): | |
image_embeddings = model.encode_image(image_data) | |
# Convert embeddings to list of strings | |
image_embeddings = [ | |
embedding.detach().cpu().numpy().tolist() | |
for embedding in image_embeddings | |
] | |
# Insert empty strings at indices of empty images | |
for i in empty_data_indices: | |
image_embeddings.insert(i, "") | |
return image_embeddings | |
# Define function to generate embeddings | |
def generate_embedding( | |
text_data: Union[str, tuple[str]], | |
image_data: Union[Image.Image, tuple[Image.Image]], | |
) -> tuple[list[str], list[str], list[str]]: | |
""" | |
Generate embeddings for text and image data using the OpenCLIP model. | |
Parameters | |
---------- | |
text_data : str or tuple of str | |
Text data to embed. | |
image_data : PIL.Image.Image or tuple of PIL.Image.Image | |
Image data to embed. | |
Returns | |
------- | |
text_embeddings : list of str | |
List of text embeddings. | |
image_embeddings : list of str | |
List of image embeddings. | |
similarity : list of str | |
List of cosine similarity between text and image embeddings. | |
""" | |
# Embed text data | |
text_embeddings = generate_text_embedding(text_data) | |
# Embed image data | |
image_embeddings = generate_image_embedding(image_data) | |
# Calculate cosine similarity between text and image embeddings | |
similarity = [] | |
empty_data_indices = [] | |
if text_embeddings and image_embeddings: | |
# Filter out embedding pairs with either empty text or image embeddings, tracking indices of empty embeddings | |
text_embeddings_filtered = [] | |
image_embeddings_filtered = [] | |
for i, (text_embedding, image_embedding) in enumerate( | |
zip(text_embeddings, image_embeddings) | |
): | |
if text_embedding != "" and image_embedding != "": | |
text_embeddings_filtered.append(text_embedding) | |
image_embeddings_filtered.append(image_embedding) | |
else: | |
empty_data_indices.append(i) | |
# Calculate cosine similarity if there are any non-empty embedding pairs | |
if image_embeddings_filtered and text_embeddings_filtered: | |
# Convert lists back to tensors for processing | |
text_embeddings_tensor = torch.tensor(text_embeddings_filtered) | |
image_embeddings_tensor = torch.tensor(image_embeddings_filtered) | |
# Normalize the embeddings | |
text_embedding_norm = text_embeddings_tensor / text_embeddings_tensor.norm( | |
dim=-1, keepdim=True | |
) | |
image_embedding_norm = ( | |
image_embeddings_tensor | |
/ image_embeddings_tensor.norm(dim=-1, keepdim=True) | |
) | |
# Calculate cosine similarity | |
similarity = torch.nn.functional.cosine_similarity( | |
text_embedding_norm, image_embedding_norm, dim=-1 | |
) | |
# Convert to percentage as text | |
similarity = [f"{sim.item() * 100:.2f}%" for sim in similarity] | |
# Insert empty text strings in similarity | |
for i in empty_data_indices: | |
similarity.insert(i, "") | |
return (text_embeddings, image_embeddings, similarity, openclip_model_name) | |
# Define Gradio interface | |
demo = gr.Interface( | |
fn=generate_embedding, | |
inputs=[ | |
gr.Textbox( | |
lines=5, | |
max_lines=5, | |
placeholder="Enter Text Here...", | |
label="Text to Embed", | |
), | |
gr.Image(height=512, type="pil", label="Image to Embed"), | |
], | |
outputs=[ | |
gr.Textbox(lines=5, max_lines=5, label="Text Embedding", autoscroll=False), | |
gr.Textbox(lines=5, max_lines=5, label="Image Embedding", autoscroll=False), | |
gr.Textbox(label="Cosine Similarity"), | |
gr.Textbox(label="Embedding Model"), | |
], | |
title="OpenCLIP Embedding Generator", | |
description="Generate embeddings using OpenCLIP model for text and images.", | |
allow_flagging="never", | |
batch=False, | |
api_name="embed", | |
) | |
# Enable queueing and launch the app | |
if __name__ == "__main__": | |
demo.queue(api_open=True).launch(show_api=True) | |