DEVICE = 'cuda' import gradio as gr import numpy as np from sklearn.svm import LinearSVC from sklearn import preprocessing import pandas as pd from diffusers import LCMScheduler, AutoencoderTiny, EulerDiscreteScheduler, UNet2DConditionModel, AutoPipelineForText2Image from diffusers.models import ImageProjection import torch import random import time import torch from urllib.request import urlopen from PIL import Image import requests from io import BytesIO, StringIO from transformers import CLIPVisionModelWithProjection from huggingface_hub import hf_hub_download from safetensors.torch import load_file import spaces prompt_list = [p for p in list(set( pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str] start_time = time.time() ####################### Setup Model model_id = "stabilityai/stable-diffusion-xl-base-1.0" sdxl_lightening = "ByteDance/SDXL-Lightning" ckpt = "sdxl_lightning_2step_unet.safetensors" unet = UNet2DConditionModel.from_config(model_id, subfolder="unet").to("cuda", torch.float16) unet.load_state_dict(load_file(hf_hub_download(sdxl_lightening, ckpt), device="cuda")) image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="models/image_encoder", torch_dtype=torch.float16,).to("cuda") pipe = AutoPipelineForText2Image.from_pretrained(model_id, unet=unet, torch_dtype=torch.float16, variant="fp16", image_encoder=image_encoder).to("cuda") pipe.unet._load_ip_adapter_weights(torch.load(hf_hub_download('h94/IP-Adapter', 'sdxl_models/ip-adapter_sdxl_vit-h.bin'))) pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl_vit-h.bin") pipe.register_modules(image_encoder = image_encoder) pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16) pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") pipe.to(device='cuda') output_hidden_state = False ####################### @spaces.GPU def predict( prompt, im_emb=None, progress=gr.Progress(track_tqdm=True) ): """Run a single prediction on the model""" with torch.no_grad(): if im_emb == None: im_emb = torch.zeros(1, 1024, dtype=torch.float16, device='cuda') im_emb = [im_emb.to('cuda').unsqueeze(0)] if prompt == '': image = pipe( prompt_embeds=torch.zeros(1, 1, 2048, dtype=torch.float16, device='cuda'), pooled_prompt_embeds=torch.zeros(1, 1280, dtype=torch.float16, device='cuda'), ip_adapter_image_embeds=im_emb, height=1024, width=1024, num_inference_steps=2, guidance_scale=0, ).images[0] else: image = pipe( prompt=prompt, ip_adapter_image_embeds=im_emb, height=1024, width=1024, num_inference_steps=2, guidance_scale=0, ).images[0] im_emb, _ = pipe.encode_image( image, 'cuda', 1, output_hidden_state ) return image, im_emb.to(DEVICE) # TODO add to state instead of shared across all glob_idx = 0 def next_image(embs, ys, calibrate_prompts): global glob_idx glob_idx = glob_idx + 1 # handle case where every instance of calibration prompts is 'Neither' or 'Like' or 'Dislike' if len(calibrate_prompts) == 0 and len(list(set(ys))) <= 1: embs.append(.01*torch.randn(1, 1024)) embs.append(.01*torch.randn(1, 1024)) ys.append(0) ys.append(1) with torch.no_grad(): if len(calibrate_prompts) > 0: print('######### Calibrating with sample prompts #########') prompt = calibrate_prompts.pop(0) print(prompt) image, img_emb = predict(prompt) embs.append(img_emb) return image, embs, ys, calibrate_prompts else: print('######### Roaming #########') # sample a .8 of rated embeddings for some stochasticity, or at least two embeddings. n_to_choose = max(int(len(embs)*.8), 2) indices = random.sample(range(len(embs)), n_to_choose) # also add the latest 0 and the latest 1 has_0 = False has_1 = False for i in reversed(range(len(ys))): if ys[i] == 0 and has_0 == False: indices.append(i) has_0 = True elif ys[i] == 1 and has_1 == False: indices.append(i) has_1 = True if has_0 and has_1: break feature_embs = np.array(torch.cat([embs[i] for i in indices]).to('cpu')) scaler = preprocessing.StandardScaler().fit(feature_embs) feature_embs = scaler.transform(feature_embs) lin_class = LinearSVC(max_iter=50000, dual='auto', class_weight='balanced').fit(feature_embs, np.array([ys[i] for i in indices])) lin_class.coef_ = torch.tensor(lin_class.coef_, dtype=torch.double) lin_class.coef_ = (lin_class.coef_.flatten() / (lin_class.coef_.flatten().norm())).unsqueeze(0) rng_prompt = random.choice(prompt_list) w = 1# if len(embs) % 2 == 0 else 0 im_emb = w * lin_class.coef_.to(device=DEVICE, dtype=torch.float16) prompt= '' if glob_idx % 2 == 0 else rng_prompt print(prompt, len(ys)) image, im_emb = predict(prompt, im_emb) embs.append(im_emb) if len(embs) > 100: embs.pop(0) ys.pop(0) return image, embs, ys, calibrate_prompts def start(_, embs, ys, calibrate_prompts): image, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts) return [ gr.Button(value='Like (L)', interactive=True), gr.Button(value='Neither (Space)', interactive=True), gr.Button(value='Dislike (A)', interactive=True), gr.Button(value='Start', interactive=False), image, embs, ys, calibrate_prompts ] def choose(choice, embs, ys, calibrate_prompts): if choice == 'Like (L)': choice = 1 elif choice == 'Neither (Space)': _ = embs.pop(-1) img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts) return img, embs, ys, calibrate_prompts else: choice = 0 ys.append(choice) img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts) return img, embs, ys, calibrate_prompts css = '''.gradio-container{max-width: 700px !important} #description{text-align: center} #description h1, #description h3{display: block} #description p{margin-top: 0} .fade-in-out {animation: fadeInOut 3s forwards} @keyframes fadeInOut { 0% { background: var(--bg-color); } 100% { background: var(--button-secondary-background-fill); } } ''' js_head = ''' ''' with gr.Blocks(css=css, head=js_head) as demo: gr.Markdown('''### Zahir: Generative Recommenders for Unprompted, Scalable Exploration Explore the latent space without text prompts, based on your preferences. Learn more on [the write-up](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/). ''', elem_id="description") embs = gr.State([]) ys = gr.State([]) calibrate_prompts = gr.State([ "4k photo", 'surrealist art', # 'a psychedelic, fractal view', 'a beautiful collage', 'abstract art', 'an eldritch image', 'a sketch', # 'a city full of darkness and graffiti', '', ]) with gr.Row(elem_id='output-image'): img = gr.Image(interactive=False, elem_id='output-image',width=700) with gr.Row(equal_height=True): b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike") b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither") b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like") b1.click( choose, [b1, embs, ys, calibrate_prompts], [img, embs, ys, calibrate_prompts] ) b2.click( choose, [b2, embs, ys, calibrate_prompts], [img, embs, ys, calibrate_prompts] ) b3.click( choose, [b3, embs, ys, calibrate_prompts], [img, embs, ys, calibrate_prompts] ) with gr.Row(): b4 = gr.Button(value='Start') b4.click(start, [b4, embs, ys, calibrate_prompts], [b1, b2, b3, b4, img, embs, ys, calibrate_prompts]) with gr.Row(): html = gr.HTML('''
You will calibrate for several prompts and then roam.''') demo.launch() # Share your demo with just 1 extra parameter 🚀