File size: 3,790 Bytes
b0b9e1f
 
 
47cfe13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0b9e1f
a3be375
b0b9e1f
 
 
9fbe234
 
47cfe13
 
9fbe234
47cfe13
 
9fbe234
47cfe13
9fbe234
 
b0b9e1f
9fbe234
 
 
 
 
 
 
 
 
 
b0b9e1f
47cfe13
 
9fbe234
b0b9e1f
 
 
 
47cfe13
 
b0b9e1f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
from datasets import load_dataset
from PIL import Image
import numpy as np
import paddlehub as hub
import random
from PIL import ImageDraw,ImageFont

import streamlit as st

@st.experimental_singleton
def load_bg_model():
    bg_model = hub.Module(name='U2NetP', directory='assets/models/')
    return bg_model


bg_model = load_bg_model()
def remove_bg(img):
    result = bg_model.Segmentation(
        images=[np.array(img)[:,:,::-1]],
        paths=None,
        batch_size=1,
        input_size=320,
        output_dir=None,
        visualization=False)
    output = result[0]
    mask=Image.fromarray(output['mask'])
    front=Image.fromarray(output['front'][:,:,::-1]).convert("RGBA")
    front.putalpha(mask)
    return front

meme_template=Image.open("./assets/pigeon_meme.jpg").convert("RGBA")
def make_meme(pigeon,text="Is this a pigeon?",show_text=True,remove_background=True):
    
    meme=meme_template.copy()
    approx_butterfly_center=(850,30)

    if remove_background:
        pigeon=remove_bg(pigeon)
        meme=meme.convert("RGBA")

    random_rotate=random.randint(-30,30)
    random_size=random.randint(150,200)
    pigeon=pigeon.resize((random_size,random_size)).rotate(random_rotate,expand=True)

    meme.alpha_composite(pigeon, approx_butterfly_center)

    #ref: https://blog.lipsumarium.com/caption-memes-in-python/
    def drawTextWithOutline(text, x, y):
        draw.text((x-2, y-2), text,(0,0,0),font=font)
        draw.text((x+2, y-2), text,(0,0,0),font=font)
        draw.text((x+2, y+2), text,(0,0,0),font=font)
        draw.text((x-2, y+2), text,(0,0,0),font=font)
        draw.text((x, y), text, (255,255,255), font=font)

    if show_text:
        draw = ImageDraw.Draw(meme)
        font_size=52
        font = ImageFont.truetype("assets/impact.ttf", font_size)
        w, h = draw.textsize(text, font) # measure the size the text will take
        drawTextWithOutline(text, meme.width/2 - w/2, meme.height - font_size*2)
    meme = meme.convert("RGB")    
    return meme

def get_train_data(dataset_name="huggan/smithsonian_butterflies_subset"):
    dataset=load_dataset(dataset_name)
    dataset=dataset.sort("sim_score")
    return dataset["train"]

from transformers import BeitFeatureExtractor, BeitForImageClassification
emb_feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224')
emb_model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224')
def embed(images):
    inputs = emb_feature_extractor(images=images, return_tensors="pt")
    outputs = emb_model(**inputs,output_hidden_states= True)
    last_hidden=outputs.hidden_states[-1]
    pooler=emb_model.base_model.pooler
    final_emb=pooler(last_hidden).detach().numpy()
    return final_emb    
    
def build_index():    
    dataset=get_train_data()
    ds_with_embeddings = dataset.map(lambda x: {"beit_embeddings":embed(x["image"])},batched=True,batch_size=20)
    ds_with_embeddings.add_faiss_index(column='beit_embeddings')
    ds_with_embeddings.save_faiss_index('beit_embeddings', 'beit_index.faiss')

def get_dataset():
    dataset=get_train_data()
    dataset.load_faiss_index('beit_embeddings', 'beit_index.faiss')
    return dataset

def load_model(model_name='ceyda/butterfly_cropped_uniq1K_512',model_version="95a9596a1e47e2419c9bd5252d809eecb14fdcf4"):
    gan = LightweightGAN.from_pretrained(model_name,version=model_version)
    gan.eval()
    return gan
    
def generate(gan,batch_size=1):
    with torch.no_grad():
        ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0., 1.)*255
        ims = ims.permute(0,2,3,1).detach().cpu().numpy().astype(np.uint8)
    return ims

def interpolate():
    pass