PopYou / app.py
AmitIsraeli's picture
remove tags
10f4eaf
import torch
from models import VQVAE, build_vae_var
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, SiglipTextModel
from peft import LoraConfig, get_peft_model
from torchvision.transforms import ToPILImage
import random
import gradio as gr
class SimpleAdapter(nn.Module):
def __init__(self, input_dim=512, hidden_dim=1024, out_dim=1024):
super(SimpleAdapter, self).__init__()
self.layer1 = nn.Linear(input_dim, hidden_dim)
self.norm0 = nn.LayerNorm(input_dim)
self.activation1 = nn.GELU()
self.layer2 = nn.Linear(hidden_dim, out_dim)
self.norm2 = nn.LayerNorm(out_dim)
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight, gain=0.001)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
x = self.norm0(x)
x = self.layer1(x)
x = self.activation1(x)
x = self.layer2(x)
x = self.norm2(x)
return x
class InferenceTextVAR(nn.Module):
def __init__(self, pl_checkpoint=None, start_class_id=578, hugging_face_token=None, siglip_model='google/siglip-base-patch16-224', device="cpu", MODEL_DEPTH=16):
super(InferenceTextVAR, self).__init__()
self.device = device
self.class_id = start_class_id
# Define layers
patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
self.vae, self.var = build_vae_var(
V=4096, Cvae=32, ch=160, share_quant_resi=4,
device=device, patch_nums=patch_nums,
num_classes=1000, depth=MODEL_DEPTH, shared_aln=False,
)
self.text_processor = AutoTokenizer.from_pretrained(siglip_model, token=hugging_face_token)
self.siglip_text_encoder = SiglipTextModel.from_pretrained(siglip_model, token=hugging_face_token).to(device)
self.adapter = SimpleAdapter(
input_dim=self.siglip_text_encoder.config.hidden_size,
out_dim=self.var.C # Ensure dimensional consistency
).to(device)
self.apply_lora_to_var()
if pl_checkpoint is not None:
state_dict = torch.load(pl_checkpoint, map_location="cpu")['state_dict']
var_state_dict = {k[len('var.'):]: v for k, v in state_dict.items() if k.startswith('var.')}
vae_state_dict = {k[len('vae.'):]: v for k, v in state_dict.items() if k.startswith('vae.')}
adapter_state_dict = {k[len('adapter.'):]: v for k, v in state_dict.items() if k.startswith('adapter.')}
self.var.load_state_dict(var_state_dict)
self.vae.load_state_dict(vae_state_dict)
self.adapter.load_state_dict(adapter_state_dict)
del self.vae.encoder
def apply_lora_to_var(self):
"""
Applies LoRA (Low-Rank Adaptation) to the VAR model.
"""
def find_linear_module_names(model):
linear_module_names = []
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
linear_module_names.append(name)
return linear_module_names
linear_module_names = find_linear_module_names(self.var)
lora_config = LoraConfig(
r=8,
lora_alpha=32,
target_modules=linear_module_names,
lora_dropout=0.05,
bias="none",
)
self.var = get_peft_model(self.var, lora_config)
@torch.no_grad()
def generate_image(self, text, beta=1, seed=None, more_smooth=False, top_k=0, top_p=0.5):
if seed is None:
seed = random.randint(0, 2**32 - 1)
inputs = self.text_processor([text], padding="max_length", return_tensors="pt").to(self.device)
outputs = self.siglip_text_encoder(**inputs)
pooled_output = outputs.pooler_output # pooled (EOS token) states
pooled_output = F.normalize(pooled_output, p=2, dim=-1) # Normalize delta condition
cond_delta = F.normalize(pooled_output, p=2, dim=-1).to(self.device) # Use correct device
cond_delta = self.adapter(cond_delta)
cond_delta = F.normalize(cond_delta, p=2, dim=-1) # Normalize delta condition
generated_images = self.var.autoregressive_infer_cfg(
B=1,
label_B=self.class_id,
delta_condition=cond_delta[:1],
beta=beta,
alpha=1,
top_k=top_k,
top_p=top_p,
more_smooth=more_smooth,
g_seed=seed
)
image = ToPILImage()(generated_images[0].cpu())
return image
if __name__ == '__main__':
# Initialize the model
checkpoint = 'VARtext_v1.pth' # Replace with your actual checkpoint path
device = 'cpu' if not torch.cuda.is_available() else 'cuda'
model = InferenceTextVAR(device=device)
model.load_state_dict(torch.load(checkpoint, map_location=device))
model.to(device)
def generate_image_gradio(text, beta=1.0, seed=None, more_smooth=False, top_k=0, top_p=0.9):
print(f"Generating image for text: {text}\n"
f"beta: {beta}\n"
f"seed: {seed}\n"
f"more_smooth: {more_smooth}\n"
f"top_k: {top_k}\n"
f"top_p: {top_p}\n")
image = model.generate_image(text, beta=beta, seed=seed, more_smooth=more_smooth, top_k=int(top_k), top_p=top_p)
return image
with gr.Blocks(css="""
.project-item {margin-bottom: 30px;}
.project-description {margin-top: 20px;}
.github-button, .huggingface-button, .wandb-button {
display: inline-block; margin-left: 10px; text-decoration: none; font-size: 14px;
padding: 5px 10px; background-color: #f0f0f0; border-radius: 5px; color: black;
}
.project-content {display: flex; flex-direction: row;}
.project-description {flex: 2; padding-right: 20px;}
.project-options-image {flex: 1;}
.funko-image {width: 100%; max-width: 300px;}
""") as demo:
gr.Markdown("""
# PopYou2 - VAR Text
<!-- Project Links -->
[![GitHub](https://img.shields.io/badge/GitHub-Repository-blue?logo=github)](https://github.com/amit154154/VAR_clip)
[![Weights & Biases](https://img.shields.io/badge/Weights%20%26%20Biases-Report-orange?logo=weightsandbiases)](https://api.wandb.ai/links/amit154154/cqccmfsl)
## Project Explanation
- **Dataset Generation:** Generated a comprehensive dataset of approximately 100,000 Funko Pop! images with detailed prompts using [SDXL Turbo](https://huggingface.co/stabilityai/sdxl-turbo) for high-quality data creation.
- **Model Fine-tuning:** Fine-tuned the [Visual AutoRegressive (VAR)](https://arxiv.org/abs/2404.02905) model, pretrained on ImageNet, to adapt it for Funko Pop! generation by injecting a custom embedding representing the "doll" class.
- **Adapter Training:** Trained an adapter with the frozen [SigLIP image encoder](https://github.com/FoundationVision/VAR) and a lightweight LoRA module to map image embeddings to text representation in a large language model.
- **Text-to-Image Generation:** Enabled text-to-image generation by replacing the SigLIP image encoder with its text encoder, retaining frozen components such as the VAE and generator for efficiency and quality.
## Generate Your Own Funko Pop!
""")
with gr.Tab("Generate Image"):
with gr.Row():
with gr.Column(scale=1):
text_input = gr.Textbox(label="Input Text", placeholder="Enter a description for your Funko Pop!")
beta_input = gr.Slider(label="Beta", minimum=0.0, maximum=2.5, step=0.05, value=1.0)
seed_input = gr.Number(label="Seed", value=None)
more_smooth_input = gr.Checkbox(label="More Smooth", value=False)
top_k_input = gr.Number(label="Top K", value=0)
top_p_input = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, step=0.01, value=0.5)
generate_button = gr.Button("Generate Image")
with gr.Column(scale=1):
image_output = gr.Image(label="Generated Image")
generate_button.click(
generate_image_gradio,
inputs=[text_input, beta_input, seed_input, more_smooth_input, top_k_input, top_p_input],
outputs=image_output
)
gr.Markdown("## Examples")
with gr.Row():
with gr.Column():
gr.Markdown("### Example 1")
gr.Markdown("A Funko Pop figure of a yellow robot Tom Cruise with headphones on a white background")
example1_image = gr.Image(value="examples/tom_cruise_robot.png") # Replace with the actual path
with gr.Column():
gr.Markdown("### Example 2")
gr.Markdown("A Funko Pop figure of an alien Scarlett Johansson holding a shield on a white background")
example2_image = gr.Image(value="examples/alien_Scarlett_Johansson.png") # Replace with the actual path
with gr.Column():
gr.Markdown("### Example 3")
gr.Markdown("A Funko Pop figure of a woman with a hat and pink long hair and blue dress on a white background")
example3_image = gr.Image(value="examples/woman_pink.png") # Replace with the actual path
gr.Markdown("""
## Customize Your Funko Pop!
Build your own Funko Pop! by selecting options below and clicking "Generate Custom Funko Pop!".
""")
def update_custom_image(famous_name, character, action):
# Build the prompt based on the selections
parts = []
if famous_name != "None":
parts.append(f"a Funko Pop figure of {famous_name}")
else:
parts.append("a Funko Pop figure")
if character != "None":
parts.append(f"styled as a {character}")
if action != "None":
parts.append(f"performing {action}")
parts.append("on a white background")
prompt = ", ".join(parts)
image = model.generate_image(prompt)
return image
famous_name_input = gr.Dropdown(choices=["None", "Donald Trump", "Johnny Depp", "Oprah Winfrey,Lebron James"], label="Famous Name", value="None")
character_input = gr.Dropdown(choices=["None", "Alien", "Robot"], label="Character", value="None")
action_input = gr.Dropdown(choices=["None", "Playing the Guitar", "Holding the Sword","wearing headphone"], label="Action", value="None")
custom_generate_button = gr.Button("Generate Custom Funko Pop!")
custom_image_output = gr.Image(label="Custom Funko Pop!")
custom_generate_button.click(
update_custom_image,
inputs=[famous_name_input, character_input, action_input],
outputs=custom_image_output
)
demo.launch()