Spaces:
Build error
Build error
# ------------------------------------------------------------------------------------ | |
# Minimal DALL-E | |
# Copyright (c) 2021 KakaoBrain. All Rights Reserved. | |
# Licensed under the Apache License, Version 2.0 [see LICENSE for details] | |
# ------------------------------------------------------------------------------------ | |
import os | |
import torch | |
import torch.nn as nn | |
import pytorch_lightning as pl | |
from typing import Optional, Tuple, Union | |
from omegaconf import OmegaConf | |
from torch.cuda.amp import autocast | |
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR | |
from torch.nn import functional as F | |
from .stage1.vqgan import VQGAN | |
from .stage2.transformer import Transformer1d, iGPT | |
from .stage2.layers import Block | |
from .. import utils | |
from ..utils.config import get_base_config | |
from ..utils.sampling import sampling, sampling_igpt, get_positional_encoding, sampling_prefix, sampling_conditional | |
from ..utils.utils import save_image | |
from .tokenizer import build_tokenizer | |
import numpy as np | |
from .stage2.layers import CrossAttentionLayer | |
_MODELS = { | |
'minDALL-E/1.3B': 'https://arena.kakaocdn.net/brainrepo/models/minDALL-E/57b008f02ceaa02b779c8b7463143315/1.3B.tar.gz' | |
} | |
class Dalle(pl.LightningModule): | |
def __init__(self, | |
config: OmegaConf) -> None: | |
super().__init__() | |
self.tokenizer = None | |
self.stage1 = VQGAN(n_embed=config.stage1.n_embed, | |
embed_dim=config.stage1.embed_dim, | |
hparams=config.stage1.hparams) | |
self.stage2 = Transformer1d(vocab_size_txt=config.stage2.vocab_size_txt, | |
vocab_size_img=config.stage2.vocab_size_img, | |
hparams=config.stage2.hparams) | |
self.config = config | |
self.config_stage1 = config.stage1 | |
self.config_stage2 = config.stage2 | |
self.config_dataset = config.dataset | |
# # make the parameters in stage 1 not trainable | |
# self.stage1.eval() | |
# for p in self.stage1.parameters(): | |
# p.requires_grad = False | |
def from_pretrained(cls, args) -> Tuple[nn.Module, OmegaConf]: | |
path = args.model_name_or_path | |
config_new = OmegaConf.load(os.path.join(path, 'config.yaml')) | |
if args.do_train: | |
config_base = get_base_config('finetuning') | |
config_update = OmegaConf.merge(config_base, config_new) | |
for key, val in vars(args).items(): | |
if key in config_update.optimizer.keys(): | |
OmegaConf.update(config_update, "optimizer.%s" % key, val, merge=False) | |
if key in config_update.experiment.keys(): | |
OmegaConf.update(config_update, "experiment.%s" % key, val, merge=False) | |
else: | |
config_base = get_base_config('default') | |
config_update = OmegaConf.merge(config_base, config_new) | |
model = cls(config_update) | |
model.tokenizer = build_tokenizer(os.path.join(path, 'tokenizer'), | |
context_length=model.config_dataset.context_length, | |
lowercase=True, | |
dropout=None) | |
print("Loading models from checkpoint %s" % path) | |
if hasattr(args, 'dalle_path') and args.dalle_path and args.dalle_path.endswith('.pth'): | |
model.load_state_dict(torch.load(args.dalle_path)["model_state_dict"]) | |
else: | |
model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt')) | |
model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt')) | |
return model, config_update | |
def sampling(self, | |
prompt: Union[str, torch.LongTensor], | |
top_k: int = 256, | |
top_p: Optional[float] = None, | |
softmax_temperature: float = 1.0, | |
num_candidates: int = 96, | |
device: str = 'cuda:0', | |
use_fp16: bool = True) -> torch.FloatTensor: | |
self.stage1.eval() | |
self.stage2.eval() | |
if type(prompt) == str: | |
tokens = self.tokenizer.encode(prompt) | |
tokens = torch.LongTensor(tokens.ids) | |
else: | |
tokens = prompt | |
tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0) | |
# Check if the encoding works as intended | |
# print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0]) | |
tokens = tokens.to(device) | |
codes = sampling(self.stage2, | |
tokens, | |
top_k=top_k, | |
top_p=top_p, | |
softmax_temperature=softmax_temperature, | |
use_fp16=use_fp16) | |
codes = codes.view(num_candidates, 16, 16) # [B, 16, 16] | |
pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256] | |
return pixels | |
def forward(self, | |
images: torch.FloatTensor, | |
texts: Optional[torch.LongTensor], | |
past=None | |
) -> tuple: | |
B, C, H, W = images.shape | |
with torch.no_grad(): | |
with autocast(enabled=False): | |
codes = self.stage1.get_codes(images).detach() | |
pos_enc_tokens = get_positional_encoding(texts, mode='1d') | |
codes = codes.clone().detach() | |
pos_enc_code = get_positional_encoding(codes, mode='1d') | |
# codes = codes.unsqueeze(-1) | |
# pos_enc_code = pos_enc_code.unsqueeze(-1) | |
logits_img, logits_txt = self.stage2(codes, texts, pos_enc_code, pos_enc_tokens, past) | |
return logits_img, logits_txt, codes | |
def training_step(self, batch, batch_idx): | |
images, texts = batch | |
logits_img, logits_txt, codes = self(images, texts) | |
loss_img = F.cross_entropy(logits_img.view(-1, logits_img.shape[-1]), codes.view(-1)) | |
loss_txt = F.cross_entropy(logits_txt.view(-1, logits_txt.shape[-1]), texts[:, 1:].reshape(-1)) | |
self.log("train/loss_img", loss_img, on_step=True, on_epoch=True, prog_bar=False, logger=True) | |
self.log("train/loss_txt", loss_txt, on_step=True, on_epoch=True, prog_bar=False, logger=True) | |
return loss_img + loss_txt | |
def validation_step(self, batch, batch_idx): | |
images, texts = batch | |
logits_img, logits_txt, codes = self(images, texts) | |
# print(logits_img.shape, logits_txt.shape, codes.shape, texts.shape) | |
loss_img = F.cross_entropy(logits_img.view(-1, logits_img.shape[-1]), codes.view(-1)) | |
loss_txt = F.cross_entropy(logits_txt.view(-1, logits_txt.shape[-1]), texts[:, 1:].reshape(-1)) | |
self.log("val/loss_img", loss_img, on_step=False, on_epoch=True, prog_bar=False, logger=True) | |
self.log("val/loss_txt", loss_txt, on_step=False, on_epoch=True, prog_bar=False, logger=True) | |
return loss_img + loss_txt | |
def configure_optimizers(self): | |
assert self.config.optimizer.opt_type == 'adamW' | |
# assert self.config.optimizer.sched_type == 'cosine' | |
opt = torch.optim.AdamW(self.parameters(), | |
lr=self.config.optimizer.learning_rate, | |
betas=self.config.optimizer.betas, | |
weight_decay=self.config.optimizer.weight_decay) | |
# sched = CosineAnnealingLR(opt, | |
# T_max=self.config.optimizer.max_steps, | |
# eta_min=self.config.optimizer.min_lr) | |
def lr_lambda(current_step: int): | |
return max( | |
0.0, float(self.config.optimizer.max_steps - current_step) / float(max(1, self.config.optimizer.max_steps)) | |
) | |
sched = LambdaLR(opt, lr_lambda) | |
sched = { | |
'scheduler': sched, | |
'name': 'linear' | |
} | |
return [opt], [sched] | |
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, | |
on_tpu=False, using_native_amp=False, using_lbfgs=False): | |
optimizer.step(closure=optimizer_closure) | |
self.lr_schedulers().step() | |
self.log("lr", self.lr_schedulers().get_last_lr()[0], on_step=True, on_epoch=False, prog_bar=True, logger=True) | |
def on_epoch_start(self): | |
self.stage1.eval() | |
class ImageGPT(pl.LightningModule): | |
def __init__(self, | |
config: OmegaConf) -> None: | |
super().__init__() | |
self.stage1 = VQGAN(n_embed=config.stage1.n_embed, | |
embed_dim=config.stage1.embed_dim, | |
hparams=config.stage1.hparams) | |
self.stage2 = iGPT(vocab_size_img=config.stage2.vocab_size_img, | |
use_cls_cond=config.stage2.use_cls_cond, | |
hparams=config.stage2.hparams) | |
self.config = config | |
self.use_cls_cond = config.stage2.use_cls_cond | |
# make the parameters in stage 1 not trainable | |
self.stage1.eval() | |
for p in self.stage1.parameters(): | |
p.requires_grad = False | |
def from_pretrained(cls, | |
path_upstream: str, | |
path_downstream: str) -> Tuple[nn.Module, OmegaConf]: | |
config_base = get_base_config(use_default=False) | |
config_down = OmegaConf.load(path_downstream) | |
config_down = OmegaConf.merge(config_base, config_down) | |
model = cls(config_down) | |
model.stage1.from_ckpt(os.path.join(path_upstream, 'stage1_last.ckpt'), strict=True) | |
model.stage2.from_ckpt(os.path.join(path_upstream, 'stage2_last.ckpt'), strict=False) | |
return model, config_down | |
def sample(self, | |
cls_idx: Optional[int] = None, | |
top_k: int = 256, | |
top_p: Optional[float] = None, | |
softmax_temperature: float = 1.0, | |
num_candidates: int = 16, | |
device: str = 'cuda:0', | |
use_fp16: bool = True, | |
is_tqdm: bool = True) -> torch.FloatTensor: | |
self.stage1.eval() | |
self.stage2.eval() | |
if cls_idx is None: | |
sos = self.stage2.sos.repeat(num_candidates, 1, 1) | |
else: | |
sos = torch.LongTensor([cls_idx]).to(device=device) | |
sos = sos.repeat(num_candidates) | |
sos = self.stage2.sos(sos).unsqueeze(1) | |
codes = sampling_igpt(self.stage2, | |
sos=sos, | |
top_k=top_k, | |
top_p=top_p, | |
softmax_temperature=softmax_temperature, | |
use_fp16=use_fp16, | |
is_tqdm=is_tqdm) | |
codes = codes.view(num_candidates, 16, 16) # [B, 16, 16] | |
pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256] | |
return pixels | |
def forward(self, | |
images: torch.FloatTensor, | |
labels: Optional[torch.LongTensor] = None) -> torch.FloatTensor: | |
B, C, H, W = images.shape | |
with torch.no_grad(): | |
with autocast(enabled=False): | |
codes = self.stage1.get_codes(images).detach() | |
logits = self.stage2(codes, labels) | |
return logits, codes | |
def training_step(self, batch, batch_idx): | |
images, labels = batch | |
logits, codes = self(images, labels=labels if self.use_cls_cond else None) | |
loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), codes.view(-1)) | |
self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=False, logger=True) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
images, labels = batch | |
logits, codes = self(images, labels=labels if self.use_cls_cond else None) | |
loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), codes.view(-1)) | |
self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False, logger=True) | |
return loss | |
def configure_optimizers(self): | |
assert self.config.optimizer.opt_type == 'adamW' | |
assert self.config.optimizer.sched_type == 'cosine' | |
opt = torch.optim.AdamW(self.parameters(), | |
lr=self.config.optimizer.base_lr, | |
betas=self.config.optimizer.betas, | |
weight_decay=self.config.optimizer.weight_decay) | |
sched = CosineAnnealingLR(opt, | |
T_max=self.config.optimizer.max_steps, | |
eta_min=self.config.optimizer.min_lr) | |
sched = { | |
'scheduler': sched, | |
'name': 'cosine' | |
} | |
return [opt], [sched] | |
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, | |
on_tpu=False, using_native_amp=False, using_lbfgs=False): | |
optimizer.step(closure=optimizer_closure) | |
self.lr_schedulers().step() | |
self.log("lr", self.lr_schedulers().get_last_lr()[0], on_step=True, on_epoch=False, prog_bar=True, logger=True) | |
def on_epoch_start(self): | |
self.stage1.eval() | |
class PromptDalle(Dalle): | |
"""Classification Head for transformer encoders""" | |
def __init__(self, config): | |
super().__init__(config) | |
print('Initializing the PromptTuning model') | |
self.config = config | |
self.n_embd = config.stage2.hparams.embed_dim | |
self.preseqlen = config.prompt.preseqlen | |
self.prefix_dropout = config.prompt.prefix_dropout | |
# DIFFERENT PARAMETRIZATION: | |
print('[Full prompt-tuning Setting :) ]') | |
self.input_tokens = torch.arange(self.preseqlen).long() | |
self.wte = nn.Embedding(self.preseqlen, self.n_embd) | |
self.control_trans = nn.Sequential( | |
nn.Linear(self.n_embd, self.n_embd), | |
nn.Tanh(), | |
nn.Linear(self.n_embd, self.n_embd)) | |
self.get_prompt = self.get_prompt_p5 | |
self.dropout = nn.Dropout(self.prefix_dropout) | |
###### NUM PARAMS ######### | |
total_param = 0 | |
for name, param in self.named_parameters(): | |
# print(param.shape) | |
total_param += param.numel() | |
print('Total parameters is {}'.format(total_param)) | |
def from_pretrained(cls, args) -> Tuple[nn.Module, OmegaConf]: | |
# if not args.model_name_or_path: | |
# args.model_name_or_path = args.prefix_model_name_or_path | |
path = args.prefix_model_name_or_path | |
path = _MODELS[path] if path in _MODELS else path | |
path = utils.realpath_url_or_path(path, root=os.path.expanduser("~/.cache/minDALL-E")) | |
config_base = get_base_config('prompt_tuning') | |
config_new = OmegaConf.load(os.path.join(path, 'config.yaml')) | |
config_update = OmegaConf.merge(config_base, config_new) | |
for key, val in vars(args).items(): | |
if key in config_update.prompt.keys(): | |
OmegaConf.update(config_update, "prompt.%s" % key, val, merge=False) | |
if key in config_update.optimizer.keys(): | |
OmegaConf.update(config_update, "optimizer.%s" % key, val, merge=False) | |
if key in config_update.experiment.keys(): | |
OmegaConf.update(config_update, "experiment.%s" % key, val, merge=False) | |
model = cls(config_update) | |
model.tokenizer = build_tokenizer(os.path.join(path, 'tokenizer'), | |
context_length=model.config_dataset.context_length, | |
lowercase=True, | |
dropout=None) | |
if args.model_name_or_path: | |
print("Loading model from pretrained checkpoint %s" % args.model_name_or_path) | |
# model.from_ckpt(args.model_name_or_path) | |
try: | |
model.load_state_dict(torch.load(args.model_name_or_path)['state_dict']) | |
except KeyError: | |
model.load_state_dict(torch.load(args.model_name_or_path)['model_state_dict']) | |
else: | |
print("Loading models from checkpoint %s" % path) | |
model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt')) | |
model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt')) | |
return model, config_update | |
def get_prompt_p5(self, bsz=None, eval=False): | |
input_tokens = self.input_tokens.unsqueeze(0).expand(bsz, -1).to(self.device) | |
temp_control = self.wte(input_tokens) | |
past_key_values = self.control_trans(temp_control) #bsz, seqlen, layer*emb | |
if not eval: | |
past_key_values = self.dropout(past_key_values) | |
return past_key_values | |
def forward(self, | |
images: torch.FloatTensor, | |
texts: Optional[torch.LongTensor], | |
**kwargs, | |
): | |
#{"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn':tgt_attn, 'src':src} | |
B, C, H, W = images.shape | |
prompt = self.get_prompt(bsz=B) | |
pos_enc_prompt = get_positional_encoding(self.input_tokens.unsqueeze(0).expand(B, -1).to(self.device), mode='1d') | |
# if self.mode_para == 2 and src_attn is not None and tgt_attn is not None: | |
# attention_mask = torch.cat([src_attn, tgt_attn], dim=1) | |
with torch.no_grad(): | |
with autocast(enabled=False): | |
codes = self.stage1.get_codes(images).detach() | |
pos_enc_tokens = get_positional_encoding(texts, mode='1d') | |
codes = codes.clone().detach() | |
pos_enc_code = get_positional_encoding(codes, mode='1d') | |
# codes = codes.unsqueeze(-1) | |
# pos_enc_code = pos_enc_code.unsqueeze(-1) | |
# print(images.shape, codes.shape, texts.shape) | |
logits_img, logits_txt = self.stage2(codes, texts, pos_enc_code, pos_enc_tokens, prompt=prompt, pos_prompt=pos_enc_prompt) | |
return logits_img, logits_txt, codes | |
def sampling(self, | |
tokens: torch.LongTensor, | |
prompt: torch.FloatTensor, | |
top_k: int = 256, | |
top_p: Optional[float] = None, | |
softmax_temperature: float = 1.0, | |
num_candidates: int = 96, | |
device: str = 'cuda:0', | |
use_fp16: bool = True, | |
labels = None) -> torch.FloatTensor: | |
self.stage1.eval() | |
self.stage2.eval() | |
# tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0) | |
tokens = tokens.to(device) | |
pos_enc_prompt = get_positional_encoding(self.input_tokens.unsqueeze(0).expand(num_candidates, -1).to(self.device), mode='1d') | |
codes = sampling(self.stage2, | |
tokens, | |
top_k=top_k, | |
top_p=top_p, | |
softmax_temperature=softmax_temperature, | |
use_fp16=use_fp16, | |
prompt=prompt, | |
pos_prompt=pos_enc_prompt) | |
codes = codes.view(-1, 16, 16) # [B, 16, 16] | |
pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256] | |
return pixels | |
def predict_step(self, batch, batch_idx, return_images=False): | |
orig_images, texts = batch | |
# extra for checks | |
logits_img, logits_txt, codes = self(orig_images, texts) | |
pred = torch.argmax(logits_img.view(-1, logits_img.shape[-1]), dim=-1) | |
bs = orig_images.shape[0] | |
pred = pred.view(bs, 16, 16) # [B, 16, 16] | |
pixels = torch.clamp(self.stage1.decode_code(pred) * 0.5 + 0.5, 0, 1).cpu().numpy() # [B, 256, 256] | |
pixels = np.transpose(pixels, (0, 2, 3, 1)) | |
# print(texts.shape, orig_images.shape) | |
prompt = self.get_prompt(bsz=5, eval=True) | |
images = [] | |
for i, t in enumerate(texts): | |
pixels = self.sampling(t, prompt, top_k=16, num_candidates=5, labels=codes[i]).cpu().numpy() | |
pixels = np.transpose(pixels, (0, 2, 3, 1)) | |
images.append(pixels) | |
if return_images: | |
return images | |
else: | |
save_image(orig_images, pixels, './out/images/pororo_prompt', batch_idx+10) | |
save_image(orig_images, images, './out/images/pororo_prompt', batch_idx) | |
class PrefixTuningDalle(Dalle): | |
"""Classification Head for transformer encoders""" | |
def __init__(self, config): | |
super().__init__(config) | |
print('Initializing the PrefixTuning model') | |
self.config = config | |
self.match_n_layer = config.stage2.hparams.n_layers | |
self.match_n_head = config.stage2.hparams.n_heads | |
self.match_n_embd = config.stage2.hparams.embed_dim // config.stage2.hparams.n_heads | |
self.n_embd = config.stage2.hparams.embed_dim | |
self.optim_prefix = config.prefix.optim_prefix | |
self.preseqlen = config.prefix.preseqlen | |
self.prefix_dropout = config.prefix.prefix_dropout | |
self.init_random = config.prefix.init_random | |
self.hidden_dim_prefix = config.prefix.hidden_dim_prefix | |
self.lowdata_token = config.prefix.lowdata_token | |
self.init_shallow = config.prefix.init_shallow | |
self.init_shallow_word = config.prefix.init_shallow_word | |
self.mode_para = 0 | |
print('PrefixTuning') | |
print('preseqlen is {}, optimizing the prefix directly'.format(self.preseqlen)) | |
# DIFFERENT PARAMETRIZATION: | |
print('[Full prefix-tuning Setting :) ]') | |
self.input_tokens = torch.arange(self.preseqlen).long() | |
self.wte = nn.Embedding(self.preseqlen, self.n_embd) | |
self.control_trans = nn.Sequential( | |
nn.Linear(self.n_embd, self.hidden_dim_prefix), | |
nn.Tanh(), | |
nn.Linear(self.hidden_dim_prefix, self.match_n_layer * 2 * self.n_embd)) | |
self.get_prompt = self.get_prompt_p5 | |
self.dropout = nn.Dropout(self.prefix_dropout) | |
###### NUM PARAMS ######### | |
total_param = 0 | |
for name, param in self.named_parameters(): | |
# print(param.shape) | |
total_param += param.numel() | |
print('Total parameters is {}'.format(total_param)) | |
def from_pretrained(cls, args) -> Tuple[nn.Module, OmegaConf]: | |
# if not args.model_name_or_path: | |
# args.model_name_or_path = args.prefix_model_name_or_path | |
path = args.prefix_model_name_or_path | |
path = _MODELS[path] if path in _MODELS else path | |
path = utils.realpath_url_or_path(path, root=os.path.expanduser("~/.cache/minDALL-E")) | |
config_base = get_base_config('prefixtuning') | |
config_new = OmegaConf.load(os.path.join(path, 'config.yaml')) | |
config_update = OmegaConf.merge(config_base, config_new) | |
for key, val in vars(args).items(): | |
if key in config_update.prefix.keys(): | |
OmegaConf.update(config_update, "prefix.%s" % key, val, merge=False) | |
if key in config_update.optimizer.keys(): | |
OmegaConf.update(config_update, "optimizer.%s" % key, val, merge=False) | |
if key in config_update.experiment.keys(): | |
OmegaConf.update(config_update, "experiment.%s" % key, val, merge=False) | |
model = cls(config_update) | |
model.tokenizer = build_tokenizer(os.path.join(path, 'tokenizer'), | |
context_length=model.config_dataset.context_length, | |
lowercase=True, | |
dropout=None) | |
if args.model_name_or_path: | |
print("Loading model from pretrained checkpoint %s" % args.model_name_or_path) | |
# model.from_ckpt(args.model_name_or_path) | |
try: | |
model.load_state_dict(torch.load(args.model_name_or_path)['state_dict']) | |
except KeyError: | |
model.load_state_dict(torch.load(args.model_name_or_path)['model_state_dict']) | |
else: | |
print("Loading models from checkpoint %s" % path) | |
model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt')) | |
model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt')) | |
return model, config_update | |
def get_prompt_p5(self, bsz=None, eval=False): | |
input_tokens = self.input_tokens.unsqueeze(0).expand(bsz, -1).to(self.device) | |
temp_control = self.wte(input_tokens) | |
past_key_values = self.control_trans(temp_control) #bsz, seqlen, layer*emb | |
bsz, seqlen, _ = past_key_values.shape | |
past_key_values = past_key_values.view(bsz, seqlen, self.match_n_layer * 2, self.match_n_head, | |
self.match_n_embd) | |
if not eval: | |
past_key_values = self.dropout(past_key_values) | |
# past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) | |
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]) | |
# print(past_key_values.shape) | |
return past_key_values.split(2) | |
def forward(self, | |
images: torch.FloatTensor, | |
texts: Optional[torch.LongTensor], | |
**kwargs, | |
): | |
#{"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn':tgt_attn, 'src':src} | |
B, C, H, W = images.shape | |
if self.mode_para == 2: | |
past_key_values_prompt = self.get_prompt(bsz=B) | |
else: | |
past_key_values_prompt = self.get_prompt(bsz=B) | |
# if self.mode_para == 2 and src_attn is not None and tgt_attn is not None: | |
# attention_mask = torch.cat([src_attn, tgt_attn], dim=1) | |
with torch.no_grad(): | |
with autocast(enabled=False): | |
codes = self.stage1.get_codes(images).detach() | |
pos_enc_tokens = get_positional_encoding(texts, mode='1d') | |
codes = codes.clone().detach() | |
pos_enc_code = get_positional_encoding(codes, mode='1d') | |
# codes = codes.unsqueeze(-1) | |
# pos_enc_code = pos_enc_code.unsqueeze(-1) | |
# print(images.shape, codes.shape, texts.shape) | |
logits_img, logits_txt = self.stage2(codes, texts, pos_enc_code, pos_enc_tokens, past_key_values_prompt) | |
return logits_img, logits_txt, codes | |
def sampling(self, | |
tokens: torch.LongTensor, | |
past: torch.FloatTensor, | |
top_k: int = 256, | |
top_p: Optional[float] = None, | |
softmax_temperature: float = 1.0, | |
num_candidates: int = 96, | |
device: str = 'cuda:0', | |
use_fp16: bool = True, | |
labels = None) -> torch.FloatTensor: | |
self.stage1.eval() | |
self.stage2.eval() | |
if len(past.shape) == 6: | |
n_layers, temp, bs, n_heads, seq_len, n_dim = past.shape | |
past = past.view(n_layers, temp, bs*n_heads, seq_len, n_dim) | |
tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0) | |
# Check if the encoding works as intended | |
# print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0]) | |
tokens = tokens.to(device) | |
codes = sampling_prefix(self.stage2, | |
tokens, | |
past, | |
top_k=top_k, | |
top_p=top_p, | |
softmax_temperature=softmax_temperature, | |
use_fp16=use_fp16, | |
labels = None if labels is None else labels.view(-1)) | |
# codes = sampling(self.stage2, | |
# tokens, | |
# top_k=top_k, | |
# top_p=top_p, | |
# softmax_temperature=softmax_temperature, | |
# use_fp16=use_fp16) | |
codes = codes.view(num_candidates, 16, 16) # [B, 16, 16] | |
pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256] | |
return pixels | |
def training_step(self, batch, batch_idx): | |
images, texts = batch | |
logits_img, logits_txt, codes = self(images, texts) | |
loss_img = F.cross_entropy(logits_img.view(-1, logits_img.shape[-1]), codes.view(-1)) | |
loss_txt = F.cross_entropy(logits_txt.view(-1, logits_txt.shape[-1]), texts[:, 1:].reshape(-1)) | |
self.log("train/loss_img", loss_img, on_step=True, on_epoch=True, prog_bar=False, logger=True) | |
self.log("train/loss_txt", loss_txt, on_step=True, on_epoch=True, prog_bar=False, logger=True) | |
return loss_img + loss_txt | |
def validation_step(self, batch, batch_idx): | |
images, texts = batch | |
logits_img, logits_txt, codes = self(images, texts) | |
# print(logits_img.shape, logits_txt.shape, codes.shape, texts.shape) | |
loss_img = F.cross_entropy(logits_img.view(-1, logits_img.shape[-1]), codes.view(-1)) | |
loss_txt = F.cross_entropy(logits_txt.view(-1, logits_txt.shape[-1]), texts[:, 1:].reshape(-1)) | |
self.log("val/loss_img", loss_img, on_step=False, on_epoch=True, prog_bar=False, logger=True) | |
self.log("val/loss_txt", loss_txt, on_step=False, on_epoch=True, prog_bar=False, logger=True) | |
return loss_img + loss_txt | |
def predict_step(self, batch, batch_idx, return_images=False): | |
orig_images, texts = batch | |
# extra for checks | |
logits_img, logits_txt, codes = self(orig_images, texts) | |
pred = torch.argmax(logits_img.view(-1, logits_img.shape[-1]), dim=-1) | |
bs = orig_images.shape[0] | |
pred = pred.view(bs, 16, 16) # [B, 16, 16] | |
pixels = torch.clamp(self.stage1.decode_code(pred) * 0.5 + 0.5, 0, 1).cpu().numpy() # [B, 256, 256] | |
pixels = np.transpose(pixels, (0, 2, 3, 1)) | |
# print(texts.shape, orig_images.shape) | |
# concatenate the list of prompts (split by n_head) for better downstream processing | |
past_key_values_prompt = self.get_prompt(bsz=5, eval=True) | |
# print(past_key_values_prompt[0].shape, past_key_values_prompt[1].shape, len(past_key_values_prompt)) | |
past_key_values_prompt = torch.cat([x.unsqueeze(0) for x in past_key_values_prompt], dim=0) | |
n_layers, temp, bs, n_heads, seq_len, n_dim = past_key_values_prompt.shape | |
past_key_values_prompt = past_key_values_prompt.view(n_layers, temp, bs*n_heads, seq_len, n_dim) | |
# print(past_key_values_prompt.shape) | |
images = [] | |
for i, t in enumerate(texts): | |
pixels = self.sampling(t, past_key_values_prompt, top_k=16, num_candidates=5, labels=codes[i]).cpu().numpy() | |
pixels = np.transpose(pixels, (0, 2, 3, 1)) | |
images.append(pixels) | |
# images.extend([p for p in pixels]) | |
# print([i.shape for i in images]) | |
if return_images: | |
return images | |
else: | |
save_image(orig_images, pixels, './out/images/pororo_prefix', batch_idx+10) | |
save_image(orig_images, images, './out/images/pororo_prefix', batch_idx) | |
class ConditionalDalle(Dalle): | |
"""Classification Head for transformer encoders""" | |
def __init__(self, config): | |
super().__init__(config) | |
print('Initializing the Conditional Dalle model') | |
self.config = config | |
print('Setting up Cross-attention Layers') | |
self.init_cross_attention(list(range(2,42,3)), config.stage2.hparams) | |
###### NUM PARAMS ######### | |
total_param = 0 | |
for name, param in self.named_parameters(): | |
# print(param.shape) | |
total_param += param.numel() | |
print('Total parameters is {}'.format(total_param)) | |
def from_pretrained(cls, args) -> Tuple[nn.Module, OmegaConf]: | |
# if not args.model_name_or_path: | |
# args.model_name_or_path = args.prefix_model_name_or_path | |
path = args.model_name_or_path | |
config_new = OmegaConf.load(os.path.join(path, 'config.yaml')) | |
if args.do_train: | |
config_base = get_base_config('finetuning') | |
config_update = OmegaConf.merge(config_base, config_new) | |
for key, val in vars(args).items(): | |
if key in config_update.optimizer.keys(): | |
OmegaConf.update(config_update, "optimizer.%s" % key, val, merge=False) | |
if key in config_update.experiment.keys(): | |
OmegaConf.update(config_update, "experiment.%s" % key, val, merge=False) | |
else: | |
config_base = get_base_config('default') | |
config_update = OmegaConf.merge(config_base, config_new) | |
model = cls(config_update) | |
model.tokenizer = build_tokenizer(os.path.join(path, 'tokenizer'), | |
context_length=model.config_dataset.context_length, | |
lowercase=True, | |
dropout=None) | |
print(model.cross_attention_idxs) | |
# print(next(model.cross_attention_layers[0].parameters()).is_cuda) | |
if args.dalle_path: | |
print("Loading model from pretrained checkpoint %s" % args.dalle_path) | |
# model.from_ckpt(args.model_name_or_path) | |
model.load_state_dict(torch.load(args.dalle_path)['model_state_dict']) | |
else: | |
print("Loading models from checkpoint %s" % path) | |
model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt')) | |
model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt')) | |
return model, config_update | |
def init_cross_attention(self, cross_attention_layers, hparams): | |
self.cross_attention_idxs = cross_attention_layers | |
self.cross_attention_layers = [CrossAttentionLayer(ctx_len=hparams.ctx_len_img + hparams.ctx_len_txt, | |
embed_dim=hparams.embed_dim, | |
n_heads=hparams.n_heads, | |
attn_bias=hparams.attn_bias, | |
resid_pdrop=hparams.resid_pdrop, | |
attn_pdrop=hparams.attn_pdrop) for i in cross_attention_layers] | |
def forward(self, | |
images: torch.FloatTensor, | |
src_images: Optional[torch.FloatTensor], | |
texts: Optional[torch.LongTensor], | |
**kwargs, | |
): | |
#{"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn':tgt_attn, 'src':src} | |
# print(images.shape, src_images.shape, texts.shape) | |
with torch.no_grad(): | |
with autocast(enabled=False): | |
codes = self.stage1.get_codes(images).detach() | |
src_codes = self.stage1.get_codes(src_images).detach() | |
pos_enc_tokens = get_positional_encoding(texts, mode='1d') | |
codes = codes.clone().detach() | |
pos_enc_code = get_positional_encoding(codes, mode='1d') | |
src_codes = src_codes.clone().detach() | |
src_pos_enc_code = get_positional_encoding(src_codes, mode='1d') | |
# codes = codes.unsqueeze(-1) | |
# pos_enc_code = pos_enc_code.unsqueeze(-1) | |
# print(images.shape, codes.shape, texts.shape) | |
logits_img, logits_txt = self.stage2.forward_with_context(codes, texts, | |
pos_enc_code, pos_enc_tokens, src_codes, src_pos_enc_code, | |
self.cross_attention_idxs, self.cross_attention_layers) | |
# print(logits_img.shape, logits_txt.shape, codes.shape, texts.shape) | |
return logits_img, logits_txt, codes | |
def sampling(self, | |
prompt: torch.LongTensor, | |
source: torch.FloatTensor, | |
top_k: int = 256, | |
top_p: Optional[float] = None, | |
softmax_temperature: float = 1.0, | |
num_candidates: int = 96, | |
device: str = 'cuda:0', | |
use_fp16: bool = True) -> torch.FloatTensor: | |
self.stage1.eval() | |
self.stage2.eval() | |
if type(prompt) == str: | |
tokens = self.tokenizer.encode(prompt) | |
tokens = torch.LongTensor(tokens.ids) | |
else: | |
tokens = prompt | |
tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0) | |
# Check if the encoding works as intended | |
# print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0]) | |
tokens = tokens.to(device) | |
source = source.to(device) | |
with autocast(enabled=False): | |
src_codes = self.stage1.get_codes(source).detach() | |
src_codes = torch.repeat_interleave(src_codes, num_candidates, dim=0) | |
codes = sampling_conditional(self.stage2, | |
self.cross_attention_idxs, | |
self.cross_attention_layers, | |
tokens, | |
src_codes, | |
top_k=top_k, | |
top_p=top_p, | |
softmax_temperature=softmax_temperature, | |
use_fp16=use_fp16) | |
codes = codes.view(num_candidates, 16, 16) # [B, 16, 16] | |
pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256] | |
return pixels | |
def training_step(self, batch, batch_idx): | |
images, texts = batch | |
logits_img, logits_txt, codes = self(images, texts) | |
loss_img = F.cross_entropy(logits_img.view(-1, logits_img.shape[-1]), codes.view(-1)) | |
loss_txt = F.cross_entropy(logits_txt.view(-1, logits_txt.shape[-1]), texts[:, 1:].reshape(-1)) | |
self.log("train/loss_img", loss_img, on_step=True, on_epoch=True, prog_bar=False, logger=True) | |
self.log("train/loss_txt", loss_txt, on_step=True, on_epoch=True, prog_bar=False, logger=True) | |
return loss_img + loss_txt | |
def validation_step(self, batch, batch_idx): | |
images, texts = batch | |
logits_img, logits_txt, codes = self(images, texts) | |
# print(logits_img.shape, logits_txt.shape, codes.shape, texts.shape) | |
loss_img = F.cross_entropy(logits_img.view(-1, logits_img.shape[-1]), codes.view(-1)) | |
loss_txt = F.cross_entropy(logits_txt.view(-1, logits_txt.shape[-1]), texts[:, 1:].reshape(-1)) | |
self.log("val/loss_img", loss_img, on_step=False, on_epoch=True, prog_bar=False, logger=True) | |
self.log("val/loss_txt", loss_txt, on_step=False, on_epoch=True, prog_bar=False, logger=True) | |
return loss_img + loss_txt | |
def predict_step(self, batch, batch_idx): | |
orig_images, texts = batch | |
# concatenate the list of prompts (split by n_head) for better downstream processing | |
past_key_values_prompt = self.get_prompt(bsz=5) | |
past_key_values_prompt = torch.cat([x.unsqueeze(0) for x in past_key_values_prompt], dim=0) | |
images = [] | |
for t in texts: | |
pixels = self.sampling(t, past_key_values_prompt, top_k=64, num_candidates=5).cpu().numpy() | |
pixels = np.transpose(pixels, (0, 2, 3, 1)) | |
images.append(pixels) | |
# images.extend([p for p in pixels]) | |
# print([i.shape for i in images]) | |
save_image(orig_images, images, './out/images/', batch_idx) | |
class PromptConditionalDalle(Dalle): | |
"""Classification Head for transformer encoders""" | |
def __init__(self, config): | |
super().__init__(config) | |
print('Initializing the Conditional Dalle model') | |
self.config = config | |
print('Setting up Cross-attention Layers') | |
self.init_cross_attention(list(range(2,42,3)), config.stage2.hparams) | |
self.n_embd = config.stage2.hparams.embed_dim | |
self.preseqlen = config.story.preseqlen | |
self.prefix_dropout = config.story.prefix_dropout | |
# DIFFERENT PARAMETRIZATION: | |
print('[Full prompt-tuning Setting :) ]') | |
self.input_tokens = torch.arange(self.preseqlen).long() | |
self.wte = nn.Embedding(self.preseqlen, self.n_embd) | |
self.control_trans = nn.Sequential( | |
nn.Linear(self.n_embd, self.n_embd), | |
nn.Tanh(), | |
nn.Linear(self.n_embd, self.n_embd)) | |
self.get_prompt = self.get_prompt_p5 | |
self.dropout = nn.Dropout(self.prefix_dropout) | |
###### NUM PARAMS ######### | |
total_param = 0 | |
for name, param in self.named_parameters(): | |
# print(param.shape) | |
total_param += param.numel() | |
print('Total parameters is {}'.format(total_param)) | |
def from_pretrained(cls, args) -> Tuple[nn.Module, OmegaConf]: | |
# if not args.model_name_or_path: | |
# args.model_name_or_path = args.prefix_model_name_or_path | |
path = args.prefix_model_name_or_path | |
path = _MODELS[path] if path in _MODELS else path | |
path = utils.realpath_url_or_path(path, root=os.path.expanduser("~/.cache/minDALL-E")) | |
config_new = OmegaConf.load(os.path.join(path, 'config.yaml')) | |
if args.do_train: | |
config_base = get_base_config('story') | |
config_update = OmegaConf.merge(config_base, config_new) | |
for key, val in vars(args).items(): | |
if key in config_update.story.keys(): | |
OmegaConf.update(config_update, "story.%s" % key, val, merge=False) | |
if key in config_update.optimizer.keys(): | |
OmegaConf.update(config_update, "optimizer.%s" % key, val, merge=False) | |
if key in config_update.experiment.keys(): | |
OmegaConf.update(config_update, "experiment.%s" % key, val, merge=False) | |
else: | |
config_base = get_base_config('default') | |
config_update = OmegaConf.merge(config_base, config_new) | |
model = cls(config_update) | |
model.tokenizer = build_tokenizer(os.path.join(path, 'tokenizer'), | |
context_length=model.config_dataset.context_length, | |
lowercase=True, | |
dropout=None) | |
print(model.cross_attention_idxs) | |
# print(next(model.cross_attention_layers[0].parameters()).is_cuda) | |
if args.model_name_or_path: | |
print("Loading model from pretrained checkpoint %s" % args.model_name_or_path) | |
# model.from_ckpt(args.model_name_or_path) | |
try: | |
model.load_state_dict(torch.load(args.model_name_or_path)['state_dict']) | |
except KeyError: | |
model.load_state_dict(torch.load(args.model_name_or_path)['model_state_dict']) | |
else: | |
print("Loading models from checkpoint %s" % path) | |
model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt')) | |
model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt')) | |
return model, config_update | |
def init_cross_attention(self, cross_attention_layers, hparams): | |
self.cross_attention_idxs = cross_attention_layers | |
self.cross_attention_layers = [CrossAttentionLayer(ctx_len=hparams.ctx_len_img + hparams.ctx_len_txt, | |
embed_dim=hparams.embed_dim, | |
n_heads=hparams.n_heads, | |
attn_bias=hparams.attn_bias, | |
resid_pdrop=hparams.resid_pdrop, | |
attn_pdrop=hparams.attn_pdrop) for i in cross_attention_layers] | |
def get_prompt_p5(self, bsz=None, eval=False): | |
input_tokens = self.input_tokens.unsqueeze(0).expand(bsz, -1).to(self.device) | |
temp_control = self.wte(input_tokens) | |
past_key_values = self.control_trans(temp_control) #bsz, seqlen, layer*emb | |
if not eval: | |
past_key_values = self.dropout(past_key_values) | |
return past_key_values | |
def forward(self, | |
images: torch.FloatTensor, | |
src_images: Optional[torch.FloatTensor], | |
texts: Optional[torch.LongTensor], | |
**kwargs, | |
): | |
#{"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn':tgt_attn, 'src':src} | |
# print(images.shape, src_images.shape, texts.shape) | |
with torch.no_grad(): | |
with autocast(enabled=False): | |
codes = self.stage1.get_codes(images).detach() | |
src_codes = self.stage1.get_codes(src_images).detach() | |
B, C, H, W = images.shape | |
prompt = self.get_prompt(bsz=B) | |
pos_enc_prompt = get_positional_encoding(self.input_tokens.unsqueeze(0).expand(B, -1).to(self.device), mode='1d') | |
pos_enc_tokens = get_positional_encoding(texts, mode='1d') | |
codes = codes.clone().detach() | |
pos_enc_code = get_positional_encoding(codes, mode='1d') | |
src_codes = src_codes.clone().detach() | |
src_pos_enc_code = get_positional_encoding(src_codes, mode='1d') | |
# codes = codes.unsqueeze(-1) | |
# pos_enc_code = pos_enc_code.unsqueeze(-1) | |
# print(images.shape, codes.shape, texts.shape) | |
logits_img, logits_txt = self.stage2.forward_with_context(codes, texts, | |
pos_enc_code, pos_enc_tokens, src_codes, src_pos_enc_code, | |
self.cross_attention_idxs, self.cross_attention_layers, | |
prompt=prompt, pos_prompt=pos_enc_prompt) | |
# print(logits_img.shape, logits_txt.shape, codes.shape, texts.shape) | |
return logits_img, logits_txt, codes | |
def sampling(self, | |
tokens: torch.LongTensor, | |
prompt: torch.LongTensor, | |
source: torch.FloatTensor, | |
top_k: int = 256, | |
top_p: Optional[float] = None, | |
softmax_temperature: float = 1.0, | |
num_candidates: int = 96, | |
device: str = 'cuda:0', | |
use_fp16: bool = True, | |
labels=None) -> torch.FloatTensor: | |
self.stage1.eval() | |
self.stage2.eval() | |
if type(tokens) == str: | |
tokens = self.tokenizer.encode(prompt) | |
tokens = torch.LongTensor(tokens.ids) | |
else: | |
pass | |
tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0) | |
# Check if the encoding works as intended | |
# print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0]) | |
tokens = tokens.to(device) | |
source = source.to(device) | |
pos_enc_prompt = get_positional_encoding(self.input_tokens.unsqueeze(0).expand(num_candidates, -1).to(self.device), mode='1d') | |
with autocast(enabled=False): | |
src_codes = self.stage1.get_codes(source).detach() | |
src_codes = torch.repeat_interleave(src_codes, num_candidates, dim=0) | |
codes = sampling_conditional(self.stage2, | |
self.cross_attention_idxs, | |
self.cross_attention_layers, | |
tokens, | |
src_codes, | |
top_k=top_k, | |
top_p=top_p, | |
softmax_temperature=softmax_temperature, | |
use_fp16=use_fp16, | |
prompt=prompt, | |
pos_prompt=pos_enc_prompt) | |
codes = codes.view(num_candidates, 16, 16) # [B, 16, 16] | |
pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256] | |
return pixels | |
def predict_step(self, batch, batch_idx, return_images=False): | |
orig_images, texts = batch | |
# concatenate the list of prompts (split by n_head) for better downstream processing | |
# extra for checks | |
logits_img, logits_txt, codes = self(orig_images, texts) | |
pred = torch.argmax(logits_img.view(-1, logits_img.shape[-1]), dim=-1) | |
bs = orig_images.shape[0] | |
pred = pred.view(bs, 16, 16) # [B, 16, 16] | |
pixels = torch.clamp(self.stage1.decode_code(pred) * 0.5 + 0.5, 0, 1).cpu().numpy() # [B, 256, 256] | |
pixels = np.transpose(pixels, (0, 2, 3, 1)) | |
prompt = self.get_prompt(bsz=5, eval=True) | |
images = [] | |
for t in texts: | |
pixels = self.sampling(t, prompt, top_k=64, num_candidates=5, labels=codes[i]).cpu().numpy() | |
pixels = np.transpose(pixels, (0, 2, 3, 1)) | |
images.append(pixels) | |
# images.extend([p for p in pixels]) | |
# print([i.shape for i in images]) | |
if return_images: | |
return images | |
else: | |
save_image(orig_images, pixels, './out/images/pororo_story', batch_idx+10) | |
save_image(orig_images, images, './out/images/pororo_story', batch_idx) | |
class StoryDalle(Dalle): | |
"""Base model with story block""" | |
def __init__(self, config): | |
super().__init__(config) | |
print('Initializing the Conditional Dalle model') | |
self.config = config | |
self.story_linear = nn.Linear(config.story.sent_embed, config.stage2.hparams.embed_dim) | |
self.story_block = Block(ctx_len=config.story.story_len, | |
embed_dim=config.stage2.hparams.embed_dim, | |
n_heads=config.stage2.hparams.n_heads, | |
mlp_bias=config.stage2.hparams.mlp_bias, | |
attn_bias=config.stage2.hparams.attn_bias, | |
resid_pdrop=config.stage2.hparams.resid_pdrop, | |
attn_pdrop=config.stage2.hparams.attn_pdrop, | |
gelu_use_approx=config.stage2.hparams.gelu_use_approx) | |
if self.config.story.prompt: | |
self.n_embd = config.stage2.hparams.embed_dim | |
self.preseqlen = config.story.preseqlen | |
self.prefix_dropout = config.story.prefix_dropout | |
# DIFFERENT PARAMETRIZATION: | |
print('[Full prompt-tuning Setting :) ]') | |
self.input_tokens = torch.arange(self.preseqlen).long() | |
self.wte = nn.Embedding(self.preseqlen, self.n_embd) | |
self.control_trans = nn.Sequential( | |
nn.Linear(self.n_embd, self.n_embd), | |
nn.Tanh(), | |
nn.Linear(self.n_embd, self.n_embd)) | |
self.get_prompt = self.get_prompt_p5 | |
self.dropout = nn.Dropout(self.prefix_dropout) | |
if self.config.story.condition: | |
print('Setting up Cross-attention Layers') | |
self.init_cross_attention(list(range(2,42,3)), config.stage2.hparams) | |
###### NUM PARAMS ######### | |
total_param = 0 | |
for name, param in self.named_parameters(): | |
# print(param.shape) | |
total_param += param.numel() | |
print('Total parameters is {}'.format(total_param)) | |
def from_pretrained(cls, args) -> Tuple[nn.Module, OmegaConf]: | |
# if not args.model_name_or_path: | |
# args.model_name_or_path = args.prefix_model_name_or_path | |
path = args.prefix_model_name_or_path | |
path = _MODELS[path] if path in _MODELS else path | |
path = utils.realpath_url_or_path(path, root=os.path.expanduser("~/.cache/minDALL-E")) | |
config_new = OmegaConf.load(os.path.join(path, 'config.yaml')) | |
# if args.do_train: | |
config_base = get_base_config('story') | |
config_update = OmegaConf.merge(config_base, config_new) | |
for key, val in vars(args).items(): | |
if key in config_update.story.keys(): | |
OmegaConf.update(config_update, "story.%s" % key, val, merge=False) | |
if key in config_update.optimizer.keys(): | |
OmegaConf.update(config_update, "optimizer.%s" % key, val, merge=False) | |
if key in config_update.experiment.keys(): | |
OmegaConf.update(config_update, "experiment.%s" % key, val, merge=False) | |
# else: | |
# config_base = get_base_config('story') | |
# config_update = OmegaConf.merge(config_base, config_new) | |
# print(next(model.cross_attention_layers[0].parameters()).is_cuda) | |
if args.model_name_or_path: | |
model = cls(config_update) | |
model_dir = os.path.dirname(args.model_name_or_path) | |
print(model_dir) | |
model.tokenizer = build_tokenizer(model_dir, | |
context_length=model.config_dataset.context_length, | |
lowercase=True, | |
dropout=None) | |
print("Loaded tokenizer from finetuned checkpoint") | |
print(model.cross_attention_idxs) | |
print("Loading model from pretrained checkpoint %s" % args.model_name_or_path) | |
# model.from_ckpt(args.model_name_or_path) | |
try: | |
model.load_state_dict(torch.load(args.model_name_or_path, map_location=torch.device('cpu'))['state_dict']) | |
except KeyError: | |
model.load_state_dict(torch.load(args.model_name_or_path, map_location=torch.device('cpu'))['model_state_dict']) | |
else: | |
model = cls(config_update) | |
print(model.cross_attention_idxs) | |
print("Loading models from checkpoint %s" % path) | |
model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt')) | |
model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt')) | |
model.tokenizer = build_tokenizer(os.path.join(path, 'tokenizer'), | |
context_length=model.config_dataset.context_length, | |
lowercase=True, | |
dropout=None) | |
return model, config_update | |
def init_cross_attention(self, cross_attention_layers, hparams): | |
self.cross_attention_idxs = cross_attention_layers | |
self.cross_attention_layers = [CrossAttentionLayer(ctx_len=hparams.ctx_len_img + hparams.ctx_len_txt, | |
embed_dim=hparams.embed_dim, | |
n_heads=hparams.n_heads, | |
attn_bias=hparams.attn_bias, | |
resid_pdrop=hparams.resid_pdrop, | |
attn_pdrop=hparams.attn_pdrop) for i in cross_attention_layers] | |
def get_prompt_p5(self, bsz=None, eval=False): | |
input_tokens = self.input_tokens.unsqueeze(0).expand(bsz, -1).to(self.device) | |
temp_control = self.wte(input_tokens) | |
past_key_values = self.control_trans(temp_control) #bsz, seqlen, layer*emb | |
if not eval: | |
past_key_values = self.dropout(past_key_values) | |
return past_key_values | |
def forward(self, | |
images: torch.FloatTensor, | |
src_images: Optional[torch.FloatTensor], | |
texts: Optional[torch.LongTensor], | |
sent_embeds: Optional[torch.FloatTensor], | |
**kwargs, | |
): | |
# print(images.shape, src_images.shape, texts.shape, sent_embeds.shape) | |
B, L, C, H, W = images.shape | |
images = images.view(B*L, C, H, W) | |
src_images = src_images.unsqueeze(1).expand(-1, L, -1, -1, -1).reshape(B*L, C, H, W) | |
sent_embeds = self.story_block(self.story_linear(sent_embeds)).view(B * L, -1).unsqueeze(1) | |
texts = texts.view(B * L, -1) | |
#{"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn':tgt_attn, 'src':src} | |
with torch.no_grad(): | |
with autocast(enabled=False): | |
codes = self.stage1.get_codes(images).detach() | |
src_codes = self.stage1.get_codes(src_images).detach() | |
B, C, H, W = images.shape | |
if self.config.story.prompt: | |
prompt = self.get_prompt(bsz=B) | |
prompt = torch.cat([prompt, sent_embeds], dim=1) | |
else: | |
prompt = sent_embeds | |
# dim = 0 for full-model finetuning?? | |
pos_enc_prompt = get_positional_encoding(torch.arange(prompt.shape[1]).long().unsqueeze(0).expand(B, -1).to(self.device), | |
mode='1d') | |
pos_enc_tokens = get_positional_encoding(texts, mode='1d') | |
codes = codes.clone().detach() | |
pos_enc_code = get_positional_encoding(codes, mode='1d') | |
src_codes = src_codes.clone().detach() | |
src_pos_enc_code = get_positional_encoding(src_codes, mode='1d') | |
# codes = codes.unsqueeze(-1) | |
# pos_enc_code = pos_enc_code.unsqueeze(-1) | |
# print(images.shape, codes.shape, texts.shape) | |
if self.config.story.condition: | |
logits_img, logits_txt = self.stage2.forward_with_context(codes, texts, | |
pos_enc_code, pos_enc_tokens, src_codes, src_pos_enc_code, | |
self.cross_attention_idxs, self.cross_attention_layers, | |
prompt=prompt, pos_prompt=pos_enc_prompt) | |
else: | |
logits_img, logits_txt = self.stage2(codes, texts, pos_enc_code, pos_enc_tokens, prompt=prompt, | |
pos_prompt=pos_enc_prompt) | |
# print(logits_img.shape, logits_txt.shape, codes.shape, texts.shape) | |
return logits_img, logits_txt, codes | |
def sampling(self, | |
tokens: torch.LongTensor, | |
source: torch.FloatTensor, | |
sent_embeds: torch.FloatTensor, | |
top_k: int = 256, | |
top_p: Optional[float] = None, | |
softmax_temperature: float = 1.0, | |
num_candidates: int = 96, | |
device: str = 'cuda:0', | |
use_fp16: bool = True, | |
labels=None, | |
prompt = None) -> torch.FloatTensor: | |
self.stage1.eval() | |
self.stage2.eval() | |
if type(tokens) == str: | |
tokens = self.tokenizer.encode(tokens) | |
tokens = torch.LongTensor(tokens.ids) | |
# tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0) | |
# Check if the encoding works as intended | |
# print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0]) | |
tokens = tokens.to(device) | |
source = source.to(device) | |
# print(tokens.shape, sent_embeds.shape, prompt.shape) | |
B, L, _ = sent_embeds.shape | |
sent_embeds = self.story_block(self.story_linear(sent_embeds)).view(B * L, -1).unsqueeze(1) | |
if prompt is not None: | |
prompt = torch.cat([prompt, sent_embeds], dim=1) | |
else: | |
prompt = sent_embeds | |
pos_enc_prompt = get_positional_encoding(torch.arange(prompt.shape[1]).long().unsqueeze(0).expand(B*L, -1).to(self.device), mode='1d') | |
with autocast(enabled=False): | |
src_codes = self.stage1.get_codes(source).detach() | |
src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len, dim=0) | |
print(tokens.shape, src_codes.shape, prompt.shape) | |
if self.config.story.condition: | |
codes = sampling_conditional(self.stage2, | |
self.cross_attention_idxs, | |
self.cross_attention_layers, | |
tokens, | |
src_codes, | |
top_k=top_k, | |
top_p=top_p, | |
softmax_temperature=softmax_temperature, | |
use_fp16=use_fp16, | |
prompt=prompt, | |
pos_prompt=pos_enc_prompt) | |
else: | |
codes = sampling(self.stage2, | |
tokens, | |
top_k=top_k, | |
top_p=top_p, | |
softmax_temperature=softmax_temperature, | |
use_fp16=use_fp16, | |
prompt=prompt, | |
pos_prompt=pos_enc_prompt) | |
codes = codes.view(self.config.story.story_len, 16, 16) # [B, 16, 16] | |
pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256] | |
return pixels | |
def sampling_batch(self, | |
tokens: torch.LongTensor, | |
source: torch.FloatTensor, | |
sent_embeds: torch.FloatTensor, | |
top_k: int = 256, | |
top_p: Optional[float] = None, | |
softmax_temperature: float = 1.0, | |
num_candidates: int = 96, | |
device: str = 'cuda:0', | |
use_fp16: bool = True, | |
labels=None, | |
prompt=None, n_candidates=1) -> torch.FloatTensor: | |
self.stage1.eval() | |
self.stage2.eval() | |
if type(tokens) == str: | |
tokens = self.tokenizer.encode(tokens) | |
tokens = torch.LongTensor(tokens.ids) | |
# tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0) | |
# Check if the encoding works as intended | |
# print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0]) | |
tokens = tokens.to(device) | |
source = source.to(device) | |
# print(tokens.shape, sent_embeds.shape, prompt.shape) | |
B, L, _ = sent_embeds.shape | |
sent_embeds = self.story_block(self.story_linear(sent_embeds)).view(B * L, -1).unsqueeze(1) | |
if prompt is not None: | |
prompt = torch.cat([prompt, sent_embeds], dim=1) | |
else: | |
prompt = sent_embeds | |
pos_enc_prompt = get_positional_encoding( | |
torch.arange(prompt.shape[1]).long().unsqueeze(0).expand(B * L, -1).to(self.device), mode='1d') | |
with autocast(enabled=False): | |
src_codes = self.stage1.get_codes(source).detach() | |
# repeat inputs to adjust to n_candidates and story length | |
src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len * n_candidates, dim=0) | |
prompt = prompt.repeat(n_candidates, 1, 1) | |
pos_enc_prompt = pos_enc_prompt.repeat(n_candidates, 1) | |
tokens = tokens.repeat(n_candidates, 1) | |
print(tokens.shape, src_codes.shape, prompt.shape, pos_enc_prompt.shape) | |
if self.config.story.condition: | |
codes = sampling_conditional(self.stage2, | |
self.cross_attention_idxs, | |
self.cross_attention_layers, | |
tokens, | |
src_codes, | |
top_k=top_k, | |
top_p=top_p, | |
softmax_temperature=softmax_temperature, | |
use_fp16=use_fp16, | |
prompt=prompt, | |
pos_prompt=pos_enc_prompt) | |
else: | |
codes = sampling(self.stage2, | |
tokens, | |
top_k=top_k, | |
top_p=top_p, | |
softmax_temperature=softmax_temperature, | |
use_fp16=use_fp16, | |
prompt=prompt, | |
pos_prompt=pos_enc_prompt) | |
codes = codes.view(self.config.story.story_len * n_candidates, 16, 16) # [B, 16, 16] | |
print(codes.shape) | |
pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 3, 256, 256] | |
print(pixels.shape) | |
return pixels.view(n_candidates, self.config.story.story_len, pixels.shape[-3], pixels.shape[-2], pixels.shape[-1]) | |
def predict_step(self, batch, batch_idx, return_images=False): | |
orig_images, texts = batch | |
# concatenate the list of prompts (split by n_head) for better downstream processing | |
# extra for checks | |
logits_img, logits_txt, codes = self(orig_images, texts) | |
pred = torch.argmax(logits_img.view(-1, logits_img.shape[-1]), dim=-1) | |
bs = orig_images.shape[0] | |
pred = pred.view(bs, 16, 16) # [B, 16, 16] | |
pixels = torch.clamp(self.stage1.decode_code(pred) * 0.5 + 0.5, 0, 1).cpu().numpy() # [B, 256, 256] | |
pixels = np.transpose(pixels, (0, 2, 3, 1)) | |
prompt = self.get_prompt(bsz=5, eval=True) | |
images = [] | |
for t in texts: | |
pixels = self.sampling(t, prompt, top_k=64, num_candidates=5, labels=codes[i]).cpu().numpy() | |
pixels = np.transpose(pixels, (0, 2, 3, 1)) | |
images.append(pixels) | |
# images.extend([p for p in pixels]) | |
# print([i.shape for i in images]) | |
if return_images: | |
return images | |
else: | |
save_image(orig_images, pixels, './out/images/pororo_story', batch_idx+10) | |
save_image(orig_images, images, './out/images/pororo_story', batch_idx) | |