Spaces:
Runtime error
Runtime error
import torch | |
from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN | |
from datasets import load_dataset | |
from PIL import Image | |
import numpy as np | |
import paddlehub as hub | |
import random | |
from PIL import ImageDraw,ImageFont | |
import streamlit as st | |
def load_bg_model(): | |
bg_model = hub.Module(name='U2NetP', directory='assets/models/') | |
return bg_model | |
bg_model = load_bg_model() | |
def remove_bg(img): | |
result = bg_model.Segmentation( | |
images=[np.array(img)[:,:,::-1]], | |
paths=None, | |
batch_size=1, | |
input_size=320, | |
output_dir=None, | |
visualization=False) | |
output = result[0] | |
mask=Image.fromarray(output['mask']) | |
front=Image.fromarray(output['front'][:,:,::-1]).convert("RGBA") | |
front.putalpha(mask) | |
return front | |
meme_template=Image.open("./assets/pigeon_meme.jpg").convert("RGBA") | |
def make_meme(pigeon,text="Is this a pigeon?",show_text=True,remove_background=True): | |
meme=meme_template.copy() | |
approx_butterfly_center=(850,30) | |
if remove_background: | |
pigeon=remove_bg(pigeon) | |
meme=meme.convert("RGBA") | |
random_rotate=random.randint(-30,30) | |
random_size=random.randint(150,200) | |
pigeon=pigeon.resize((random_size,random_size)).rotate(random_rotate,expand=True) | |
meme.alpha_composite(pigeon, approx_butterfly_center) | |
#ref: https://blog.lipsumarium.com/caption-memes-in-python/ | |
def drawTextWithOutline(text, x, y): | |
draw.text((x-2, y-2), text,(0,0,0),font=font) | |
draw.text((x+2, y-2), text,(0,0,0),font=font) | |
draw.text((x+2, y+2), text,(0,0,0),font=font) | |
draw.text((x-2, y+2), text,(0,0,0),font=font) | |
draw.text((x, y), text, (255,255,255), font=font) | |
if show_text: | |
draw = ImageDraw.Draw(meme) | |
font_size=52 | |
font = ImageFont.truetype("assets/impact.ttf", font_size) | |
w, h = draw.textsize(text, font) # measure the size the text will take | |
drawTextWithOutline(text, meme.width/2 - w/2, meme.height - font_size*2) | |
meme = meme.convert("RGB") | |
return meme | |
def get_train_data(dataset_name="huggan/smithsonian_butterflies_subset"): | |
dataset=load_dataset(dataset_name) | |
dataset=dataset.sort("sim_score") | |
return dataset["train"] | |
from transformers import BeitFeatureExtractor, BeitForImageClassification | |
emb_feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224') | |
emb_model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224') | |
def embed(images): | |
inputs = emb_feature_extractor(images=images, return_tensors="pt") | |
outputs = emb_model(**inputs,output_hidden_states= True) | |
last_hidden=outputs.hidden_states[-1] | |
pooler=emb_model.base_model.pooler | |
final_emb=pooler(last_hidden).detach().numpy() | |
return final_emb | |
def build_index(): | |
dataset=get_train_data() | |
ds_with_embeddings = dataset.map(lambda x: {"beit_embeddings":embed(x["image"])},batched=True,batch_size=20) | |
ds_with_embeddings.add_faiss_index(column='beit_embeddings') | |
ds_with_embeddings.save_faiss_index('beit_embeddings', 'beit_index.faiss') | |
def get_dataset(): | |
dataset=get_train_data() | |
dataset.load_faiss_index('beit_embeddings', 'beit_index.faiss') | |
return dataset | |
def load_model(model_name='ceyda/butterfly_cropped_uniq1K_512',model_version="95a9596a1e47e2419c9bd5252d809eecb14fdcf4"): | |
gan = LightweightGAN.from_pretrained(model_name,version=model_version) | |
gan.eval() | |
return gan | |
def generate(gan,batch_size=1): | |
with torch.no_grad(): | |
ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0., 1.)*255 | |
ims = ims.permute(0,2,3,1).detach().cpu().numpy().astype(np.uint8) | |
return ims | |
def interpolate(): | |
pass |