File size: 3,697 Bytes
657fc3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a76f18
657fc3f
 
 
 
 
 
 
 
 
d99e0ff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
from io import BytesIO
import requests
from datetime import datetime
import random

# Interface utilities
import gradio as gr

# Data utilities
import numpy as np
import pandas as pd

# Image utilities
from PIL import Image
import cv2 

# FLAVA Model
import torch
from transformers import BertTokenizer, FlavaModel

# Style Transfer Model
import paddlehub as hub



os.system("hub install stylepro_artistic==1.0.1")
stylepro_artistic = hub.Module(name="stylepro_artistic")



# FLAVA Model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = FlavaModel.from_pretrained("facebook/flava-full")
tokenizer = BertTokenizer.from_pretrained("facebook/flava-full")
model = model.to(device)

# Load Data
photo_features = np.load("unsplash-dataset/features.npy")
photo_data = pd.read_csv("unsplash-dataset/photos.csv")

def image_from_text(text_input):
    start=datetime.now()

    ## Inference
    with torch.no_grad():
        inputs = tokenizer([text_input],  padding=True, return_tensors="pt").to(device)
        text_features = model.get_text_features(**inputs)[:, 0, :].cpu().numpy()
    
    ## Find similarity
    similarities = list((text_features @ photo_features.T).squeeze(0))
    
    ## Return best image :)
    idx = sorted(zip(similarities, range(photo_features.shape[0])), key=lambda x: x[0], reverse=True)[0][1]
    photo = photo_data.iloc[idx]

    print(f"Time spent at FLAVA: {datetime.now()-start}")

    start=datetime.now()
    # Downlaod image
    response = requests.get(photo["path"])
    pil_image = Image.open(BytesIO(response.content)).convert("RGB") 
    open_cv_image = np.array(pil_image) 
    # Convert RGB to BGR 
    open_cv_image = open_cv_image[:, :, ::-1].copy() 

    print(f"Time spent at Image request: {datetime.now()-start}")

    return open_cv_image

def inference(content, style):
    content_image = image_from_text(content)
    start=datetime.now()

    result = stylepro_artistic.style_transfer(
        images=[{
            "content": content_image,
            "styles": [cv2.imread(style.name)]
        }])

    print(f"Time spent at Style Transfer: {datetime.now()-start}")
    return Image.fromarray(np.uint8(result[0]["data"])[:,:,::-1]).convert("RGB")

if __name__ == "__main__": 
    title = "FLAVA Neural Style Transfer"
    description = "Gradio demo for Neural Style Transfer. Inspired from <a href='https://huggingface.co/spaces/WaterKnight/neural-style-transfer'>this demo for CLIP</a>. To use it, simply enter the text for image content and upload style image. Read more at the links below."
    article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2003.07694'target='_blank'>Parameter-Free Style Projection for Arbitrary Style Transfer</a> | <a href='https://github.com/PaddlePaddle/PaddleHub' target='_blank'>Github Repo</a></br><a href='https://arxiv.org/abs/2112.04482' target='_blank'>FLAVA paper</a> | <a href='https://huggingface.co/transformers/model_doc/flava.html' target='_blank'>Hugging Face FLAVA Implementation</a></p>"
    examples=[
            ["a cute kangaroo", "styles/starry.jpeg"],
            ["man holding beer", "styles/mona1.jpeg"],
        ]
    demo = gr.Interface(inference, 
        inputs=[
            gr.inputs.Textbox(lines=1, placeholder="Describe the content of the image", default="a modern city with neon lights", label="Describe the image to which the style will be applied"),
            gr.inputs.Image(type="file", label="Style to be applied"),
        ], 
        outputs=gr.outputs.Image(type="pil"),
        enable_queue=True,
        title=title,
        description=description,
        article=article,
        examples=examples
    )
    demo.launch()