Spaces:
Runtime error
Runtime error
import os | |
from io import BytesIO | |
import requests | |
from datetime import datetime | |
import random | |
# Interface utilities | |
import gradio as gr | |
# Data utilities | |
import numpy as np | |
import pandas as pd | |
# Image utilities | |
from PIL import Image | |
import cv2 | |
# FLAVA Model | |
import torch | |
from transformers import BertTokenizer, FlavaModel | |
# Style Transfer Model | |
import paddlehub as hub | |
os.system("hub install stylepro_artistic==1.0.1") | |
stylepro_artistic = hub.Module(name="stylepro_artistic") | |
# FLAVA Model | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = FlavaModel.from_pretrained("facebook/flava-full") | |
tokenizer = BertTokenizer.from_pretrained("facebook/flava-full") | |
model = model.to(device) | |
# Load Data | |
photo_features = np.load("unsplash-dataset/features.npy") | |
photo_data = pd.read_csv("unsplash-dataset/photos.csv") | |
def image_from_text(text_input): | |
start=datetime.now() | |
## Inference | |
with torch.no_grad(): | |
inputs = tokenizer([text_input], padding=True, return_tensors="pt").to(device) | |
text_features = model.get_text_features(**inputs)[:, 0, :].cpu().numpy() | |
## Find similarity | |
similarities = list((text_features @ photo_features.T).squeeze(0)) | |
## Return best image :) | |
idx = sorted(zip(similarities, range(photo_features.shape[0])), key=lambda x: x[0], reverse=True)[0][1] | |
photo = photo_data.iloc[idx] | |
print(f"Time spent at FLAVA: {datetime.now()-start}") | |
start=datetime.now() | |
# Downlaod image | |
response = requests.get(photo["path"]) | |
pil_image = Image.open(BytesIO(response.content)).convert("RGB") | |
open_cv_image = np.array(pil_image) | |
# Convert RGB to BGR | |
open_cv_image = open_cv_image[:, :, ::-1].copy() | |
print(f"Time spent at Image request: {datetime.now()-start}") | |
return open_cv_image | |
def inference(content, style): | |
content_image = image_from_text(content) | |
start=datetime.now() | |
result = stylepro_artistic.style_transfer( | |
images=[{ | |
"content": content_image, | |
"styles": [cv2.imread(style.name)] | |
}]) | |
print(f"Time spent at Style Transfer: {datetime.now()-start}") | |
return Image.fromarray(np.uint8(result[0]["data"])[:,:,::-1]).convert("RGB") | |
if __name__ == "__main__": | |
title = "FLAVA Neural Style Transfer" | |
description = "Gradio demo for Neural Style Transfer. Inspired from <a href='https://huggingface.co/spaces/WaterKnight/neural-style-transfer'>this demo for CLIP</a>. To use it, simply enter the text for image content and upload style image. Read more at the links below." | |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2003.07694'target='_blank'>Parameter-Free Style Projection for Arbitrary Style Transfer</a> | <a href='https://github.com/PaddlePaddle/PaddleHub' target='_blank'>Github Repo</a></br><a href='https://arxiv.org/abs/2112.04482' target='_blank'>FLAVA paper</a> | <a href='https://huggingface.co/transformers/model_doc/flava.html' target='_blank'>Hugging Face FLAVA Implementation</a></p>" | |
examples=[ | |
["a cute kangaroo", "styles/starry.jpeg"], | |
["man holding beer", "styles/mona1.jpeg"], | |
] | |
demo = gr.Interface(inference, | |
inputs=[ | |
gr.inputs.Textbox(lines=1, placeholder="Describe the content of the image", default="a modern city with neon lights", label="Describe the image to which the style will be applied"), | |
gr.inputs.Image(type="file", label="Style to be applied"), | |
], | |
outputs=gr.outputs.Image(type="pil"), | |
enable_queue=True, | |
title=title, | |
description=description, | |
article=article, | |
examples=examples | |
) | |
demo.launch() |