Spaces:
Build error
Build error
adymaharana
commited on
Commit
·
3d5e231
1
Parent(s):
84f1d0c
Added files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .idea/.gitignore +8 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/misc.xml +4 -0
- .idea/modules.xml +8 -0
- .idea/storydalle.iml +12 -0
- .idea/vcs.xml +6 -0
- 1.3B/config.yaml +38 -0
- 1.3B/tokenizer/bpe-16k-merges.txt +0 -0
- 1.3B/tokenizer/bpe-16k-vocab.json +0 -0
- app.py +353 -0
- dalle/__init__.py +0 -0
- dalle/__pycache__/__init__.cpython-38.pyc +0 -0
- dalle/__pycache__/trainer_prefix.cpython-38.pyc +0 -0
- dalle/models/__init__.py +1462 -0
- dalle/models/__pycache__/__init__.cpython-38.pyc +0 -0
- dalle/models/__pycache__/prefix_tuning_model.cpython-38.pyc +0 -0
- dalle/models/__pycache__/tokenizer.cpython-38.pyc +0 -0
- dalle/models/stage1/__pycache__/layers.cpython-38.pyc +0 -0
- dalle/models/stage1/__pycache__/vqgan.cpython-38.pyc +0 -0
- dalle/models/stage1/layers.py +373 -0
- dalle/models/stage1/vqgan.py +93 -0
- dalle/models/stage2/__pycache__/layers.cpython-38.pyc +0 -0
- dalle/models/stage2/__pycache__/transformer.cpython-38.pyc +0 -0
- dalle/models/stage2/layers.py +216 -0
- dalle/models/stage2/transformer.py +502 -0
- dalle/models/tokenizer.py +35 -0
- dalle/trainer_prefix.py +1629 -0
- dalle/utils/__init__.py +3 -0
- dalle/utils/__pycache__/__init__.cpython-38.pyc +0 -0
- dalle/utils/__pycache__/config.cpython-38.pyc +0 -0
- dalle/utils/__pycache__/sampling.cpython-38.pyc +0 -0
- dalle/utils/__pycache__/utils.cpython-38.pyc +0 -0
- dalle/utils/config.py +209 -0
- dalle/utils/sampling.py +369 -0
- dalle/utils/utils.py +131 -0
- demo/Barney.png +0 -0
- demo/Betty.png +0 -0
- demo/Crong.png +0 -0
- demo/Dino.png +0 -0
- demo/Eddy.png +0 -0
- demo/Fred.png +0 -0
- demo/Harry.png +0 -0
- demo/Loopy.png +0 -0
- demo/MrSlate.png +0 -0
- demo/Pebbles.png +0 -0
- demo/Petty.png +0 -0
- demo/Poby.png +0 -0
- demo/Pororo.png +0 -0
- demo/Rody.png +0 -0
- demo/Tongtong.png +0 -0
.idea/.gitignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Default ignored files
|
2 |
+
/shelf/
|
3 |
+
/workspace.xml
|
4 |
+
# Editor-based HTTP Client requests
|
5 |
+
/httpRequests/
|
6 |
+
# Datasource local storage ignored files
|
7 |
+
/dataSources/
|
8 |
+
/dataSources.local.xml
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<settings>
|
3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
4 |
+
<version value="1.0" />
|
5 |
+
</settings>
|
6 |
+
</component>
|
.idea/misc.xml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8" project-jdk-type="Python SDK" />
|
4 |
+
</project>
|
.idea/modules.xml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectModuleManager">
|
4 |
+
<modules>
|
5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/storydalle.iml" filepath="$PROJECT_DIR$/.idea/storydalle.iml" />
|
6 |
+
</modules>
|
7 |
+
</component>
|
8 |
+
</project>
|
.idea/storydalle.iml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<module type="PYTHON_MODULE" version="4">
|
3 |
+
<component name="NewModuleRootManager">
|
4 |
+
<content url="file://$MODULE_DIR$" />
|
5 |
+
<orderEntry type="inheritedJdk" />
|
6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
7 |
+
</component>
|
8 |
+
<component name="PyDocumentationSettings">
|
9 |
+
<option name="format" value="PLAIN" />
|
10 |
+
<option name="myDocStringFormat" value="Plain" />
|
11 |
+
</component>
|
12 |
+
</module>
|
.idea/vcs.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="VcsDirectoryMappings">
|
4 |
+
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
5 |
+
</component>
|
6 |
+
</project>
|
1.3B/config.yaml
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset:
|
2 |
+
tokenizer_type: CharBPE
|
3 |
+
context_length: 64
|
4 |
+
image_resolution: 256
|
5 |
+
|
6 |
+
stage1:
|
7 |
+
type: vqgan
|
8 |
+
embed_dim: 256
|
9 |
+
n_embed: 16384
|
10 |
+
hparams:
|
11 |
+
double_z: False
|
12 |
+
z_channels: 256
|
13 |
+
resolution: 256
|
14 |
+
in_channels: 3
|
15 |
+
out_ch: 3
|
16 |
+
ch: 128
|
17 |
+
ch_mult: [1, 1, 2, 2, 4]
|
18 |
+
num_res_blocks: 2
|
19 |
+
attn_resolutions: [16]
|
20 |
+
pdrop: 0.0
|
21 |
+
|
22 |
+
stage2:
|
23 |
+
type: transformer1d
|
24 |
+
vocab_size_txt: 16384
|
25 |
+
vocab_size_img: 16384
|
26 |
+
hparams:
|
27 |
+
embed_dim: 1536
|
28 |
+
n_layers: 42
|
29 |
+
n_heads: 24
|
30 |
+
n_dense_layers: 42
|
31 |
+
ctx_len_img: 256
|
32 |
+
ctx_len_txt: 64
|
33 |
+
embd_pdrop: 0.0
|
34 |
+
resid_pdrop: 0.0
|
35 |
+
attn_pdrop: 0.0
|
36 |
+
mlp_bias: True
|
37 |
+
attn_bias: True
|
38 |
+
gelu_use_approx: False
|
1.3B/tokenizer/bpe-16k-merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
1.3B/tokenizer/bpe-16k-vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
app.py
ADDED
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, torch
|
2 |
+
import gradio as gr
|
3 |
+
import torchvision.utils as vutils
|
4 |
+
import torchvision.transforms as transforms
|
5 |
+
from dalle.models import StoryDalle
|
6 |
+
import argparse
|
7 |
+
from PIL import Image
|
8 |
+
import numpy as np
|
9 |
+
from torchvision.utils import save_image
|
10 |
+
import tensorflow_hub as hub
|
11 |
+
import gdown
|
12 |
+
|
13 |
+
|
14 |
+
source_frame_paths = {
|
15 |
+
'Pororo': '/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/Pororo_ENGLISH1_2/Pororo_ENGLISH1_2_ep6/12.png',
|
16 |
+
'Loopy': '/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/Pororo_ENGLISH1_1/Pororo_ENGLISH1_1_ep12/26.png',
|
17 |
+
'Crong': '/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/Pororo_ENGLISH1_1/Pororo_ENGLISH1_1_ep12/10.png',
|
18 |
+
'Poby': '/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/Pororo_ENGLISH1_1/Pororo_ENGLISH1_1_ep9/34.png',
|
19 |
+
'Eddy': '/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/Pororo_ENGLISH1_1/Pororo_ENGLISH1_1_ep12/46.png',
|
20 |
+
'Petty': '/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/Pororo_ENGLISH2_1/Pororo_ENGLISH2_1_ep1/34.png',
|
21 |
+
'Tongtong': '/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/Pororo_ENGLISH3_1/Pororo_ENGLISH3_1_ep7/8.png',
|
22 |
+
'Rody': '/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/Pororo_ENGLISH3_1/Pororo_ENGLISH3_1_ep6/66.png',
|
23 |
+
'Harry': '/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/Pororo_ENGLISH3_1/Pororo_ENGLISH3_1_ep7/39.png',
|
24 |
+
}
|
25 |
+
|
26 |
+
|
27 |
+
def inverse_normalize(tensor, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
|
28 |
+
mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
|
29 |
+
std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
|
30 |
+
if mean.ndim == 1:
|
31 |
+
mean = mean.view(-1, 1, 1)
|
32 |
+
if std.ndim == 1:
|
33 |
+
std = std.view(-1, 1, 1)
|
34 |
+
tensor.mul_(std).add_(mean)
|
35 |
+
return tensor
|
36 |
+
|
37 |
+
|
38 |
+
def save_story_results(images, video_len=4, n_candidates=1, mask=None):
|
39 |
+
# print("Generated Images shape: ", images.shape)
|
40 |
+
|
41 |
+
if mask is None:
|
42 |
+
mask = [1 for _ in range(len(video_len))]
|
43 |
+
|
44 |
+
all_images = []
|
45 |
+
for i in range(len(images)): # batch size = 1
|
46 |
+
for j in range(n_candidates):
|
47 |
+
story = []
|
48 |
+
for k, m in enumerate(mask):
|
49 |
+
if m == 1:
|
50 |
+
story.append(images[i][j][k])
|
51 |
+
all_images.append(vutils.make_grid(story, sum(mask), padding=0))
|
52 |
+
all_images = vutils.make_grid(all_images, 1, padding=20)
|
53 |
+
print(all_images)
|
54 |
+
|
55 |
+
pad_len = video_len - sum(mask)
|
56 |
+
|
57 |
+
if pad_len > 0:
|
58 |
+
pad_height = 256 * n_candidates + 20 * (n_candidates + 1)
|
59 |
+
pad_width = 256 * pad_len + 20 * (pad_len)
|
60 |
+
pad_image = torch.ones(3, pad_height, pad_width)
|
61 |
+
|
62 |
+
print(all_images.shape, pad_image.shape)
|
63 |
+
all_images = torch.cat([all_images[:, :, :-15], pad_image], dim=-1)
|
64 |
+
|
65 |
+
print(all_images.shape)
|
66 |
+
return all_images[:, 15:-15, 15:-15]
|
67 |
+
|
68 |
+
|
69 |
+
def main(args):
|
70 |
+
device = 'cuda:0'
|
71 |
+
|
72 |
+
model_url = 'https://drive.google.com/file/d/1lJ6zMZ6qTvFu6H35-VEdFlN13MMslivJ/view?usp=sharing'
|
73 |
+
png_url = 'https://drive.google.com/file/d/1C33A1IzSHDPoQ4QBsgFWbF61QWaAxRo_/view?usp=sharing'
|
74 |
+
|
75 |
+
gdown.download(model_url, quiet=True, use_cookies=False, output="./ckpt/25.pth")
|
76 |
+
gdown.download(png_url, quiet=True, use_cookies=False, output="demo_pororo_good.png")
|
77 |
+
|
78 |
+
if args.debug:
|
79 |
+
model = None
|
80 |
+
embed = None
|
81 |
+
else:
|
82 |
+
model, config = StoryDalle.from_pretrained(args)
|
83 |
+
model.tokenizer.add_tokens(['pororo', 'loopy', 'eddy', 'harry', 'poby', 'tongtong', 'crong', 'rody', 'petty'])
|
84 |
+
model.eval()
|
85 |
+
model.to(device=device)
|
86 |
+
embed = hub.load("https://tfhub.dev/google/universal-sentence-encoder-large/5")
|
87 |
+
|
88 |
+
if model.config.story.condition:
|
89 |
+
for i in range(len(model.cross_attention_layers)):
|
90 |
+
model.cross_attention_layers[i].to(device)
|
91 |
+
print("Cross-attention layers are in cuda:", next(model.cross_attention_layers[0].parameters()).is_cuda)
|
92 |
+
|
93 |
+
valid_transform = transforms.Compose(
|
94 |
+
[transforms.Resize(config.dataset.image_resolution),
|
95 |
+
transforms.CenterCrop(config.dataset.image_resolution),
|
96 |
+
transforms.ToTensor(),
|
97 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
|
98 |
+
)
|
99 |
+
|
100 |
+
def predict(caption_1, caption_2, caption_3, caption_4, source='Pororo', top_k=32, top_p=0.2, n_candidates=4,
|
101 |
+
supercondition=False):
|
102 |
+
|
103 |
+
if not args.debug:
|
104 |
+
captions = [caption_1, caption_2, caption_3, caption_4]
|
105 |
+
mask = [1 if caption != '' else 0 for caption in captions]
|
106 |
+
print(captions, mask, source, n_candidates)
|
107 |
+
for i, caption in enumerate(captions):
|
108 |
+
if caption == "":
|
109 |
+
captions[i] = "Pororo is reading a book."
|
110 |
+
tokens = [model.tokenizer.encode(caption) for caption in captions]
|
111 |
+
texts = torch.stack([torch.LongTensor(token.ids) for token in tokens]).unsqueeze(0)
|
112 |
+
sent_embeds = torch.tensor(embed(captions).numpy())
|
113 |
+
# sent_embeds = torch.tensor(description_vecs[source_frame_paths[source].
|
114 |
+
# replace('/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/', '')[:-4]][0]).unsqueeze(0).repeat(4, 1)
|
115 |
+
|
116 |
+
src_image = valid_transform(Image.open('./demo/%s.png' % source).convert('RGB'))
|
117 |
+
|
118 |
+
stories = []
|
119 |
+
with torch.no_grad():
|
120 |
+
for i in range(texts.shape[0]):
|
121 |
+
pixels = model.sampling_batch(texts[i].to(device), src_image.unsqueeze(0).to(device),
|
122 |
+
sent_embeds.unsqueeze(0).to(device), top_k=top_k, top_p=top_p,
|
123 |
+
prompt=None, n_candidates=n_candidates).cpu()
|
124 |
+
stories.append(pixels)
|
125 |
+
|
126 |
+
img = save_story_results(stories, video_len=4, n_candidates=n_candidates, mask=mask)
|
127 |
+
save_image(img, "gradio_demo_pororo.png", normalize=True)
|
128 |
+
|
129 |
+
return "gradio_demo_pororo.png"
|
130 |
+
|
131 |
+
with gr.Blocks(css='#output {width:750px; height:750px; float:left;}') as demo:
|
132 |
+
gr.Markdown('''
|
133 |
+
<p style="text-align: center;font-size:40px;"><b>StoryDALL-E: Adapting Pretrained Text-to-Image Transformers for Story Continuation</b><br><font size="6">Adyasha Maharana, Darryl Hannan and Mohit Bansal (UNC Chapel Hill)<br>Published at <b>ECCV 2022</b></font></p>
|
134 |
+
|
135 |
+
StoryDALL-E \[1\] is a model trained for the task of Story Visualization \[2\].
|
136 |
+
The model receives a sequence of captions as input and generates a corresponding sequence of images which form a visual story depicting the narrative in the captions.
|
137 |
+
We modify this task to enable the model to receive an initial scene as input, which can be used as a cue for the setting of the story and also for generating unseen or low-resource visual elements. We refer to this task as Story Continuation \[1\].
|
138 |
+
StoryDALL-E is based on the [mega-dalle](https://github.com/borisdayma/dalle-mini) model and is adapted from the corresponding [PyTorch codebase](https://github.com/kuprel/min-dalle).
|
139 |
+
**This model has been developed for academic purposes only.**
|
140 |
+
|
141 |
+
\[[Paper](http://arxiv.org/abs/2209.06192)\] \[[Code](https://github.com/adymaharana/storydalle)\] \[[Model Card](https://github.com/adymaharana/storydalle/blob/main/MODEL_CARD.MD)\]
|
142 |
+
|
143 |
+
### Dataset
|
144 |
+
This model has been trained using the Pororo story visualization dataset \[1\].
|
145 |
+
The data was adapted from the popular cartoon series *Pororo the Little Penguin* and originally released by \[2\].
|
146 |
+
The Pororo dataset contains 9 recurring characters, as shown below, in the decreasing order of their frequency in the training data.
|
147 |
+
<p align="center">
|
148 |
+
<img src="file/pororo_characters.png" width="800">
|
149 |
+
</p>
|
150 |
+
The training dataset contains nearly 10,000 samples in the training set. Most of the scenes occur in a snowy village, surrounded by hills, trees and houses. A few episodes are located in gardens or water bodies. All the captions are in the English language and predominantly contain verbs in the present tense. Additionally, the training of this model starts from the pretrained checkpoint of mega-dalle, which is trained on the Conceptual Captions dataset.
|
151 |
+
|
152 |
+
### Intended Use
|
153 |
+
This model is intended for generating visual stories containing the 9 characters in the Pororo dataset. This version of the StoryDALL-E model is reasonable at the following scenarios:
|
154 |
+
* Frames containing a single character.
|
155 |
+
* Overtly visual actions such as *making cookies*, *walking*, *reading a book*, *sitting*.
|
156 |
+
* Scenes taking place in snowy settings, indoors and gardens.
|
157 |
+
* Visual stories contaning 1-3 characters across all frames.
|
158 |
+
* Scene transitions e.g. from day to night.
|
159 |
+
* Moderately capable of generating semantic concepts that do not appear in the story continuation dataset, such as *doughnut* and *lion*.
|
160 |
+
|
161 |
+
Here are some examples of generated visual stories for the above-mentioned settings.
|
162 |
+
|
163 |
+
<p align="center">
|
164 |
+
<img src="file/demo_pororo_good.png" width="1000">
|
165 |
+
</p>
|
166 |
+
|
167 |
+
Due to the small training dataset size for story visualization, the model has poor generalization to some unseen settings. The model struggles to generate coherent images in the following scenarios.
|
168 |
+
* Multiple characters in a frame.
|
169 |
+
* Non-visual actions such as *compliment*.
|
170 |
+
* Characters that are infrequent in the training dataset e.g. Rody, Harry.
|
171 |
+
* Background locations that are not found in the cartoon e.g. a busy city.
|
172 |
+
* Color-based descriptions for object.
|
173 |
+
* Completely new characters based on textual descriptions.
|
174 |
+
|
175 |
+
In the following demo, four or less captions can be entered in the `caption` text fields for the visual story.
|
176 |
+
Select a `source` frame based on the character that is predominant in your visual story.
|
177 |
+
`top_k` refers to the number of highest probability vocabulary tokens to keep for top-k-filtering.
|
178 |
+
Only the most probable tokens with probabilities that add up to `top_p` or higher are kept for generation.
|
179 |
+
Set `supercondition` to True to enable generation using a null hypothesis.
|
180 |
+
Select between 1-4 `n_candidates` to generate a diverse set of stories for the given captions.
|
181 |
+
<br><br>
|
182 |
+
Feel free to send feedback to [email protected].
|
183 |
+
''')
|
184 |
+
|
185 |
+
with gr.Row():
|
186 |
+
with gr.Column():
|
187 |
+
caption_1 = gr.Textbox(label="Caption 1", value='Pororo is reading a book.')
|
188 |
+
caption_2 = gr.Textbox(label="Caption 2", value='Pororo is sleeping on the couch.')
|
189 |
+
caption_3 = gr.Textbox(label="Caption 3", value='Pororo wakes up in the middle of the night in his bed.')
|
190 |
+
caption_4 = gr.Textbox(label="Caption 4", value='Pororo is in his bedroom and looks terrified.')
|
191 |
+
source = gr.Radio(["Pororo", "Loopy", "Crong", "Poby", "Eddy", "Petty", "Tongtong", "Rody", "Harry"],
|
192 |
+
label="Source", value="Pororo")
|
193 |
+
top_k = gr.Slider(16, 128, label="top_k", value=32)
|
194 |
+
top_p = gr.Slider(0.01, 1.0, label="top_p", value=0.2)
|
195 |
+
supercondition = gr.Checkbox(value=False, label='supercondition')
|
196 |
+
n_candidates = gr.Dropdown([1, 2, 3, 4], value=4, label='n_candidates')
|
197 |
+
|
198 |
+
with gr.Row():
|
199 |
+
# clear_btn = gr.Button("Clear")
|
200 |
+
submit_btn = gr.Button("Submit")
|
201 |
+
|
202 |
+
with gr.Column():
|
203 |
+
with gr.Row():
|
204 |
+
frame_1_label = gr.Button("Frame 1")
|
205 |
+
frame_2_label = gr.Button("Frame 2")
|
206 |
+
frame_3_label = gr.Button("Frame 3")
|
207 |
+
frame_4_label = gr.Button("Frame 4")
|
208 |
+
# frame_1_label = gr.Label("Frame 1")
|
209 |
+
# frame_2_label = gr.Label("Frame 2")
|
210 |
+
# frame_3_label = gr.Label("Frame 3")
|
211 |
+
# frame_4_label = gr.Label("Frame 4")
|
212 |
+
output = gr.Image(label="", elem_id='output')
|
213 |
+
|
214 |
+
submit_btn.click(fn=predict,
|
215 |
+
inputs=[caption_1, caption_2, caption_3, caption_4, source, top_k, top_p, n_candidates,
|
216 |
+
supercondition], outputs=output)
|
217 |
+
|
218 |
+
gr.Markdown('''
|
219 |
+
### References
|
220 |
+
|
221 |
+
\[1\] Maharana, Adyasha, et al. "StoryDALL-E: Adapting Pretrained Text-to-Image Transformers for Story Continuation." ECCV. 2022.
|
222 |
+
|
223 |
+
\[2\] Li, Yitong, et al. "Storygan: A sequential conditional gan for story visualization." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2019.
|
224 |
+
|
225 |
+
\[3\] Kim, Kyung-Min, et al. "DeepStory: video story QA by deep embedded memory networks." Proceedings of the 26th International Joint Conference on Artificial Intelligence. 2017.
|
226 |
+
|
227 |
+
\[4\] Sharma, Piyush, et al. "Conceptual captions: A cleaned, hypernymed, image alt-text dataset for automatic image captioning." Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 2018.
|
228 |
+
''')
|
229 |
+
|
230 |
+
demo.launch(share=True)
|
231 |
+
|
232 |
+
|
233 |
+
if __name__ == "__main__":
|
234 |
+
args_list = ['--model_name_or_path', './ckpt/25.pth',
|
235 |
+
'--prefix_model_name_or_path', './1.3B/',
|
236 |
+
'--dataset_name', 'pororo',
|
237 |
+
'--tuning_mode', 'story',
|
238 |
+
'--preseqlen', '32',
|
239 |
+
'--condition',
|
240 |
+
'--story_len', '4',
|
241 |
+
'--sent_embed', '512',
|
242 |
+
'--prefix_dropout', '0.2',
|
243 |
+
'--data_dir', '/playpen-ssd/adyasha/projects/StoryGAN/pororo_png/',
|
244 |
+
'--dataloader_num_workers', '1',
|
245 |
+
'--do_eval',
|
246 |
+
'--per_gpu_eval_batch_size', '16',
|
247 |
+
'--mode', 'story']
|
248 |
+
|
249 |
+
parser = argparse.ArgumentParser(description='arguments for training/evaluating prefix-tuning DALLE')
|
250 |
+
|
251 |
+
# Model Arguments
|
252 |
+
parser.add_argument('--model_name_or_path', type=str, default=None,
|
253 |
+
help='The model checkpoint for weights initialization.')
|
254 |
+
parser.add_argument('--prefix_model_name_or_path', type=str, default=None,
|
255 |
+
help='The prefix model checkpoint for weights initialization.')
|
256 |
+
parser.add_argument('--prefix_mode', type=str, default='activation', help='activation or embedding')
|
257 |
+
parser.add_argument('--preseqlen', type=int, default=0, help='how many tokens of prefix should we include.')
|
258 |
+
parser.add_argument('--optim_prefix', action="store_true",
|
259 |
+
help='set to True if optimizing prefix directly; no if through amortized function')
|
260 |
+
parser.add_argument('--tuning_mode', type=str, default='prefixtune', help='prefixtune or finetune')
|
261 |
+
parser.add_argument('--top_k_layers', type=int, default=2,
|
262 |
+
help='In finetuning setting, if we only tune the top k layers.')
|
263 |
+
parser.add_argument('--parameterize_mode', type=str, default='mlp',
|
264 |
+
help="mlp or emb to parametrize when we optimize for the embeddings.")
|
265 |
+
parser.add_argument('--prefix_dropout', type=float, default=0.0, help='dropout rate for the prefix tuning model.')
|
266 |
+
parser.add_argument('--teacher_dropout', type=float, default=0.0, help='dropout rate for the teacher model.')
|
267 |
+
parser.add_argument('--init_random', action="store_true", help="set True if initializing random embeddings")
|
268 |
+
parser.add_argument('--init_shallow', action="store_true", help="set True if not using reparameterization")
|
269 |
+
parser.add_argument('--init_shallow_word', type=bool, default=False,
|
270 |
+
help="set True if init_shallow and specify words")
|
271 |
+
parser.add_argument('--replay_buffer', action="store_true", help="set True if using replay buffer in training")
|
272 |
+
parser.add_argument('--gumbel', action="store_true", help="set True if using the gumbel softmax in training")
|
273 |
+
parser.add_argument('--hidden_dim_prefix', type=float, default=512, help="hidden dim of MLP for generating prefix?")
|
274 |
+
|
275 |
+
# Data Arguments
|
276 |
+
parser.add_argument('--dataset_name', type=str, default='pororo', help="dataset name")
|
277 |
+
parser.add_argument('--data_dir', type=str, default=None, help="Path to data directory")
|
278 |
+
parser.add_argument('--lowdata_token', type=str, default='story',
|
279 |
+
help="The token to be prepended at initialization time.")
|
280 |
+
parser.add_argument('--use_lowdata_token', type=bool, default=True,
|
281 |
+
help="Whether we should use the lowdata token for prefix-tuning")
|
282 |
+
parser.add_argument('--train_embeddings', action="store_true", help="Whether to train word embeddings")
|
283 |
+
parser.add_argument('--train_max_target_length', type=int, default=100,
|
284 |
+
help='the max target length for training data.')
|
285 |
+
parser.add_argument('--val_max_target_length', type=int, default=100, help='the max target length for dev data.')
|
286 |
+
parser.add_argument('--dataloader_num_workers', type=int, default=8, help='number of workers when loading data')
|
287 |
+
|
288 |
+
# new arguments for story
|
289 |
+
parser.add_argument('--prompt', action="store_true", help="set True if using prompts in StoryDALLE")
|
290 |
+
parser.add_argument('--story_len', type=int, default=4, help='the max target length for dev data.')
|
291 |
+
parser.add_argument('--sent_embed', type=int, default=384, help='the max target length for dev data.')
|
292 |
+
parser.add_argument('--condition', action="store_true", help="set True if using prompts in StoryDALLE")
|
293 |
+
parser.add_argument('--clip_embed', action="store_true", help="set True if using prompts in StoryDALLE")
|
294 |
+
|
295 |
+
# Training Arguments
|
296 |
+
parser.add_argument('--output_dir', type=str, default=None, help="Path to data directory")
|
297 |
+
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
298 |
+
parser.add_argument("--do_eval", action="store_true", help="Whether to run evaluation.")
|
299 |
+
parser.add_argument("--do_test", action="store_true", help="Whether to run test.")
|
300 |
+
parser.add_argument('--seed', type=int, default=42, help='seed for reproducibility')
|
301 |
+
parser.add_argument("--overwrite_output_dir", action="store_true", help="Whether to overwrite output dir.")
|
302 |
+
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
303 |
+
parser.add_argument(
|
304 |
+
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
|
305 |
+
)
|
306 |
+
parser.add_argument(
|
307 |
+
"--gradient_accumulation_steps",
|
308 |
+
type=int,
|
309 |
+
default=1,
|
310 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
311 |
+
)
|
312 |
+
|
313 |
+
parser.add_argument('--mode', type=str, default='val', help="mval or test.")
|
314 |
+
|
315 |
+
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
316 |
+
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
|
317 |
+
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
318 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
319 |
+
parser.add_argument(
|
320 |
+
"--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform."
|
321 |
+
)
|
322 |
+
parser.add_argument(
|
323 |
+
"--max_steps",
|
324 |
+
default=-1,
|
325 |
+
type=int,
|
326 |
+
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
327 |
+
)
|
328 |
+
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
329 |
+
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
330 |
+
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
|
331 |
+
parser.add_argument(
|
332 |
+
"--eval_all_checkpoints",
|
333 |
+
action="store_true",
|
334 |
+
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
335 |
+
)
|
336 |
+
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
337 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
338 |
+
parser.add_argument(
|
339 |
+
"--fp16",
|
340 |
+
action="store_true",
|
341 |
+
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
342 |
+
)
|
343 |
+
|
344 |
+
parser.add_argument("--debug", action="store_true", help="Whether to debug the demo.")
|
345 |
+
|
346 |
+
args = parser.parse_args(args_list)
|
347 |
+
|
348 |
+
main(args)
|
349 |
+
|
350 |
+
|
351 |
+
|
352 |
+
|
353 |
+
|
dalle/__init__.py
ADDED
File without changes
|
dalle/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (148 Bytes). View file
|
|
dalle/__pycache__/trainer_prefix.cpython-38.pyc
ADDED
Binary file (52.7 kB). View file
|
|
dalle/models/__init__.py
ADDED
@@ -0,0 +1,1462 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------
|
2 |
+
# Minimal DALL-E
|
3 |
+
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
import os
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import pytorch_lightning as pl
|
11 |
+
from typing import Optional, Tuple, Union
|
12 |
+
from omegaconf import OmegaConf
|
13 |
+
from torch.cuda.amp import autocast
|
14 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
|
15 |
+
from torch.nn import functional as F
|
16 |
+
from .stage1.vqgan import VQGAN
|
17 |
+
from .stage2.transformer import Transformer1d, iGPT
|
18 |
+
from .stage2.layers import Block
|
19 |
+
from .. import utils
|
20 |
+
from ..utils.config import get_base_config
|
21 |
+
from ..utils.sampling import sampling, sampling_igpt, get_positional_encoding, sampling_prefix, sampling_conditional
|
22 |
+
from ..utils.utils import save_image
|
23 |
+
from .tokenizer import build_tokenizer
|
24 |
+
import numpy as np
|
25 |
+
from .stage2.layers import CrossAttentionLayer
|
26 |
+
|
27 |
+
_MODELS = {
|
28 |
+
'minDALL-E/1.3B': 'https://arena.kakaocdn.net/brainrepo/models/minDALL-E/57b008f02ceaa02b779c8b7463143315/1.3B.tar.gz'
|
29 |
+
}
|
30 |
+
|
31 |
+
class Dalle(pl.LightningModule):
|
32 |
+
def __init__(self,
|
33 |
+
config: OmegaConf) -> None:
|
34 |
+
super().__init__()
|
35 |
+
self.tokenizer = None
|
36 |
+
self.stage1 = VQGAN(n_embed=config.stage1.n_embed,
|
37 |
+
embed_dim=config.stage1.embed_dim,
|
38 |
+
hparams=config.stage1.hparams)
|
39 |
+
self.stage2 = Transformer1d(vocab_size_txt=config.stage2.vocab_size_txt,
|
40 |
+
vocab_size_img=config.stage2.vocab_size_img,
|
41 |
+
hparams=config.stage2.hparams)
|
42 |
+
self.config = config
|
43 |
+
self.config_stage1 = config.stage1
|
44 |
+
self.config_stage2 = config.stage2
|
45 |
+
self.config_dataset = config.dataset
|
46 |
+
|
47 |
+
# # make the parameters in stage 1 not trainable
|
48 |
+
# self.stage1.eval()
|
49 |
+
# for p in self.stage1.parameters():
|
50 |
+
# p.requires_grad = False
|
51 |
+
|
52 |
+
@classmethod
|
53 |
+
def from_pretrained(cls, args) -> Tuple[nn.Module, OmegaConf]:
|
54 |
+
|
55 |
+
path = args.model_name_or_path
|
56 |
+
config_new = OmegaConf.load(os.path.join(path, 'config.yaml'))
|
57 |
+
if args.do_train:
|
58 |
+
config_base = get_base_config('finetuning')
|
59 |
+
config_update = OmegaConf.merge(config_base, config_new)
|
60 |
+
for key, val in vars(args).items():
|
61 |
+
if key in config_update.optimizer.keys():
|
62 |
+
OmegaConf.update(config_update, "optimizer.%s" % key, val, merge=False)
|
63 |
+
if key in config_update.experiment.keys():
|
64 |
+
OmegaConf.update(config_update, "experiment.%s" % key, val, merge=False)
|
65 |
+
else:
|
66 |
+
config_base = get_base_config('default')
|
67 |
+
config_update = OmegaConf.merge(config_base, config_new)
|
68 |
+
|
69 |
+
model = cls(config_update)
|
70 |
+
model.tokenizer = build_tokenizer(os.path.join(path, 'tokenizer'),
|
71 |
+
context_length=model.config_dataset.context_length,
|
72 |
+
lowercase=True,
|
73 |
+
dropout=None)
|
74 |
+
|
75 |
+
print("Loading models from checkpoint %s" % path)
|
76 |
+
|
77 |
+
if hasattr(args, 'dalle_path') and args.dalle_path and args.dalle_path.endswith('.pth'):
|
78 |
+
model.load_state_dict(torch.load(args.dalle_path)["model_state_dict"])
|
79 |
+
else:
|
80 |
+
model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt'))
|
81 |
+
model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt'))
|
82 |
+
|
83 |
+
return model, config_update
|
84 |
+
|
85 |
+
|
86 |
+
@torch.no_grad()
|
87 |
+
def sampling(self,
|
88 |
+
prompt: Union[str, torch.LongTensor],
|
89 |
+
top_k: int = 256,
|
90 |
+
top_p: Optional[float] = None,
|
91 |
+
softmax_temperature: float = 1.0,
|
92 |
+
num_candidates: int = 96,
|
93 |
+
device: str = 'cuda:0',
|
94 |
+
use_fp16: bool = True) -> torch.FloatTensor:
|
95 |
+
self.stage1.eval()
|
96 |
+
self.stage2.eval()
|
97 |
+
|
98 |
+
if type(prompt) == str:
|
99 |
+
tokens = self.tokenizer.encode(prompt)
|
100 |
+
tokens = torch.LongTensor(tokens.ids)
|
101 |
+
else:
|
102 |
+
tokens = prompt
|
103 |
+
tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0)
|
104 |
+
|
105 |
+
# Check if the encoding works as intended
|
106 |
+
# print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])
|
107 |
+
|
108 |
+
tokens = tokens.to(device)
|
109 |
+
codes = sampling(self.stage2,
|
110 |
+
tokens,
|
111 |
+
top_k=top_k,
|
112 |
+
top_p=top_p,
|
113 |
+
softmax_temperature=softmax_temperature,
|
114 |
+
use_fp16=use_fp16)
|
115 |
+
codes = codes.view(num_candidates, 16, 16) # [B, 16, 16]
|
116 |
+
pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
|
117 |
+
return pixels
|
118 |
+
|
119 |
+
def forward(self,
|
120 |
+
images: torch.FloatTensor,
|
121 |
+
texts: Optional[torch.LongTensor],
|
122 |
+
past=None
|
123 |
+
) -> tuple:
|
124 |
+
B, C, H, W = images.shape
|
125 |
+
with torch.no_grad():
|
126 |
+
with autocast(enabled=False):
|
127 |
+
codes = self.stage1.get_codes(images).detach()
|
128 |
+
pos_enc_tokens = get_positional_encoding(texts, mode='1d')
|
129 |
+
codes = codes.clone().detach()
|
130 |
+
pos_enc_code = get_positional_encoding(codes, mode='1d')
|
131 |
+
# codes = codes.unsqueeze(-1)
|
132 |
+
# pos_enc_code = pos_enc_code.unsqueeze(-1)
|
133 |
+
logits_img, logits_txt = self.stage2(codes, texts, pos_enc_code, pos_enc_tokens, past)
|
134 |
+
return logits_img, logits_txt, codes
|
135 |
+
|
136 |
+
def training_step(self, batch, batch_idx):
|
137 |
+
images, texts = batch
|
138 |
+
logits_img, logits_txt, codes = self(images, texts)
|
139 |
+
|
140 |
+
loss_img = F.cross_entropy(logits_img.view(-1, logits_img.shape[-1]), codes.view(-1))
|
141 |
+
loss_txt = F.cross_entropy(logits_txt.view(-1, logits_txt.shape[-1]), texts[:, 1:].reshape(-1))
|
142 |
+
self.log("train/loss_img", loss_img, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
143 |
+
self.log("train/loss_txt", loss_txt, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
144 |
+
return loss_img + loss_txt
|
145 |
+
|
146 |
+
def validation_step(self, batch, batch_idx):
|
147 |
+
images, texts = batch
|
148 |
+
logits_img, logits_txt, codes = self(images, texts)
|
149 |
+
# print(logits_img.shape, logits_txt.shape, codes.shape, texts.shape)
|
150 |
+
|
151 |
+
loss_img = F.cross_entropy(logits_img.view(-1, logits_img.shape[-1]), codes.view(-1))
|
152 |
+
loss_txt = F.cross_entropy(logits_txt.view(-1, logits_txt.shape[-1]), texts[:, 1:].reshape(-1))
|
153 |
+
self.log("val/loss_img", loss_img, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
154 |
+
self.log("val/loss_txt", loss_txt, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
155 |
+
return loss_img + loss_txt
|
156 |
+
|
157 |
+
def configure_optimizers(self):
|
158 |
+
assert self.config.optimizer.opt_type == 'adamW'
|
159 |
+
# assert self.config.optimizer.sched_type == 'cosine'
|
160 |
+
|
161 |
+
opt = torch.optim.AdamW(self.parameters(),
|
162 |
+
lr=self.config.optimizer.learning_rate,
|
163 |
+
betas=self.config.optimizer.betas,
|
164 |
+
weight_decay=self.config.optimizer.weight_decay)
|
165 |
+
# sched = CosineAnnealingLR(opt,
|
166 |
+
# T_max=self.config.optimizer.max_steps,
|
167 |
+
# eta_min=self.config.optimizer.min_lr)
|
168 |
+
|
169 |
+
def lr_lambda(current_step: int):
|
170 |
+
return max(
|
171 |
+
0.0, float(self.config.optimizer.max_steps - current_step) / float(max(1, self.config.optimizer.max_steps))
|
172 |
+
)
|
173 |
+
|
174 |
+
sched = LambdaLR(opt, lr_lambda)
|
175 |
+
sched = {
|
176 |
+
'scheduler': sched,
|
177 |
+
'name': 'linear'
|
178 |
+
}
|
179 |
+
return [opt], [sched]
|
180 |
+
|
181 |
+
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure,
|
182 |
+
on_tpu=False, using_native_amp=False, using_lbfgs=False):
|
183 |
+
optimizer.step(closure=optimizer_closure)
|
184 |
+
self.lr_schedulers().step()
|
185 |
+
self.log("lr", self.lr_schedulers().get_last_lr()[0], on_step=True, on_epoch=False, prog_bar=True, logger=True)
|
186 |
+
|
187 |
+
def on_epoch_start(self):
|
188 |
+
self.stage1.eval()
|
189 |
+
|
190 |
+
|
191 |
+
class ImageGPT(pl.LightningModule):
|
192 |
+
def __init__(self,
|
193 |
+
config: OmegaConf) -> None:
|
194 |
+
super().__init__()
|
195 |
+
self.stage1 = VQGAN(n_embed=config.stage1.n_embed,
|
196 |
+
embed_dim=config.stage1.embed_dim,
|
197 |
+
hparams=config.stage1.hparams)
|
198 |
+
self.stage2 = iGPT(vocab_size_img=config.stage2.vocab_size_img,
|
199 |
+
use_cls_cond=config.stage2.use_cls_cond,
|
200 |
+
hparams=config.stage2.hparams)
|
201 |
+
self.config = config
|
202 |
+
self.use_cls_cond = config.stage2.use_cls_cond
|
203 |
+
|
204 |
+
# make the parameters in stage 1 not trainable
|
205 |
+
self.stage1.eval()
|
206 |
+
for p in self.stage1.parameters():
|
207 |
+
p.requires_grad = False
|
208 |
+
|
209 |
+
@classmethod
|
210 |
+
def from_pretrained(cls,
|
211 |
+
path_upstream: str,
|
212 |
+
path_downstream: str) -> Tuple[nn.Module, OmegaConf]:
|
213 |
+
config_base = get_base_config(use_default=False)
|
214 |
+
config_down = OmegaConf.load(path_downstream)
|
215 |
+
config_down = OmegaConf.merge(config_base, config_down)
|
216 |
+
|
217 |
+
model = cls(config_down)
|
218 |
+
model.stage1.from_ckpt(os.path.join(path_upstream, 'stage1_last.ckpt'), strict=True)
|
219 |
+
model.stage2.from_ckpt(os.path.join(path_upstream, 'stage2_last.ckpt'), strict=False)
|
220 |
+
return model, config_down
|
221 |
+
|
222 |
+
def sample(self,
|
223 |
+
cls_idx: Optional[int] = None,
|
224 |
+
top_k: int = 256,
|
225 |
+
top_p: Optional[float] = None,
|
226 |
+
softmax_temperature: float = 1.0,
|
227 |
+
num_candidates: int = 16,
|
228 |
+
device: str = 'cuda:0',
|
229 |
+
use_fp16: bool = True,
|
230 |
+
is_tqdm: bool = True) -> torch.FloatTensor:
|
231 |
+
self.stage1.eval()
|
232 |
+
self.stage2.eval()
|
233 |
+
|
234 |
+
if cls_idx is None:
|
235 |
+
sos = self.stage2.sos.repeat(num_candidates, 1, 1)
|
236 |
+
else:
|
237 |
+
sos = torch.LongTensor([cls_idx]).to(device=device)
|
238 |
+
sos = sos.repeat(num_candidates)
|
239 |
+
sos = self.stage2.sos(sos).unsqueeze(1)
|
240 |
+
|
241 |
+
codes = sampling_igpt(self.stage2,
|
242 |
+
sos=sos,
|
243 |
+
top_k=top_k,
|
244 |
+
top_p=top_p,
|
245 |
+
softmax_temperature=softmax_temperature,
|
246 |
+
use_fp16=use_fp16,
|
247 |
+
is_tqdm=is_tqdm)
|
248 |
+
codes = codes.view(num_candidates, 16, 16) # [B, 16, 16]
|
249 |
+
pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
|
250 |
+
return pixels
|
251 |
+
|
252 |
+
def forward(self,
|
253 |
+
images: torch.FloatTensor,
|
254 |
+
labels: Optional[torch.LongTensor] = None) -> torch.FloatTensor:
|
255 |
+
B, C, H, W = images.shape
|
256 |
+
with torch.no_grad():
|
257 |
+
with autocast(enabled=False):
|
258 |
+
codes = self.stage1.get_codes(images).detach()
|
259 |
+
logits = self.stage2(codes, labels)
|
260 |
+
return logits, codes
|
261 |
+
|
262 |
+
def training_step(self, batch, batch_idx):
|
263 |
+
images, labels = batch
|
264 |
+
logits, codes = self(images, labels=labels if self.use_cls_cond else None)
|
265 |
+
loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), codes.view(-1))
|
266 |
+
self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
267 |
+
return loss
|
268 |
+
|
269 |
+
def validation_step(self, batch, batch_idx):
|
270 |
+
images, labels = batch
|
271 |
+
logits, codes = self(images, labels=labels if self.use_cls_cond else None)
|
272 |
+
loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), codes.view(-1))
|
273 |
+
self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
274 |
+
return loss
|
275 |
+
|
276 |
+
def configure_optimizers(self):
|
277 |
+
assert self.config.optimizer.opt_type == 'adamW'
|
278 |
+
assert self.config.optimizer.sched_type == 'cosine'
|
279 |
+
|
280 |
+
opt = torch.optim.AdamW(self.parameters(),
|
281 |
+
lr=self.config.optimizer.base_lr,
|
282 |
+
betas=self.config.optimizer.betas,
|
283 |
+
weight_decay=self.config.optimizer.weight_decay)
|
284 |
+
sched = CosineAnnealingLR(opt,
|
285 |
+
T_max=self.config.optimizer.max_steps,
|
286 |
+
eta_min=self.config.optimizer.min_lr)
|
287 |
+
sched = {
|
288 |
+
'scheduler': sched,
|
289 |
+
'name': 'cosine'
|
290 |
+
}
|
291 |
+
return [opt], [sched]
|
292 |
+
|
293 |
+
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure,
|
294 |
+
on_tpu=False, using_native_amp=False, using_lbfgs=False):
|
295 |
+
optimizer.step(closure=optimizer_closure)
|
296 |
+
self.lr_schedulers().step()
|
297 |
+
self.log("lr", self.lr_schedulers().get_last_lr()[0], on_step=True, on_epoch=False, prog_bar=True, logger=True)
|
298 |
+
|
299 |
+
def on_epoch_start(self):
|
300 |
+
self.stage1.eval()
|
301 |
+
|
302 |
+
|
303 |
+
class PromptDalle(Dalle):
|
304 |
+
"""Classification Head for transformer encoders"""
|
305 |
+
def __init__(self, config):
|
306 |
+
super().__init__(config)
|
307 |
+
print('Initializing the PromptTuning model')
|
308 |
+
|
309 |
+
self.config = config
|
310 |
+
self.n_embd = config.stage2.hparams.embed_dim
|
311 |
+
self.preseqlen = config.prompt.preseqlen
|
312 |
+
self.prefix_dropout = config.prompt.prefix_dropout
|
313 |
+
|
314 |
+
# DIFFERENT PARAMETRIZATION:
|
315 |
+
|
316 |
+
print('[Full prompt-tuning Setting :) ]')
|
317 |
+
self.input_tokens = torch.arange(self.preseqlen).long()
|
318 |
+
self.wte = nn.Embedding(self.preseqlen, self.n_embd)
|
319 |
+
self.control_trans = nn.Sequential(
|
320 |
+
nn.Linear(self.n_embd, self.n_embd),
|
321 |
+
nn.Tanh(),
|
322 |
+
nn.Linear(self.n_embd, self.n_embd))
|
323 |
+
self.get_prompt = self.get_prompt_p5
|
324 |
+
self.dropout = nn.Dropout(self.prefix_dropout)
|
325 |
+
|
326 |
+
###### NUM PARAMS #########
|
327 |
+
total_param = 0
|
328 |
+
for name, param in self.named_parameters():
|
329 |
+
# print(param.shape)
|
330 |
+
total_param += param.numel()
|
331 |
+
print('Total parameters is {}'.format(total_param))
|
332 |
+
|
333 |
+
|
334 |
+
@classmethod
|
335 |
+
def from_pretrained(cls, args) -> Tuple[nn.Module, OmegaConf]:
|
336 |
+
|
337 |
+
# if not args.model_name_or_path:
|
338 |
+
# args.model_name_or_path = args.prefix_model_name_or_path
|
339 |
+
|
340 |
+
path = args.prefix_model_name_or_path
|
341 |
+
path = _MODELS[path] if path in _MODELS else path
|
342 |
+
path = utils.realpath_url_or_path(path, root=os.path.expanduser("~/.cache/minDALL-E"))
|
343 |
+
|
344 |
+
config_base = get_base_config('prompt_tuning')
|
345 |
+
config_new = OmegaConf.load(os.path.join(path, 'config.yaml'))
|
346 |
+
config_update = OmegaConf.merge(config_base, config_new)
|
347 |
+
|
348 |
+
for key, val in vars(args).items():
|
349 |
+
if key in config_update.prompt.keys():
|
350 |
+
OmegaConf.update(config_update, "prompt.%s" % key, val, merge=False)
|
351 |
+
if key in config_update.optimizer.keys():
|
352 |
+
OmegaConf.update(config_update, "optimizer.%s" % key, val, merge=False)
|
353 |
+
if key in config_update.experiment.keys():
|
354 |
+
OmegaConf.update(config_update, "experiment.%s" % key, val, merge=False)
|
355 |
+
|
356 |
+
model = cls(config_update)
|
357 |
+
model.tokenizer = build_tokenizer(os.path.join(path, 'tokenizer'),
|
358 |
+
context_length=model.config_dataset.context_length,
|
359 |
+
lowercase=True,
|
360 |
+
dropout=None)
|
361 |
+
|
362 |
+
if args.model_name_or_path:
|
363 |
+
print("Loading model from pretrained checkpoint %s" % args.model_name_or_path)
|
364 |
+
# model.from_ckpt(args.model_name_or_path)
|
365 |
+
try:
|
366 |
+
model.load_state_dict(torch.load(args.model_name_or_path)['state_dict'])
|
367 |
+
except KeyError:
|
368 |
+
model.load_state_dict(torch.load(args.model_name_or_path)['model_state_dict'])
|
369 |
+
|
370 |
+
else:
|
371 |
+
print("Loading models from checkpoint %s" % path)
|
372 |
+
model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt'))
|
373 |
+
model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt'))
|
374 |
+
|
375 |
+
return model, config_update
|
376 |
+
|
377 |
+
def get_prompt_p5(self, bsz=None, eval=False):
|
378 |
+
input_tokens = self.input_tokens.unsqueeze(0).expand(bsz, -1).to(self.device)
|
379 |
+
temp_control = self.wte(input_tokens)
|
380 |
+
past_key_values = self.control_trans(temp_control) #bsz, seqlen, layer*emb
|
381 |
+
if not eval:
|
382 |
+
past_key_values = self.dropout(past_key_values)
|
383 |
+
return past_key_values
|
384 |
+
|
385 |
+
def forward(self,
|
386 |
+
images: torch.FloatTensor,
|
387 |
+
texts: Optional[torch.LongTensor],
|
388 |
+
**kwargs,
|
389 |
+
):
|
390 |
+
|
391 |
+
#{"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn':tgt_attn, 'src':src}
|
392 |
+
|
393 |
+
B, C, H, W = images.shape
|
394 |
+
prompt = self.get_prompt(bsz=B)
|
395 |
+
pos_enc_prompt = get_positional_encoding(self.input_tokens.unsqueeze(0).expand(B, -1).to(self.device), mode='1d')
|
396 |
+
|
397 |
+
# if self.mode_para == 2 and src_attn is not None and tgt_attn is not None:
|
398 |
+
# attention_mask = torch.cat([src_attn, tgt_attn], dim=1)
|
399 |
+
|
400 |
+
|
401 |
+
with torch.no_grad():
|
402 |
+
with autocast(enabled=False):
|
403 |
+
codes = self.stage1.get_codes(images).detach()
|
404 |
+
|
405 |
+
pos_enc_tokens = get_positional_encoding(texts, mode='1d')
|
406 |
+
codes = codes.clone().detach()
|
407 |
+
pos_enc_code = get_positional_encoding(codes, mode='1d')
|
408 |
+
# codes = codes.unsqueeze(-1)
|
409 |
+
# pos_enc_code = pos_enc_code.unsqueeze(-1)
|
410 |
+
# print(images.shape, codes.shape, texts.shape)
|
411 |
+
logits_img, logits_txt = self.stage2(codes, texts, pos_enc_code, pos_enc_tokens, prompt=prompt, pos_prompt=pos_enc_prompt)
|
412 |
+
return logits_img, logits_txt, codes
|
413 |
+
|
414 |
+
|
415 |
+
@torch.no_grad()
|
416 |
+
def sampling(self,
|
417 |
+
tokens: torch.LongTensor,
|
418 |
+
prompt: torch.FloatTensor,
|
419 |
+
top_k: int = 256,
|
420 |
+
top_p: Optional[float] = None,
|
421 |
+
softmax_temperature: float = 1.0,
|
422 |
+
num_candidates: int = 96,
|
423 |
+
device: str = 'cuda:0',
|
424 |
+
use_fp16: bool = True,
|
425 |
+
labels = None) -> torch.FloatTensor:
|
426 |
+
self.stage1.eval()
|
427 |
+
self.stage2.eval()
|
428 |
+
|
429 |
+
# tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0)
|
430 |
+
|
431 |
+
tokens = tokens.to(device)
|
432 |
+
pos_enc_prompt = get_positional_encoding(self.input_tokens.unsqueeze(0).expand(num_candidates, -1).to(self.device), mode='1d')
|
433 |
+
|
434 |
+
codes = sampling(self.stage2,
|
435 |
+
tokens,
|
436 |
+
top_k=top_k,
|
437 |
+
top_p=top_p,
|
438 |
+
softmax_temperature=softmax_temperature,
|
439 |
+
use_fp16=use_fp16,
|
440 |
+
prompt=prompt,
|
441 |
+
pos_prompt=pos_enc_prompt)
|
442 |
+
|
443 |
+
codes = codes.view(-1, 16, 16) # [B, 16, 16]
|
444 |
+
pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
|
445 |
+
return pixels
|
446 |
+
|
447 |
+
|
448 |
+
@torch.no_grad()
|
449 |
+
def predict_step(self, batch, batch_idx, return_images=False):
|
450 |
+
orig_images, texts = batch
|
451 |
+
|
452 |
+
# extra for checks
|
453 |
+
logits_img, logits_txt, codes = self(orig_images, texts)
|
454 |
+
pred = torch.argmax(logits_img.view(-1, logits_img.shape[-1]), dim=-1)
|
455 |
+
bs = orig_images.shape[0]
|
456 |
+
pred = pred.view(bs, 16, 16) # [B, 16, 16]
|
457 |
+
pixels = torch.clamp(self.stage1.decode_code(pred) * 0.5 + 0.5, 0, 1).cpu().numpy() # [B, 256, 256]
|
458 |
+
pixels = np.transpose(pixels, (0, 2, 3, 1))
|
459 |
+
|
460 |
+
# print(texts.shape, orig_images.shape)
|
461 |
+
prompt = self.get_prompt(bsz=5, eval=True)
|
462 |
+
|
463 |
+
images = []
|
464 |
+
for i, t in enumerate(texts):
|
465 |
+
pixels = self.sampling(t, prompt, top_k=16, num_candidates=5, labels=codes[i]).cpu().numpy()
|
466 |
+
pixels = np.transpose(pixels, (0, 2, 3, 1))
|
467 |
+
images.append(pixels)
|
468 |
+
|
469 |
+
if return_images:
|
470 |
+
return images
|
471 |
+
else:
|
472 |
+
save_image(orig_images, pixels, './out/images/pororo_prompt', batch_idx+10)
|
473 |
+
save_image(orig_images, images, './out/images/pororo_prompt', batch_idx)
|
474 |
+
|
475 |
+
|
476 |
+
class PrefixTuningDalle(Dalle):
|
477 |
+
"""Classification Head for transformer encoders"""
|
478 |
+
def __init__(self, config):
|
479 |
+
super().__init__(config)
|
480 |
+
print('Initializing the PrefixTuning model')
|
481 |
+
|
482 |
+
self.config = config
|
483 |
+
|
484 |
+
self.match_n_layer = config.stage2.hparams.n_layers
|
485 |
+
self.match_n_head = config.stage2.hparams.n_heads
|
486 |
+
self.match_n_embd = config.stage2.hparams.embed_dim // config.stage2.hparams.n_heads
|
487 |
+
self.n_embd = config.stage2.hparams.embed_dim
|
488 |
+
|
489 |
+
self.optim_prefix = config.prefix.optim_prefix
|
490 |
+
self.preseqlen = config.prefix.preseqlen
|
491 |
+
self.prefix_dropout = config.prefix.prefix_dropout
|
492 |
+
self.init_random = config.prefix.init_random
|
493 |
+
self.hidden_dim_prefix = config.prefix.hidden_dim_prefix
|
494 |
+
|
495 |
+
self.lowdata_token = config.prefix.lowdata_token
|
496 |
+
self.init_shallow = config.prefix.init_shallow
|
497 |
+
self.init_shallow_word = config.prefix.init_shallow_word
|
498 |
+
self.mode_para = 0
|
499 |
+
|
500 |
+
print('PrefixTuning')
|
501 |
+
print('preseqlen is {}, optimizing the prefix directly'.format(self.preseqlen))
|
502 |
+
|
503 |
+
# DIFFERENT PARAMETRIZATION:
|
504 |
+
|
505 |
+
print('[Full prefix-tuning Setting :) ]')
|
506 |
+
self.input_tokens = torch.arange(self.preseqlen).long()
|
507 |
+
self.wte = nn.Embedding(self.preseqlen, self.n_embd)
|
508 |
+
self.control_trans = nn.Sequential(
|
509 |
+
nn.Linear(self.n_embd, self.hidden_dim_prefix),
|
510 |
+
nn.Tanh(),
|
511 |
+
nn.Linear(self.hidden_dim_prefix, self.match_n_layer * 2 * self.n_embd))
|
512 |
+
self.get_prompt = self.get_prompt_p5
|
513 |
+
self.dropout = nn.Dropout(self.prefix_dropout)
|
514 |
+
|
515 |
+
###### NUM PARAMS #########
|
516 |
+
total_param = 0
|
517 |
+
for name, param in self.named_parameters():
|
518 |
+
# print(param.shape)
|
519 |
+
total_param += param.numel()
|
520 |
+
print('Total parameters is {}'.format(total_param))
|
521 |
+
|
522 |
+
|
523 |
+
@classmethod
|
524 |
+
def from_pretrained(cls, args) -> Tuple[nn.Module, OmegaConf]:
|
525 |
+
|
526 |
+
# if not args.model_name_or_path:
|
527 |
+
# args.model_name_or_path = args.prefix_model_name_or_path
|
528 |
+
|
529 |
+
path = args.prefix_model_name_or_path
|
530 |
+
path = _MODELS[path] if path in _MODELS else path
|
531 |
+
path = utils.realpath_url_or_path(path, root=os.path.expanduser("~/.cache/minDALL-E"))
|
532 |
+
|
533 |
+
config_base = get_base_config('prefixtuning')
|
534 |
+
config_new = OmegaConf.load(os.path.join(path, 'config.yaml'))
|
535 |
+
config_update = OmegaConf.merge(config_base, config_new)
|
536 |
+
|
537 |
+
for key, val in vars(args).items():
|
538 |
+
if key in config_update.prefix.keys():
|
539 |
+
OmegaConf.update(config_update, "prefix.%s" % key, val, merge=False)
|
540 |
+
if key in config_update.optimizer.keys():
|
541 |
+
OmegaConf.update(config_update, "optimizer.%s" % key, val, merge=False)
|
542 |
+
if key in config_update.experiment.keys():
|
543 |
+
OmegaConf.update(config_update, "experiment.%s" % key, val, merge=False)
|
544 |
+
|
545 |
+
model = cls(config_update)
|
546 |
+
model.tokenizer = build_tokenizer(os.path.join(path, 'tokenizer'),
|
547 |
+
context_length=model.config_dataset.context_length,
|
548 |
+
lowercase=True,
|
549 |
+
dropout=None)
|
550 |
+
|
551 |
+
if args.model_name_or_path:
|
552 |
+
print("Loading model from pretrained checkpoint %s" % args.model_name_or_path)
|
553 |
+
# model.from_ckpt(args.model_name_or_path)
|
554 |
+
try:
|
555 |
+
model.load_state_dict(torch.load(args.model_name_or_path)['state_dict'])
|
556 |
+
except KeyError:
|
557 |
+
model.load_state_dict(torch.load(args.model_name_or_path)['model_state_dict'])
|
558 |
+
|
559 |
+
else:
|
560 |
+
print("Loading models from checkpoint %s" % path)
|
561 |
+
model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt'))
|
562 |
+
model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt'))
|
563 |
+
|
564 |
+
return model, config_update
|
565 |
+
|
566 |
+
def get_prompt_p5(self, bsz=None, eval=False):
|
567 |
+
input_tokens = self.input_tokens.unsqueeze(0).expand(bsz, -1).to(self.device)
|
568 |
+
temp_control = self.wte(input_tokens)
|
569 |
+
past_key_values = self.control_trans(temp_control) #bsz, seqlen, layer*emb
|
570 |
+
bsz, seqlen, _ = past_key_values.shape
|
571 |
+
past_key_values = past_key_values.view(bsz, seqlen, self.match_n_layer * 2, self.match_n_head,
|
572 |
+
self.match_n_embd)
|
573 |
+
if not eval:
|
574 |
+
past_key_values = self.dropout(past_key_values)
|
575 |
+
# past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
|
576 |
+
past_key_values = past_key_values.permute([2, 0, 3, 1, 4])
|
577 |
+
# print(past_key_values.shape)
|
578 |
+
return past_key_values.split(2)
|
579 |
+
|
580 |
+
def forward(self,
|
581 |
+
images: torch.FloatTensor,
|
582 |
+
texts: Optional[torch.LongTensor],
|
583 |
+
**kwargs,
|
584 |
+
):
|
585 |
+
|
586 |
+
#{"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn':tgt_attn, 'src':src}
|
587 |
+
|
588 |
+
B, C, H, W = images.shape
|
589 |
+
|
590 |
+
if self.mode_para == 2:
|
591 |
+
past_key_values_prompt = self.get_prompt(bsz=B)
|
592 |
+
else:
|
593 |
+
past_key_values_prompt = self.get_prompt(bsz=B)
|
594 |
+
|
595 |
+
# if self.mode_para == 2 and src_attn is not None and tgt_attn is not None:
|
596 |
+
# attention_mask = torch.cat([src_attn, tgt_attn], dim=1)
|
597 |
+
|
598 |
+
|
599 |
+
with torch.no_grad():
|
600 |
+
with autocast(enabled=False):
|
601 |
+
codes = self.stage1.get_codes(images).detach()
|
602 |
+
|
603 |
+
pos_enc_tokens = get_positional_encoding(texts, mode='1d')
|
604 |
+
codes = codes.clone().detach()
|
605 |
+
pos_enc_code = get_positional_encoding(codes, mode='1d')
|
606 |
+
# codes = codes.unsqueeze(-1)
|
607 |
+
# pos_enc_code = pos_enc_code.unsqueeze(-1)
|
608 |
+
# print(images.shape, codes.shape, texts.shape)
|
609 |
+
logits_img, logits_txt = self.stage2(codes, texts, pos_enc_code, pos_enc_tokens, past_key_values_prompt)
|
610 |
+
return logits_img, logits_txt, codes
|
611 |
+
|
612 |
+
@torch.no_grad()
|
613 |
+
def sampling(self,
|
614 |
+
tokens: torch.LongTensor,
|
615 |
+
past: torch.FloatTensor,
|
616 |
+
top_k: int = 256,
|
617 |
+
top_p: Optional[float] = None,
|
618 |
+
softmax_temperature: float = 1.0,
|
619 |
+
num_candidates: int = 96,
|
620 |
+
device: str = 'cuda:0',
|
621 |
+
use_fp16: bool = True,
|
622 |
+
labels = None) -> torch.FloatTensor:
|
623 |
+
self.stage1.eval()
|
624 |
+
self.stage2.eval()
|
625 |
+
|
626 |
+
if len(past.shape) == 6:
|
627 |
+
n_layers, temp, bs, n_heads, seq_len, n_dim = past.shape
|
628 |
+
past = past.view(n_layers, temp, bs*n_heads, seq_len, n_dim)
|
629 |
+
|
630 |
+
tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0)
|
631 |
+
|
632 |
+
# Check if the encoding works as intended
|
633 |
+
# print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])
|
634 |
+
|
635 |
+
tokens = tokens.to(device)
|
636 |
+
codes = sampling_prefix(self.stage2,
|
637 |
+
tokens,
|
638 |
+
past,
|
639 |
+
top_k=top_k,
|
640 |
+
top_p=top_p,
|
641 |
+
softmax_temperature=softmax_temperature,
|
642 |
+
use_fp16=use_fp16,
|
643 |
+
labels = None if labels is None else labels.view(-1))
|
644 |
+
|
645 |
+
# codes = sampling(self.stage2,
|
646 |
+
# tokens,
|
647 |
+
# top_k=top_k,
|
648 |
+
# top_p=top_p,
|
649 |
+
# softmax_temperature=softmax_temperature,
|
650 |
+
# use_fp16=use_fp16)
|
651 |
+
|
652 |
+
codes = codes.view(num_candidates, 16, 16) # [B, 16, 16]
|
653 |
+
pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
|
654 |
+
return pixels
|
655 |
+
|
656 |
+
def training_step(self, batch, batch_idx):
|
657 |
+
images, texts = batch
|
658 |
+
logits_img, logits_txt, codes = self(images, texts)
|
659 |
+
|
660 |
+
loss_img = F.cross_entropy(logits_img.view(-1, logits_img.shape[-1]), codes.view(-1))
|
661 |
+
loss_txt = F.cross_entropy(logits_txt.view(-1, logits_txt.shape[-1]), texts[:, 1:].reshape(-1))
|
662 |
+
self.log("train/loss_img", loss_img, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
663 |
+
self.log("train/loss_txt", loss_txt, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
664 |
+
return loss_img + loss_txt
|
665 |
+
|
666 |
+
def validation_step(self, batch, batch_idx):
|
667 |
+
images, texts = batch
|
668 |
+
logits_img, logits_txt, codes = self(images, texts)
|
669 |
+
# print(logits_img.shape, logits_txt.shape, codes.shape, texts.shape)
|
670 |
+
|
671 |
+
loss_img = F.cross_entropy(logits_img.view(-1, logits_img.shape[-1]), codes.view(-1))
|
672 |
+
loss_txt = F.cross_entropy(logits_txt.view(-1, logits_txt.shape[-1]), texts[:, 1:].reshape(-1))
|
673 |
+
self.log("val/loss_img", loss_img, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
674 |
+
self.log("val/loss_txt", loss_txt, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
675 |
+
return loss_img + loss_txt
|
676 |
+
|
677 |
+
@torch.no_grad()
|
678 |
+
def predict_step(self, batch, batch_idx, return_images=False):
|
679 |
+
orig_images, texts = batch
|
680 |
+
|
681 |
+
# extra for checks
|
682 |
+
logits_img, logits_txt, codes = self(orig_images, texts)
|
683 |
+
pred = torch.argmax(logits_img.view(-1, logits_img.shape[-1]), dim=-1)
|
684 |
+
bs = orig_images.shape[0]
|
685 |
+
pred = pred.view(bs, 16, 16) # [B, 16, 16]
|
686 |
+
pixels = torch.clamp(self.stage1.decode_code(pred) * 0.5 + 0.5, 0, 1).cpu().numpy() # [B, 256, 256]
|
687 |
+
pixels = np.transpose(pixels, (0, 2, 3, 1))
|
688 |
+
|
689 |
+
|
690 |
+
# print(texts.shape, orig_images.shape)
|
691 |
+
# concatenate the list of prompts (split by n_head) for better downstream processing
|
692 |
+
past_key_values_prompt = self.get_prompt(bsz=5, eval=True)
|
693 |
+
# print(past_key_values_prompt[0].shape, past_key_values_prompt[1].shape, len(past_key_values_prompt))
|
694 |
+
past_key_values_prompt = torch.cat([x.unsqueeze(0) for x in past_key_values_prompt], dim=0)
|
695 |
+
n_layers, temp, bs, n_heads, seq_len, n_dim = past_key_values_prompt.shape
|
696 |
+
past_key_values_prompt = past_key_values_prompt.view(n_layers, temp, bs*n_heads, seq_len, n_dim)
|
697 |
+
# print(past_key_values_prompt.shape)
|
698 |
+
images = []
|
699 |
+
for i, t in enumerate(texts):
|
700 |
+
pixels = self.sampling(t, past_key_values_prompt, top_k=16, num_candidates=5, labels=codes[i]).cpu().numpy()
|
701 |
+
pixels = np.transpose(pixels, (0, 2, 3, 1))
|
702 |
+
images.append(pixels)
|
703 |
+
# images.extend([p for p in pixels])
|
704 |
+
# print([i.shape for i in images])
|
705 |
+
|
706 |
+
|
707 |
+
if return_images:
|
708 |
+
return images
|
709 |
+
else:
|
710 |
+
save_image(orig_images, pixels, './out/images/pororo_prefix', batch_idx+10)
|
711 |
+
save_image(orig_images, images, './out/images/pororo_prefix', batch_idx)
|
712 |
+
|
713 |
+
|
714 |
+
class ConditionalDalle(Dalle):
|
715 |
+
"""Classification Head for transformer encoders"""
|
716 |
+
def __init__(self, config):
|
717 |
+
super().__init__(config)
|
718 |
+
print('Initializing the Conditional Dalle model')
|
719 |
+
|
720 |
+
self.config = config
|
721 |
+
|
722 |
+
print('Setting up Cross-attention Layers')
|
723 |
+
self.init_cross_attention(list(range(2,42,3)), config.stage2.hparams)
|
724 |
+
|
725 |
+
###### NUM PARAMS #########
|
726 |
+
total_param = 0
|
727 |
+
for name, param in self.named_parameters():
|
728 |
+
# print(param.shape)
|
729 |
+
total_param += param.numel()
|
730 |
+
print('Total parameters is {}'.format(total_param))
|
731 |
+
|
732 |
+
@classmethod
|
733 |
+
def from_pretrained(cls, args) -> Tuple[nn.Module, OmegaConf]:
|
734 |
+
|
735 |
+
# if not args.model_name_or_path:
|
736 |
+
# args.model_name_or_path = args.prefix_model_name_or_path
|
737 |
+
|
738 |
+
path = args.model_name_or_path
|
739 |
+
config_new = OmegaConf.load(os.path.join(path, 'config.yaml'))
|
740 |
+
if args.do_train:
|
741 |
+
config_base = get_base_config('finetuning')
|
742 |
+
config_update = OmegaConf.merge(config_base, config_new)
|
743 |
+
for key, val in vars(args).items():
|
744 |
+
if key in config_update.optimizer.keys():
|
745 |
+
OmegaConf.update(config_update, "optimizer.%s" % key, val, merge=False)
|
746 |
+
if key in config_update.experiment.keys():
|
747 |
+
OmegaConf.update(config_update, "experiment.%s" % key, val, merge=False)
|
748 |
+
else:
|
749 |
+
config_base = get_base_config('default')
|
750 |
+
config_update = OmegaConf.merge(config_base, config_new)
|
751 |
+
|
752 |
+
model = cls(config_update)
|
753 |
+
model.tokenizer = build_tokenizer(os.path.join(path, 'tokenizer'),
|
754 |
+
context_length=model.config_dataset.context_length,
|
755 |
+
lowercase=True,
|
756 |
+
dropout=None)
|
757 |
+
print(model.cross_attention_idxs)
|
758 |
+
# print(next(model.cross_attention_layers[0].parameters()).is_cuda)
|
759 |
+
|
760 |
+
if args.dalle_path:
|
761 |
+
print("Loading model from pretrained checkpoint %s" % args.dalle_path)
|
762 |
+
# model.from_ckpt(args.model_name_or_path)
|
763 |
+
model.load_state_dict(torch.load(args.dalle_path)['model_state_dict'])
|
764 |
+
else:
|
765 |
+
print("Loading models from checkpoint %s" % path)
|
766 |
+
model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt'))
|
767 |
+
model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt'))
|
768 |
+
|
769 |
+
return model, config_update
|
770 |
+
|
771 |
+
|
772 |
+
def init_cross_attention(self, cross_attention_layers, hparams):
|
773 |
+
self.cross_attention_idxs = cross_attention_layers
|
774 |
+
self.cross_attention_layers = [CrossAttentionLayer(ctx_len=hparams.ctx_len_img + hparams.ctx_len_txt,
|
775 |
+
embed_dim=hparams.embed_dim,
|
776 |
+
n_heads=hparams.n_heads,
|
777 |
+
attn_bias=hparams.attn_bias,
|
778 |
+
resid_pdrop=hparams.resid_pdrop,
|
779 |
+
attn_pdrop=hparams.attn_pdrop) for i in cross_attention_layers]
|
780 |
+
|
781 |
+
|
782 |
+
def forward(self,
|
783 |
+
images: torch.FloatTensor,
|
784 |
+
src_images: Optional[torch.FloatTensor],
|
785 |
+
texts: Optional[torch.LongTensor],
|
786 |
+
**kwargs,
|
787 |
+
):
|
788 |
+
|
789 |
+
#{"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn':tgt_attn, 'src':src}
|
790 |
+
|
791 |
+
# print(images.shape, src_images.shape, texts.shape)
|
792 |
+
with torch.no_grad():
|
793 |
+
with autocast(enabled=False):
|
794 |
+
codes = self.stage1.get_codes(images).detach()
|
795 |
+
src_codes = self.stage1.get_codes(src_images).detach()
|
796 |
+
|
797 |
+
pos_enc_tokens = get_positional_encoding(texts, mode='1d')
|
798 |
+
codes = codes.clone().detach()
|
799 |
+
pos_enc_code = get_positional_encoding(codes, mode='1d')
|
800 |
+
src_codes = src_codes.clone().detach()
|
801 |
+
src_pos_enc_code = get_positional_encoding(src_codes, mode='1d')
|
802 |
+
# codes = codes.unsqueeze(-1)
|
803 |
+
# pos_enc_code = pos_enc_code.unsqueeze(-1)
|
804 |
+
# print(images.shape, codes.shape, texts.shape)
|
805 |
+
logits_img, logits_txt = self.stage2.forward_with_context(codes, texts,
|
806 |
+
pos_enc_code, pos_enc_tokens, src_codes, src_pos_enc_code,
|
807 |
+
self.cross_attention_idxs, self.cross_attention_layers)
|
808 |
+
# print(logits_img.shape, logits_txt.shape, codes.shape, texts.shape)
|
809 |
+
return logits_img, logits_txt, codes
|
810 |
+
|
811 |
+
@torch.no_grad()
|
812 |
+
def sampling(self,
|
813 |
+
prompt: torch.LongTensor,
|
814 |
+
source: torch.FloatTensor,
|
815 |
+
top_k: int = 256,
|
816 |
+
top_p: Optional[float] = None,
|
817 |
+
softmax_temperature: float = 1.0,
|
818 |
+
num_candidates: int = 96,
|
819 |
+
device: str = 'cuda:0',
|
820 |
+
use_fp16: bool = True) -> torch.FloatTensor:
|
821 |
+
self.stage1.eval()
|
822 |
+
self.stage2.eval()
|
823 |
+
|
824 |
+
if type(prompt) == str:
|
825 |
+
tokens = self.tokenizer.encode(prompt)
|
826 |
+
tokens = torch.LongTensor(tokens.ids)
|
827 |
+
else:
|
828 |
+
tokens = prompt
|
829 |
+
|
830 |
+
tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0)
|
831 |
+
|
832 |
+
# Check if the encoding works as intended
|
833 |
+
# print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])
|
834 |
+
|
835 |
+
tokens = tokens.to(device)
|
836 |
+
source = source.to(device)
|
837 |
+
|
838 |
+
with autocast(enabled=False):
|
839 |
+
src_codes = self.stage1.get_codes(source).detach()
|
840 |
+
src_codes = torch.repeat_interleave(src_codes, num_candidates, dim=0)
|
841 |
+
|
842 |
+
codes = sampling_conditional(self.stage2,
|
843 |
+
self.cross_attention_idxs,
|
844 |
+
self.cross_attention_layers,
|
845 |
+
tokens,
|
846 |
+
src_codes,
|
847 |
+
top_k=top_k,
|
848 |
+
top_p=top_p,
|
849 |
+
softmax_temperature=softmax_temperature,
|
850 |
+
use_fp16=use_fp16)
|
851 |
+
codes = codes.view(num_candidates, 16, 16) # [B, 16, 16]
|
852 |
+
pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
|
853 |
+
return pixels
|
854 |
+
|
855 |
+
def training_step(self, batch, batch_idx):
|
856 |
+
images, texts = batch
|
857 |
+
logits_img, logits_txt, codes = self(images, texts)
|
858 |
+
|
859 |
+
loss_img = F.cross_entropy(logits_img.view(-1, logits_img.shape[-1]), codes.view(-1))
|
860 |
+
loss_txt = F.cross_entropy(logits_txt.view(-1, logits_txt.shape[-1]), texts[:, 1:].reshape(-1))
|
861 |
+
self.log("train/loss_img", loss_img, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
862 |
+
self.log("train/loss_txt", loss_txt, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
863 |
+
return loss_img + loss_txt
|
864 |
+
|
865 |
+
def validation_step(self, batch, batch_idx):
|
866 |
+
images, texts = batch
|
867 |
+
logits_img, logits_txt, codes = self(images, texts)
|
868 |
+
# print(logits_img.shape, logits_txt.shape, codes.shape, texts.shape)
|
869 |
+
|
870 |
+
loss_img = F.cross_entropy(logits_img.view(-1, logits_img.shape[-1]), codes.view(-1))
|
871 |
+
loss_txt = F.cross_entropy(logits_txt.view(-1, logits_txt.shape[-1]), texts[:, 1:].reshape(-1))
|
872 |
+
self.log("val/loss_img", loss_img, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
873 |
+
self.log("val/loss_txt", loss_txt, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
874 |
+
return loss_img + loss_txt
|
875 |
+
|
876 |
+
@torch.no_grad()
|
877 |
+
def predict_step(self, batch, batch_idx):
|
878 |
+
orig_images, texts = batch
|
879 |
+
# concatenate the list of prompts (split by n_head) for better downstream processing
|
880 |
+
past_key_values_prompt = self.get_prompt(bsz=5)
|
881 |
+
past_key_values_prompt = torch.cat([x.unsqueeze(0) for x in past_key_values_prompt], dim=0)
|
882 |
+
images = []
|
883 |
+
for t in texts:
|
884 |
+
pixels = self.sampling(t, past_key_values_prompt, top_k=64, num_candidates=5).cpu().numpy()
|
885 |
+
pixels = np.transpose(pixels, (0, 2, 3, 1))
|
886 |
+
images.append(pixels)
|
887 |
+
# images.extend([p for p in pixels])
|
888 |
+
# print([i.shape for i in images])
|
889 |
+
|
890 |
+
save_image(orig_images, images, './out/images/', batch_idx)
|
891 |
+
|
892 |
+
|
893 |
+
class PromptConditionalDalle(Dalle):
|
894 |
+
"""Classification Head for transformer encoders"""
|
895 |
+
def __init__(self, config):
|
896 |
+
super().__init__(config)
|
897 |
+
print('Initializing the Conditional Dalle model')
|
898 |
+
|
899 |
+
self.config = config
|
900 |
+
|
901 |
+
print('Setting up Cross-attention Layers')
|
902 |
+
self.init_cross_attention(list(range(2,42,3)), config.stage2.hparams)
|
903 |
+
|
904 |
+
self.n_embd = config.stage2.hparams.embed_dim
|
905 |
+
self.preseqlen = config.story.preseqlen
|
906 |
+
self.prefix_dropout = config.story.prefix_dropout
|
907 |
+
|
908 |
+
# DIFFERENT PARAMETRIZATION:
|
909 |
+
|
910 |
+
print('[Full prompt-tuning Setting :) ]')
|
911 |
+
self.input_tokens = torch.arange(self.preseqlen).long()
|
912 |
+
self.wte = nn.Embedding(self.preseqlen, self.n_embd)
|
913 |
+
self.control_trans = nn.Sequential(
|
914 |
+
nn.Linear(self.n_embd, self.n_embd),
|
915 |
+
nn.Tanh(),
|
916 |
+
nn.Linear(self.n_embd, self.n_embd))
|
917 |
+
self.get_prompt = self.get_prompt_p5
|
918 |
+
self.dropout = nn.Dropout(self.prefix_dropout)
|
919 |
+
|
920 |
+
###### NUM PARAMS #########
|
921 |
+
total_param = 0
|
922 |
+
for name, param in self.named_parameters():
|
923 |
+
# print(param.shape)
|
924 |
+
total_param += param.numel()
|
925 |
+
print('Total parameters is {}'.format(total_param))
|
926 |
+
|
927 |
+
@classmethod
|
928 |
+
def from_pretrained(cls, args) -> Tuple[nn.Module, OmegaConf]:
|
929 |
+
|
930 |
+
# if not args.model_name_or_path:
|
931 |
+
# args.model_name_or_path = args.prefix_model_name_or_path
|
932 |
+
|
933 |
+
path = args.prefix_model_name_or_path
|
934 |
+
path = _MODELS[path] if path in _MODELS else path
|
935 |
+
path = utils.realpath_url_or_path(path, root=os.path.expanduser("~/.cache/minDALL-E"))
|
936 |
+
|
937 |
+
config_new = OmegaConf.load(os.path.join(path, 'config.yaml'))
|
938 |
+
if args.do_train:
|
939 |
+
config_base = get_base_config('story')
|
940 |
+
config_update = OmegaConf.merge(config_base, config_new)
|
941 |
+
for key, val in vars(args).items():
|
942 |
+
if key in config_update.story.keys():
|
943 |
+
OmegaConf.update(config_update, "story.%s" % key, val, merge=False)
|
944 |
+
if key in config_update.optimizer.keys():
|
945 |
+
OmegaConf.update(config_update, "optimizer.%s" % key, val, merge=False)
|
946 |
+
if key in config_update.experiment.keys():
|
947 |
+
OmegaConf.update(config_update, "experiment.%s" % key, val, merge=False)
|
948 |
+
else:
|
949 |
+
config_base = get_base_config('default')
|
950 |
+
config_update = OmegaConf.merge(config_base, config_new)
|
951 |
+
|
952 |
+
model = cls(config_update)
|
953 |
+
model.tokenizer = build_tokenizer(os.path.join(path, 'tokenizer'),
|
954 |
+
context_length=model.config_dataset.context_length,
|
955 |
+
lowercase=True,
|
956 |
+
dropout=None)
|
957 |
+
print(model.cross_attention_idxs)
|
958 |
+
# print(next(model.cross_attention_layers[0].parameters()).is_cuda)
|
959 |
+
|
960 |
+
if args.model_name_or_path:
|
961 |
+
print("Loading model from pretrained checkpoint %s" % args.model_name_or_path)
|
962 |
+
# model.from_ckpt(args.model_name_or_path)
|
963 |
+
try:
|
964 |
+
model.load_state_dict(torch.load(args.model_name_or_path)['state_dict'])
|
965 |
+
except KeyError:
|
966 |
+
model.load_state_dict(torch.load(args.model_name_or_path)['model_state_dict'])
|
967 |
+
|
968 |
+
else:
|
969 |
+
print("Loading models from checkpoint %s" % path)
|
970 |
+
model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt'))
|
971 |
+
model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt'))
|
972 |
+
|
973 |
+
return model, config_update
|
974 |
+
|
975 |
+
|
976 |
+
def init_cross_attention(self, cross_attention_layers, hparams):
|
977 |
+
self.cross_attention_idxs = cross_attention_layers
|
978 |
+
self.cross_attention_layers = [CrossAttentionLayer(ctx_len=hparams.ctx_len_img + hparams.ctx_len_txt,
|
979 |
+
embed_dim=hparams.embed_dim,
|
980 |
+
n_heads=hparams.n_heads,
|
981 |
+
attn_bias=hparams.attn_bias,
|
982 |
+
resid_pdrop=hparams.resid_pdrop,
|
983 |
+
attn_pdrop=hparams.attn_pdrop) for i in cross_attention_layers]
|
984 |
+
|
985 |
+
def get_prompt_p5(self, bsz=None, eval=False):
|
986 |
+
input_tokens = self.input_tokens.unsqueeze(0).expand(bsz, -1).to(self.device)
|
987 |
+
temp_control = self.wte(input_tokens)
|
988 |
+
past_key_values = self.control_trans(temp_control) #bsz, seqlen, layer*emb
|
989 |
+
if not eval:
|
990 |
+
past_key_values = self.dropout(past_key_values)
|
991 |
+
return past_key_values
|
992 |
+
|
993 |
+
def forward(self,
|
994 |
+
images: torch.FloatTensor,
|
995 |
+
src_images: Optional[torch.FloatTensor],
|
996 |
+
texts: Optional[torch.LongTensor],
|
997 |
+
**kwargs,
|
998 |
+
):
|
999 |
+
|
1000 |
+
#{"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn':tgt_attn, 'src':src}
|
1001 |
+
|
1002 |
+
# print(images.shape, src_images.shape, texts.shape)
|
1003 |
+
with torch.no_grad():
|
1004 |
+
with autocast(enabled=False):
|
1005 |
+
codes = self.stage1.get_codes(images).detach()
|
1006 |
+
src_codes = self.stage1.get_codes(src_images).detach()
|
1007 |
+
|
1008 |
+
B, C, H, W = images.shape
|
1009 |
+
prompt = self.get_prompt(bsz=B)
|
1010 |
+
pos_enc_prompt = get_positional_encoding(self.input_tokens.unsqueeze(0).expand(B, -1).to(self.device), mode='1d')
|
1011 |
+
|
1012 |
+
pos_enc_tokens = get_positional_encoding(texts, mode='1d')
|
1013 |
+
codes = codes.clone().detach()
|
1014 |
+
pos_enc_code = get_positional_encoding(codes, mode='1d')
|
1015 |
+
src_codes = src_codes.clone().detach()
|
1016 |
+
src_pos_enc_code = get_positional_encoding(src_codes, mode='1d')
|
1017 |
+
# codes = codes.unsqueeze(-1)
|
1018 |
+
# pos_enc_code = pos_enc_code.unsqueeze(-1)
|
1019 |
+
# print(images.shape, codes.shape, texts.shape)
|
1020 |
+
logits_img, logits_txt = self.stage2.forward_with_context(codes, texts,
|
1021 |
+
pos_enc_code, pos_enc_tokens, src_codes, src_pos_enc_code,
|
1022 |
+
self.cross_attention_idxs, self.cross_attention_layers,
|
1023 |
+
prompt=prompt, pos_prompt=pos_enc_prompt)
|
1024 |
+
# print(logits_img.shape, logits_txt.shape, codes.shape, texts.shape)
|
1025 |
+
return logits_img, logits_txt, codes
|
1026 |
+
|
1027 |
+
@torch.no_grad()
|
1028 |
+
def sampling(self,
|
1029 |
+
tokens: torch.LongTensor,
|
1030 |
+
prompt: torch.LongTensor,
|
1031 |
+
source: torch.FloatTensor,
|
1032 |
+
top_k: int = 256,
|
1033 |
+
top_p: Optional[float] = None,
|
1034 |
+
softmax_temperature: float = 1.0,
|
1035 |
+
num_candidates: int = 96,
|
1036 |
+
device: str = 'cuda:0',
|
1037 |
+
use_fp16: bool = True,
|
1038 |
+
labels=None) -> torch.FloatTensor:
|
1039 |
+
|
1040 |
+
self.stage1.eval()
|
1041 |
+
self.stage2.eval()
|
1042 |
+
|
1043 |
+
if type(tokens) == str:
|
1044 |
+
tokens = self.tokenizer.encode(prompt)
|
1045 |
+
tokens = torch.LongTensor(tokens.ids)
|
1046 |
+
else:
|
1047 |
+
pass
|
1048 |
+
|
1049 |
+
tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0)
|
1050 |
+
|
1051 |
+
# Check if the encoding works as intended
|
1052 |
+
# print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])
|
1053 |
+
|
1054 |
+
tokens = tokens.to(device)
|
1055 |
+
source = source.to(device)
|
1056 |
+
|
1057 |
+
pos_enc_prompt = get_positional_encoding(self.input_tokens.unsqueeze(0).expand(num_candidates, -1).to(self.device), mode='1d')
|
1058 |
+
|
1059 |
+
with autocast(enabled=False):
|
1060 |
+
src_codes = self.stage1.get_codes(source).detach()
|
1061 |
+
src_codes = torch.repeat_interleave(src_codes, num_candidates, dim=0)
|
1062 |
+
|
1063 |
+
codes = sampling_conditional(self.stage2,
|
1064 |
+
self.cross_attention_idxs,
|
1065 |
+
self.cross_attention_layers,
|
1066 |
+
tokens,
|
1067 |
+
src_codes,
|
1068 |
+
top_k=top_k,
|
1069 |
+
top_p=top_p,
|
1070 |
+
softmax_temperature=softmax_temperature,
|
1071 |
+
use_fp16=use_fp16,
|
1072 |
+
prompt=prompt,
|
1073 |
+
pos_prompt=pos_enc_prompt)
|
1074 |
+
|
1075 |
+
codes = codes.view(num_candidates, 16, 16) # [B, 16, 16]
|
1076 |
+
pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
|
1077 |
+
return pixels
|
1078 |
+
|
1079 |
+
|
1080 |
+
@torch.no_grad()
|
1081 |
+
def predict_step(self, batch, batch_idx, return_images=False):
|
1082 |
+
orig_images, texts = batch
|
1083 |
+
# concatenate the list of prompts (split by n_head) for better downstream processing
|
1084 |
+
|
1085 |
+
# extra for checks
|
1086 |
+
logits_img, logits_txt, codes = self(orig_images, texts)
|
1087 |
+
pred = torch.argmax(logits_img.view(-1, logits_img.shape[-1]), dim=-1)
|
1088 |
+
bs = orig_images.shape[0]
|
1089 |
+
pred = pred.view(bs, 16, 16) # [B, 16, 16]
|
1090 |
+
pixels = torch.clamp(self.stage1.decode_code(pred) * 0.5 + 0.5, 0, 1).cpu().numpy() # [B, 256, 256]
|
1091 |
+
pixels = np.transpose(pixels, (0, 2, 3, 1))
|
1092 |
+
|
1093 |
+
prompt = self.get_prompt(bsz=5, eval=True)
|
1094 |
+
|
1095 |
+
images = []
|
1096 |
+
for t in texts:
|
1097 |
+
pixels = self.sampling(t, prompt, top_k=64, num_candidates=5, labels=codes[i]).cpu().numpy()
|
1098 |
+
pixels = np.transpose(pixels, (0, 2, 3, 1))
|
1099 |
+
images.append(pixels)
|
1100 |
+
# images.extend([p for p in pixels])
|
1101 |
+
# print([i.shape for i in images])
|
1102 |
+
|
1103 |
+
if return_images:
|
1104 |
+
return images
|
1105 |
+
else:
|
1106 |
+
save_image(orig_images, pixels, './out/images/pororo_story', batch_idx+10)
|
1107 |
+
save_image(orig_images, images, './out/images/pororo_story', batch_idx)
|
1108 |
+
|
1109 |
+
|
1110 |
+
class StoryDalle(Dalle):
|
1111 |
+
"""Base model with story block"""
|
1112 |
+
def __init__(self, config):
|
1113 |
+
super().__init__(config)
|
1114 |
+
print('Initializing the Conditional Dalle model')
|
1115 |
+
|
1116 |
+
self.config = config
|
1117 |
+
|
1118 |
+
self.story_linear = nn.Linear(config.story.sent_embed, config.stage2.hparams.embed_dim)
|
1119 |
+
self.story_block = Block(ctx_len=config.story.story_len,
|
1120 |
+
embed_dim=config.stage2.hparams.embed_dim,
|
1121 |
+
n_heads=config.stage2.hparams.n_heads,
|
1122 |
+
mlp_bias=config.stage2.hparams.mlp_bias,
|
1123 |
+
attn_bias=config.stage2.hparams.attn_bias,
|
1124 |
+
resid_pdrop=config.stage2.hparams.resid_pdrop,
|
1125 |
+
attn_pdrop=config.stage2.hparams.attn_pdrop,
|
1126 |
+
gelu_use_approx=config.stage2.hparams.gelu_use_approx)
|
1127 |
+
|
1128 |
+
if self.config.story.prompt:
|
1129 |
+
self.n_embd = config.stage2.hparams.embed_dim
|
1130 |
+
self.preseqlen = config.story.preseqlen
|
1131 |
+
self.prefix_dropout = config.story.prefix_dropout
|
1132 |
+
|
1133 |
+
# DIFFERENT PARAMETRIZATION:
|
1134 |
+
|
1135 |
+
print('[Full prompt-tuning Setting :) ]')
|
1136 |
+
self.input_tokens = torch.arange(self.preseqlen).long()
|
1137 |
+
self.wte = nn.Embedding(self.preseqlen, self.n_embd)
|
1138 |
+
self.control_trans = nn.Sequential(
|
1139 |
+
nn.Linear(self.n_embd, self.n_embd),
|
1140 |
+
nn.Tanh(),
|
1141 |
+
nn.Linear(self.n_embd, self.n_embd))
|
1142 |
+
self.get_prompt = self.get_prompt_p5
|
1143 |
+
self.dropout = nn.Dropout(self.prefix_dropout)
|
1144 |
+
|
1145 |
+
if self.config.story.condition:
|
1146 |
+
print('Setting up Cross-attention Layers')
|
1147 |
+
self.init_cross_attention(list(range(2,42,3)), config.stage2.hparams)
|
1148 |
+
|
1149 |
+
###### NUM PARAMS #########
|
1150 |
+
total_param = 0
|
1151 |
+
for name, param in self.named_parameters():
|
1152 |
+
# print(param.shape)
|
1153 |
+
total_param += param.numel()
|
1154 |
+
print('Total parameters is {}'.format(total_param))
|
1155 |
+
|
1156 |
+
@classmethod
|
1157 |
+
def from_pretrained(cls, args) -> Tuple[nn.Module, OmegaConf]:
|
1158 |
+
|
1159 |
+
# if not args.model_name_or_path:
|
1160 |
+
# args.model_name_or_path = args.prefix_model_name_or_path
|
1161 |
+
|
1162 |
+
path = args.prefix_model_name_or_path
|
1163 |
+
path = _MODELS[path] if path in _MODELS else path
|
1164 |
+
path = utils.realpath_url_or_path(path, root=os.path.expanduser("~/.cache/minDALL-E"))
|
1165 |
+
|
1166 |
+
config_new = OmegaConf.load(os.path.join(path, 'config.yaml'))
|
1167 |
+
# if args.do_train:
|
1168 |
+
config_base = get_base_config('story')
|
1169 |
+
config_update = OmegaConf.merge(config_base, config_new)
|
1170 |
+
for key, val in vars(args).items():
|
1171 |
+
if key in config_update.story.keys():
|
1172 |
+
OmegaConf.update(config_update, "story.%s" % key, val, merge=False)
|
1173 |
+
if key in config_update.optimizer.keys():
|
1174 |
+
OmegaConf.update(config_update, "optimizer.%s" % key, val, merge=False)
|
1175 |
+
if key in config_update.experiment.keys():
|
1176 |
+
OmegaConf.update(config_update, "experiment.%s" % key, val, merge=False)
|
1177 |
+
# else:
|
1178 |
+
# config_base = get_base_config('story')
|
1179 |
+
# config_update = OmegaConf.merge(config_base, config_new)
|
1180 |
+
# print(next(model.cross_attention_layers[0].parameters()).is_cuda)
|
1181 |
+
|
1182 |
+
if args.model_name_or_path:
|
1183 |
+
if 'pororo' in args.model_name_or_path:
|
1184 |
+
config_update.stage2.vocab_size_txt = config_update.stage2.vocab_size_txt + 9
|
1185 |
+
elif 'flintstones' in args.model_name_or_path:
|
1186 |
+
config_update.stage2.vocab_size_txt = config_update.stage2.vocab_size_txt + 7
|
1187 |
+
model = cls(config_update)
|
1188 |
+
model_dir = os.path.dirname(args.model_name_or_path)
|
1189 |
+
print(model_dir)
|
1190 |
+
model.tokenizer = build_tokenizer(model_dir,
|
1191 |
+
context_length=model.config_dataset.context_length,
|
1192 |
+
lowercase=True,
|
1193 |
+
dropout=None)
|
1194 |
+
print("Loaded tokenizer from finetuned checkpoint")
|
1195 |
+
print(model.cross_attention_idxs)
|
1196 |
+
print("Loading model from pretrained checkpoint %s" % args.model_name_or_path)
|
1197 |
+
# model.from_ckpt(args.model_name_or_path)
|
1198 |
+
try:
|
1199 |
+
model.load_state_dict(torch.load(args.model_name_or_path)['state_dict'])
|
1200 |
+
except KeyError:
|
1201 |
+
model.load_state_dict(torch.load(args.model_name_or_path)['model_state_dict'])
|
1202 |
+
else:
|
1203 |
+
model = cls(config_update)
|
1204 |
+
print(model.cross_attention_idxs)
|
1205 |
+
print("Loading models from checkpoint %s" % path)
|
1206 |
+
model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt'))
|
1207 |
+
model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt'))
|
1208 |
+
|
1209 |
+
model.tokenizer = build_tokenizer(os.path.join(path, 'tokenizer'),
|
1210 |
+
context_length=model.config_dataset.context_length,
|
1211 |
+
lowercase=True,
|
1212 |
+
dropout=None)
|
1213 |
+
|
1214 |
+
|
1215 |
+
return model, config_update
|
1216 |
+
|
1217 |
+
|
1218 |
+
def init_cross_attention(self, cross_attention_layers, hparams):
|
1219 |
+
self.cross_attention_idxs = cross_attention_layers
|
1220 |
+
self.cross_attention_layers = [CrossAttentionLayer(ctx_len=hparams.ctx_len_img + hparams.ctx_len_txt,
|
1221 |
+
embed_dim=hparams.embed_dim,
|
1222 |
+
n_heads=hparams.n_heads,
|
1223 |
+
attn_bias=hparams.attn_bias,
|
1224 |
+
resid_pdrop=hparams.resid_pdrop,
|
1225 |
+
attn_pdrop=hparams.attn_pdrop) for i in cross_attention_layers]
|
1226 |
+
|
1227 |
+
def get_prompt_p5(self, bsz=None, eval=False):
|
1228 |
+
input_tokens = self.input_tokens.unsqueeze(0).expand(bsz, -1).to(self.device)
|
1229 |
+
temp_control = self.wte(input_tokens)
|
1230 |
+
past_key_values = self.control_trans(temp_control) #bsz, seqlen, layer*emb
|
1231 |
+
if not eval:
|
1232 |
+
past_key_values = self.dropout(past_key_values)
|
1233 |
+
return past_key_values
|
1234 |
+
|
1235 |
+
def forward(self,
|
1236 |
+
images: torch.FloatTensor,
|
1237 |
+
src_images: Optional[torch.FloatTensor],
|
1238 |
+
texts: Optional[torch.LongTensor],
|
1239 |
+
sent_embeds: Optional[torch.FloatTensor],
|
1240 |
+
**kwargs,
|
1241 |
+
):
|
1242 |
+
|
1243 |
+
# print(images.shape, src_images.shape, texts.shape, sent_embeds.shape)
|
1244 |
+
|
1245 |
+
B, L, C, H, W = images.shape
|
1246 |
+
images = images.view(B*L, C, H, W)
|
1247 |
+
src_images = src_images.unsqueeze(1).expand(-1, L, -1, -1, -1).reshape(B*L, C, H, W)
|
1248 |
+
sent_embeds = self.story_block(self.story_linear(sent_embeds)).view(B * L, -1).unsqueeze(1)
|
1249 |
+
texts = texts.view(B * L, -1)
|
1250 |
+
|
1251 |
+
#{"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn':tgt_attn, 'src':src}
|
1252 |
+
|
1253 |
+
with torch.no_grad():
|
1254 |
+
with autocast(enabled=False):
|
1255 |
+
codes = self.stage1.get_codes(images).detach()
|
1256 |
+
src_codes = self.stage1.get_codes(src_images).detach()
|
1257 |
+
|
1258 |
+
B, C, H, W = images.shape
|
1259 |
+
|
1260 |
+
if self.config.story.prompt:
|
1261 |
+
prompt = self.get_prompt(bsz=B)
|
1262 |
+
prompt = torch.cat([prompt, sent_embeds], dim=1)
|
1263 |
+
else:
|
1264 |
+
prompt = sent_embeds
|
1265 |
+
|
1266 |
+
# dim = 0 for full-model finetuning??
|
1267 |
+
pos_enc_prompt = get_positional_encoding(torch.arange(prompt.shape[1]).long().unsqueeze(0).expand(B, -1).to(self.device),
|
1268 |
+
mode='1d')
|
1269 |
+
|
1270 |
+
pos_enc_tokens = get_positional_encoding(texts, mode='1d')
|
1271 |
+
codes = codes.clone().detach()
|
1272 |
+
pos_enc_code = get_positional_encoding(codes, mode='1d')
|
1273 |
+
src_codes = src_codes.clone().detach()
|
1274 |
+
src_pos_enc_code = get_positional_encoding(src_codes, mode='1d')
|
1275 |
+
# codes = codes.unsqueeze(-1)
|
1276 |
+
# pos_enc_code = pos_enc_code.unsqueeze(-1)
|
1277 |
+
# print(images.shape, codes.shape, texts.shape)
|
1278 |
+
if self.config.story.condition:
|
1279 |
+
logits_img, logits_txt = self.stage2.forward_with_context(codes, texts,
|
1280 |
+
pos_enc_code, pos_enc_tokens, src_codes, src_pos_enc_code,
|
1281 |
+
self.cross_attention_idxs, self.cross_attention_layers,
|
1282 |
+
prompt=prompt, pos_prompt=pos_enc_prompt)
|
1283 |
+
else:
|
1284 |
+
logits_img, logits_txt = self.stage2(codes, texts, pos_enc_code, pos_enc_tokens, prompt=prompt,
|
1285 |
+
pos_prompt=pos_enc_prompt)
|
1286 |
+
|
1287 |
+
# print(logits_img.shape, logits_txt.shape, codes.shape, texts.shape)
|
1288 |
+
return logits_img, logits_txt, codes
|
1289 |
+
|
1290 |
+
@torch.no_grad()
|
1291 |
+
def sampling(self,
|
1292 |
+
tokens: torch.LongTensor,
|
1293 |
+
source: torch.FloatTensor,
|
1294 |
+
sent_embeds: torch.FloatTensor,
|
1295 |
+
top_k: int = 256,
|
1296 |
+
top_p: Optional[float] = None,
|
1297 |
+
softmax_temperature: float = 1.0,
|
1298 |
+
num_candidates: int = 96,
|
1299 |
+
device: str = 'cuda:0',
|
1300 |
+
use_fp16: bool = True,
|
1301 |
+
labels=None,
|
1302 |
+
prompt = None) -> torch.FloatTensor:
|
1303 |
+
|
1304 |
+
self.stage1.eval()
|
1305 |
+
self.stage2.eval()
|
1306 |
+
|
1307 |
+
if type(tokens) == str:
|
1308 |
+
tokens = self.tokenizer.encode(tokens)
|
1309 |
+
tokens = torch.LongTensor(tokens.ids)
|
1310 |
+
|
1311 |
+
# tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0)
|
1312 |
+
|
1313 |
+
# Check if the encoding works as intended
|
1314 |
+
# print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])
|
1315 |
+
|
1316 |
+
tokens = tokens.to(device)
|
1317 |
+
source = source.to(device)
|
1318 |
+
|
1319 |
+
# print(tokens.shape, sent_embeds.shape, prompt.shape)
|
1320 |
+
B, L, _ = sent_embeds.shape
|
1321 |
+
sent_embeds = self.story_block(self.story_linear(sent_embeds)).view(B * L, -1).unsqueeze(1)
|
1322 |
+
if prompt is not None:
|
1323 |
+
prompt = torch.cat([prompt, sent_embeds], dim=1)
|
1324 |
+
else:
|
1325 |
+
prompt = sent_embeds
|
1326 |
+
pos_enc_prompt = get_positional_encoding(torch.arange(prompt.shape[1]).long().unsqueeze(0).expand(B*L, -1).to(self.device), mode='1d')
|
1327 |
+
|
1328 |
+
with autocast(enabled=False):
|
1329 |
+
src_codes = self.stage1.get_codes(source).detach()
|
1330 |
+
src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len, dim=0)
|
1331 |
+
print(tokens.shape, src_codes.shape, prompt.shape)
|
1332 |
+
if self.config.story.condition:
|
1333 |
+
codes = sampling_conditional(self.stage2,
|
1334 |
+
self.cross_attention_idxs,
|
1335 |
+
self.cross_attention_layers,
|
1336 |
+
tokens,
|
1337 |
+
src_codes,
|
1338 |
+
top_k=top_k,
|
1339 |
+
top_p=top_p,
|
1340 |
+
softmax_temperature=softmax_temperature,
|
1341 |
+
use_fp16=use_fp16,
|
1342 |
+
prompt=prompt,
|
1343 |
+
pos_prompt=pos_enc_prompt)
|
1344 |
+
else:
|
1345 |
+
codes = sampling(self.stage2,
|
1346 |
+
tokens,
|
1347 |
+
top_k=top_k,
|
1348 |
+
top_p=top_p,
|
1349 |
+
softmax_temperature=softmax_temperature,
|
1350 |
+
use_fp16=use_fp16,
|
1351 |
+
prompt=prompt,
|
1352 |
+
pos_prompt=pos_enc_prompt)
|
1353 |
+
|
1354 |
+
codes = codes.view(self.config.story.story_len, 16, 16) # [B, 16, 16]
|
1355 |
+
pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
|
1356 |
+
return pixels
|
1357 |
+
|
1358 |
+
@torch.no_grad()
|
1359 |
+
def sampling_batch(self,
|
1360 |
+
tokens: torch.LongTensor,
|
1361 |
+
source: torch.FloatTensor,
|
1362 |
+
sent_embeds: torch.FloatTensor,
|
1363 |
+
top_k: int = 256,
|
1364 |
+
top_p: Optional[float] = None,
|
1365 |
+
softmax_temperature: float = 1.0,
|
1366 |
+
num_candidates: int = 96,
|
1367 |
+
device: str = 'cuda:0',
|
1368 |
+
use_fp16: bool = True,
|
1369 |
+
labels=None,
|
1370 |
+
prompt=None, n_candidates=1) -> torch.FloatTensor:
|
1371 |
+
|
1372 |
+
self.stage1.eval()
|
1373 |
+
self.stage2.eval()
|
1374 |
+
|
1375 |
+
if type(tokens) == str:
|
1376 |
+
tokens = self.tokenizer.encode(tokens)
|
1377 |
+
tokens = torch.LongTensor(tokens.ids)
|
1378 |
+
|
1379 |
+
# tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0)
|
1380 |
+
|
1381 |
+
# Check if the encoding works as intended
|
1382 |
+
# print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])
|
1383 |
+
|
1384 |
+
tokens = tokens.to(device)
|
1385 |
+
source = source.to(device)
|
1386 |
+
|
1387 |
+
# print(tokens.shape, sent_embeds.shape, prompt.shape)
|
1388 |
+
B, L, _ = sent_embeds.shape
|
1389 |
+
sent_embeds = self.story_block(self.story_linear(sent_embeds)).view(B * L, -1).unsqueeze(1)
|
1390 |
+
if prompt is not None:
|
1391 |
+
prompt = torch.cat([prompt, sent_embeds], dim=1)
|
1392 |
+
else:
|
1393 |
+
prompt = sent_embeds
|
1394 |
+
pos_enc_prompt = get_positional_encoding(
|
1395 |
+
torch.arange(prompt.shape[1]).long().unsqueeze(0).expand(B * L, -1).to(self.device), mode='1d')
|
1396 |
+
|
1397 |
+
with autocast(enabled=False):
|
1398 |
+
src_codes = self.stage1.get_codes(source).detach()
|
1399 |
+
|
1400 |
+
# repeat inputs to adjust to n_candidates and story length
|
1401 |
+
src_codes = torch.repeat_interleave(src_codes, self.config.story.story_len * n_candidates, dim=0)
|
1402 |
+
prompt = prompt.repeat(n_candidates, 1, 1)
|
1403 |
+
pos_enc_prompt = pos_enc_prompt.repeat(n_candidates, 1)
|
1404 |
+
tokens = tokens.repeat(n_candidates, 1)
|
1405 |
+
print(tokens.shape, src_codes.shape, prompt.shape, pos_enc_prompt.shape)
|
1406 |
+
if self.config.story.condition:
|
1407 |
+
codes = sampling_conditional(self.stage2,
|
1408 |
+
self.cross_attention_idxs,
|
1409 |
+
self.cross_attention_layers,
|
1410 |
+
tokens,
|
1411 |
+
src_codes,
|
1412 |
+
top_k=top_k,
|
1413 |
+
top_p=top_p,
|
1414 |
+
softmax_temperature=softmax_temperature,
|
1415 |
+
use_fp16=use_fp16,
|
1416 |
+
prompt=prompt,
|
1417 |
+
pos_prompt=pos_enc_prompt)
|
1418 |
+
else:
|
1419 |
+
codes = sampling(self.stage2,
|
1420 |
+
tokens,
|
1421 |
+
top_k=top_k,
|
1422 |
+
top_p=top_p,
|
1423 |
+
softmax_temperature=softmax_temperature,
|
1424 |
+
use_fp16=use_fp16,
|
1425 |
+
prompt=prompt,
|
1426 |
+
pos_prompt=pos_enc_prompt)
|
1427 |
+
|
1428 |
+
codes = codes.view(self.config.story.story_len * n_candidates, 16, 16) # [B, 16, 16]
|
1429 |
+
print(codes.shape)
|
1430 |
+
pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 3, 256, 256]
|
1431 |
+
print(pixels.shape)
|
1432 |
+
return pixels.view(n_candidates, self.config.story.story_len, pixels.shape[-3], pixels.shape[-2], pixels.shape[-1])
|
1433 |
+
|
1434 |
+
|
1435 |
+
@torch.no_grad()
|
1436 |
+
def predict_step(self, batch, batch_idx, return_images=False):
|
1437 |
+
orig_images, texts = batch
|
1438 |
+
# concatenate the list of prompts (split by n_head) for better downstream processing
|
1439 |
+
|
1440 |
+
# extra for checks
|
1441 |
+
logits_img, logits_txt, codes = self(orig_images, texts)
|
1442 |
+
pred = torch.argmax(logits_img.view(-1, logits_img.shape[-1]), dim=-1)
|
1443 |
+
bs = orig_images.shape[0]
|
1444 |
+
pred = pred.view(bs, 16, 16) # [B, 16, 16]
|
1445 |
+
pixels = torch.clamp(self.stage1.decode_code(pred) * 0.5 + 0.5, 0, 1).cpu().numpy() # [B, 256, 256]
|
1446 |
+
pixels = np.transpose(pixels, (0, 2, 3, 1))
|
1447 |
+
|
1448 |
+
prompt = self.get_prompt(bsz=5, eval=True)
|
1449 |
+
|
1450 |
+
images = []
|
1451 |
+
for t in texts:
|
1452 |
+
pixels = self.sampling(t, prompt, top_k=64, num_candidates=5, labels=codes[i]).cpu().numpy()
|
1453 |
+
pixels = np.transpose(pixels, (0, 2, 3, 1))
|
1454 |
+
images.append(pixels)
|
1455 |
+
# images.extend([p for p in pixels])
|
1456 |
+
# print([i.shape for i in images])
|
1457 |
+
|
1458 |
+
if return_images:
|
1459 |
+
return images
|
1460 |
+
else:
|
1461 |
+
save_image(orig_images, pixels, './out/images/pororo_story', batch_idx+10)
|
1462 |
+
save_image(orig_images, images, './out/images/pororo_story', batch_idx)
|
dalle/models/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (34.9 kB). View file
|
|
dalle/models/__pycache__/prefix_tuning_model.cpython-38.pyc
ADDED
Binary file (5.05 kB). View file
|
|
dalle/models/__pycache__/tokenizer.cpython-38.pyc
ADDED
Binary file (974 Bytes). View file
|
|
dalle/models/stage1/__pycache__/layers.cpython-38.pyc
ADDED
Binary file (7.85 kB). View file
|
|
dalle/models/stage1/__pycache__/vqgan.cpython-38.pyc
ADDED
Binary file (4.04 kB). View file
|
|
dalle/models/stage1/layers.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------
|
2 |
+
# Modified from VQGAN (https://github.com/CompVis/taming-transformers)
|
3 |
+
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
|
4 |
+
# ------------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from typing import Tuple, Optional
|
9 |
+
|
10 |
+
|
11 |
+
def nonlinearity(x):
|
12 |
+
# swish
|
13 |
+
return x*torch.sigmoid(x)
|
14 |
+
|
15 |
+
|
16 |
+
def Normalize(in_channels):
|
17 |
+
return torch.nn.GroupNorm(num_groups=32,
|
18 |
+
num_channels=in_channels,
|
19 |
+
eps=1e-6,
|
20 |
+
affine=True)
|
21 |
+
|
22 |
+
|
23 |
+
class Upsample(nn.Module):
|
24 |
+
def __init__(self, in_channels, with_conv):
|
25 |
+
super().__init__()
|
26 |
+
self.with_conv = with_conv
|
27 |
+
if self.with_conv:
|
28 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
29 |
+
in_channels,
|
30 |
+
kernel_size=3,
|
31 |
+
stride=1,
|
32 |
+
padding=1)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
36 |
+
if self.with_conv:
|
37 |
+
x = self.conv(x)
|
38 |
+
return x
|
39 |
+
|
40 |
+
|
41 |
+
class Downsample(nn.Module):
|
42 |
+
def __init__(self, in_channels, with_conv):
|
43 |
+
super().__init__()
|
44 |
+
self.with_conv = with_conv
|
45 |
+
if self.with_conv:
|
46 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
47 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
48 |
+
in_channels,
|
49 |
+
kernel_size=3,
|
50 |
+
stride=2,
|
51 |
+
padding=0)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
if self.with_conv:
|
55 |
+
pad = (0, 1, 0, 1)
|
56 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
57 |
+
x = self.conv(x)
|
58 |
+
else:
|
59 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
60 |
+
return x
|
61 |
+
|
62 |
+
|
63 |
+
class ResnetBlock(nn.Module):
|
64 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
65 |
+
dropout, temb_channels=512):
|
66 |
+
assert temb_channels == 0
|
67 |
+
super().__init__()
|
68 |
+
self.in_channels = in_channels
|
69 |
+
out_channels = in_channels if out_channels is None else out_channels
|
70 |
+
self.out_channels = out_channels
|
71 |
+
self.use_conv_shortcut = conv_shortcut
|
72 |
+
|
73 |
+
self.norm1 = Normalize(in_channels)
|
74 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
75 |
+
out_channels,
|
76 |
+
kernel_size=3,
|
77 |
+
stride=1,
|
78 |
+
padding=1)
|
79 |
+
self.norm2 = Normalize(out_channels)
|
80 |
+
self.dropout = torch.nn.Dropout(dropout)
|
81 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
82 |
+
out_channels,
|
83 |
+
kernel_size=3,
|
84 |
+
stride=1,
|
85 |
+
padding=1)
|
86 |
+
if self.in_channels != self.out_channels:
|
87 |
+
if self.use_conv_shortcut:
|
88 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
89 |
+
out_channels,
|
90 |
+
kernel_size=3,
|
91 |
+
stride=1,
|
92 |
+
padding=1)
|
93 |
+
else:
|
94 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
95 |
+
out_channels,
|
96 |
+
kernel_size=1,
|
97 |
+
stride=1,
|
98 |
+
padding=0)
|
99 |
+
|
100 |
+
def forward(self, x, temb=None):
|
101 |
+
assert temb is None
|
102 |
+
|
103 |
+
h = x
|
104 |
+
h = self.norm1(h)
|
105 |
+
h = nonlinearity(h)
|
106 |
+
h = self.conv1(h)
|
107 |
+
|
108 |
+
h = self.norm2(h)
|
109 |
+
h = nonlinearity(h)
|
110 |
+
h = self.dropout(h)
|
111 |
+
h = self.conv2(h)
|
112 |
+
|
113 |
+
if self.in_channels != self.out_channels:
|
114 |
+
if self.use_conv_shortcut:
|
115 |
+
x = self.conv_shortcut(x)
|
116 |
+
else:
|
117 |
+
x = self.nin_shortcut(x)
|
118 |
+
return x+h
|
119 |
+
|
120 |
+
|
121 |
+
class AttnBlock(nn.Module):
|
122 |
+
def __init__(self, in_channels):
|
123 |
+
super().__init__()
|
124 |
+
self.in_channels = in_channels
|
125 |
+
|
126 |
+
self.norm = Normalize(in_channels)
|
127 |
+
self.q = torch.nn.Conv2d(in_channels,
|
128 |
+
in_channels,
|
129 |
+
kernel_size=1,
|
130 |
+
stride=1,
|
131 |
+
padding=0)
|
132 |
+
self.k = torch.nn.Conv2d(in_channels,
|
133 |
+
in_channels,
|
134 |
+
kernel_size=1,
|
135 |
+
stride=1,
|
136 |
+
padding=0)
|
137 |
+
self.v = torch.nn.Conv2d(in_channels,
|
138 |
+
in_channels,
|
139 |
+
kernel_size=1,
|
140 |
+
stride=1,
|
141 |
+
padding=0)
|
142 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
143 |
+
in_channels,
|
144 |
+
kernel_size=1,
|
145 |
+
stride=1,
|
146 |
+
padding=0)
|
147 |
+
|
148 |
+
def forward(self, x):
|
149 |
+
h_ = x
|
150 |
+
h_ = self.norm(h_)
|
151 |
+
q = self.q(h_)
|
152 |
+
k = self.k(h_)
|
153 |
+
v = self.v(h_)
|
154 |
+
|
155 |
+
# compute attention
|
156 |
+
b, c, h, w = q.shape
|
157 |
+
q = q.reshape(b, c, h*w)
|
158 |
+
q = q.permute(0, 2, 1) # b,hw,c
|
159 |
+
k = k.reshape(b, c, h*w) # b,c,hw
|
160 |
+
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
161 |
+
w_ = w_ * (int(c)**(-0.5))
|
162 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
163 |
+
|
164 |
+
# attend to values
|
165 |
+
v = v.reshape(b, c, h*w)
|
166 |
+
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
167 |
+
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
168 |
+
h_ = h_.reshape(b, c, h, w)
|
169 |
+
|
170 |
+
h_ = self.proj_out(h_)
|
171 |
+
return x+h_
|
172 |
+
|
173 |
+
|
174 |
+
class Encoder(nn.Module):
|
175 |
+
def __init__(self,
|
176 |
+
*, # forced to use named arguments
|
177 |
+
ch: int,
|
178 |
+
out_ch: int,
|
179 |
+
ch_mult: Tuple[int] = (1, 2, 4, 8),
|
180 |
+
num_res_blocks: int,
|
181 |
+
attn_resolutions: Tuple[int],
|
182 |
+
pdrop: float = 0.0,
|
183 |
+
resamp_with_conv: bool = True,
|
184 |
+
in_channels: int,
|
185 |
+
resolution: int,
|
186 |
+
z_channels: int,
|
187 |
+
double_z: Optional[bool] = None) -> None:
|
188 |
+
super().__init__()
|
189 |
+
self.ch = ch
|
190 |
+
self.temb_ch = 0
|
191 |
+
self.num_resolutions = len(ch_mult)
|
192 |
+
self.num_res_blocks = num_res_blocks
|
193 |
+
self.resolution = resolution
|
194 |
+
self.in_channels = in_channels
|
195 |
+
|
196 |
+
# downsampling
|
197 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
198 |
+
self.ch,
|
199 |
+
kernel_size=3,
|
200 |
+
stride=1,
|
201 |
+
padding=1)
|
202 |
+
|
203 |
+
curr_res = resolution
|
204 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
205 |
+
self.down = nn.ModuleList()
|
206 |
+
for i_level in range(self.num_resolutions):
|
207 |
+
block = nn.ModuleList()
|
208 |
+
attn = nn.ModuleList()
|
209 |
+
block_in = ch*in_ch_mult[i_level]
|
210 |
+
block_out = ch*ch_mult[i_level]
|
211 |
+
for i_block in range(self.num_res_blocks):
|
212 |
+
block.append(ResnetBlock(in_channels=block_in,
|
213 |
+
out_channels=block_out,
|
214 |
+
temb_channels=self.temb_ch,
|
215 |
+
dropout=pdrop))
|
216 |
+
block_in = block_out
|
217 |
+
if curr_res in attn_resolutions:
|
218 |
+
attn.append(AttnBlock(block_in))
|
219 |
+
down = nn.Module()
|
220 |
+
down.block = block
|
221 |
+
down.attn = attn
|
222 |
+
if i_level != self.num_resolutions-1:
|
223 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
224 |
+
curr_res = curr_res // 2
|
225 |
+
self.down.append(down)
|
226 |
+
|
227 |
+
# middle
|
228 |
+
self.mid = nn.Module()
|
229 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
230 |
+
out_channels=block_in,
|
231 |
+
temb_channels=self.temb_ch,
|
232 |
+
dropout=pdrop)
|
233 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
234 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
235 |
+
out_channels=block_in,
|
236 |
+
temb_channels=self.temb_ch,
|
237 |
+
dropout=pdrop)
|
238 |
+
|
239 |
+
# end
|
240 |
+
self.norm_out = Normalize(block_in)
|
241 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
242 |
+
2*z_channels if double_z else z_channels,
|
243 |
+
kernel_size=3,
|
244 |
+
stride=1,
|
245 |
+
padding=1)
|
246 |
+
|
247 |
+
def forward(self, x):
|
248 |
+
assert x.shape[2] == x.shape[3] == self.resolution, \
|
249 |
+
"{}, {}".format(x.shape, self.resolution)
|
250 |
+
|
251 |
+
# downsampling
|
252 |
+
h = self.conv_in(x)
|
253 |
+
for i_level in range(self.num_resolutions):
|
254 |
+
for i_block in range(self.num_res_blocks):
|
255 |
+
h = self.down[i_level].block[i_block](h)
|
256 |
+
if len(self.down[i_level].attn) > 0:
|
257 |
+
h = self.down[i_level].attn[i_block](h)
|
258 |
+
if i_level != self.num_resolutions-1:
|
259 |
+
h = self.down[i_level].downsample(h)
|
260 |
+
|
261 |
+
# middle
|
262 |
+
h = self.mid.block_1(h)
|
263 |
+
h = self.mid.attn_1(h)
|
264 |
+
h = self.mid.block_2(h)
|
265 |
+
|
266 |
+
# end
|
267 |
+
h = self.norm_out(h)
|
268 |
+
h = nonlinearity(h)
|
269 |
+
h = self.conv_out(h)
|
270 |
+
return h
|
271 |
+
|
272 |
+
|
273 |
+
class Decoder(nn.Module):
|
274 |
+
def __init__(self,
|
275 |
+
*, # forced to use named arguments
|
276 |
+
ch: int,
|
277 |
+
out_ch: int,
|
278 |
+
ch_mult: Tuple[int] = (1, 2, 4, 8),
|
279 |
+
num_res_blocks: int,
|
280 |
+
attn_resolutions: Tuple[int],
|
281 |
+
pdrop: float = 0.0,
|
282 |
+
resamp_with_conv: bool = True,
|
283 |
+
in_channels: int,
|
284 |
+
resolution: int,
|
285 |
+
z_channels: int,
|
286 |
+
double_z: bool) -> None:
|
287 |
+
super().__init__()
|
288 |
+
self.ch = ch
|
289 |
+
self.temb_ch = 0
|
290 |
+
self.num_resolutions = len(ch_mult)
|
291 |
+
self.num_res_blocks = num_res_blocks
|
292 |
+
self.resolution = resolution
|
293 |
+
self.in_channels = in_channels
|
294 |
+
|
295 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
296 |
+
block_in = ch*ch_mult[self.num_resolutions-1]
|
297 |
+
curr_res = resolution // 2**(self.num_resolutions-1)
|
298 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
299 |
+
|
300 |
+
# z to block_in
|
301 |
+
self.conv_in = torch.nn.Conv2d(z_channels,
|
302 |
+
block_in,
|
303 |
+
kernel_size=3,
|
304 |
+
stride=1,
|
305 |
+
padding=1)
|
306 |
+
|
307 |
+
# middle
|
308 |
+
self.mid = nn.Module()
|
309 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
310 |
+
out_channels=block_in,
|
311 |
+
temb_channels=self.temb_ch,
|
312 |
+
dropout=pdrop)
|
313 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
314 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
315 |
+
out_channels=block_in,
|
316 |
+
temb_channels=self.temb_ch,
|
317 |
+
dropout=pdrop)
|
318 |
+
|
319 |
+
# upsampling
|
320 |
+
self.up = nn.ModuleList()
|
321 |
+
for i_level in reversed(range(self.num_resolutions)):
|
322 |
+
block = nn.ModuleList()
|
323 |
+
attn = nn.ModuleList()
|
324 |
+
block_out = ch*ch_mult[i_level]
|
325 |
+
for i_block in range(self.num_res_blocks+1):
|
326 |
+
block.append(ResnetBlock(in_channels=block_in,
|
327 |
+
out_channels=block_out,
|
328 |
+
temb_channels=self.temb_ch,
|
329 |
+
dropout=pdrop))
|
330 |
+
block_in = block_out
|
331 |
+
if curr_res in attn_resolutions:
|
332 |
+
attn.append(AttnBlock(block_in))
|
333 |
+
up = nn.Module()
|
334 |
+
up.block = block
|
335 |
+
up.attn = attn
|
336 |
+
if i_level != 0:
|
337 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
338 |
+
curr_res = curr_res * 2
|
339 |
+
self.up.insert(0, up) # prepend to get consistent order
|
340 |
+
|
341 |
+
# end
|
342 |
+
self.norm_out = Normalize(block_in)
|
343 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
344 |
+
out_ch,
|
345 |
+
kernel_size=3,
|
346 |
+
stride=1,
|
347 |
+
padding=1)
|
348 |
+
|
349 |
+
def forward(self, z):
|
350 |
+
assert z.shape[1:] == self.z_shape[1:]
|
351 |
+
self.last_z_shape = z.shape
|
352 |
+
|
353 |
+
# z to block_in
|
354 |
+
h = self.conv_in(z)
|
355 |
+
|
356 |
+
# middle
|
357 |
+
h = self.mid.block_1(h)
|
358 |
+
h = self.mid.attn_1(h)
|
359 |
+
h = self.mid.block_2(h)
|
360 |
+
|
361 |
+
# upsampling
|
362 |
+
for i_level in reversed(range(self.num_resolutions)):
|
363 |
+
for i_block in range(self.num_res_blocks+1):
|
364 |
+
h = self.up[i_level].block[i_block](h)
|
365 |
+
if len(self.up[i_level].attn) > 0:
|
366 |
+
h = self.up[i_level].attn[i_block](h)
|
367 |
+
if i_level != 0:
|
368 |
+
h = self.up[i_level].upsample(h)
|
369 |
+
|
370 |
+
h = self.norm_out(h)
|
371 |
+
h = nonlinearity(h)
|
372 |
+
h = self.conv_out(h)
|
373 |
+
return h
|
dalle/models/stage1/vqgan.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------
|
2 |
+
# Modified from VQGAN (https://github.com/CompVis/taming-transformers)
|
3 |
+
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
|
4 |
+
# ------------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from typing import List, Tuple, Optional
|
9 |
+
from einops import rearrange
|
10 |
+
from omegaconf import OmegaConf
|
11 |
+
from .layers import Encoder, Decoder
|
12 |
+
|
13 |
+
|
14 |
+
class VectorQuantizer(nn.Module):
|
15 |
+
"""
|
16 |
+
Simplified VectorQuantizer in the original VQGAN repository
|
17 |
+
by removing unncessary modules for sampling
|
18 |
+
"""
|
19 |
+
def __init__(self, dim: int, n_embed: int, beta: float) -> None:
|
20 |
+
super().__init__()
|
21 |
+
self.n_embed = n_embed
|
22 |
+
self.dim = dim
|
23 |
+
self.beta = beta
|
24 |
+
|
25 |
+
self.embedding = nn.Embedding(self.n_embed, self.dim)
|
26 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_embed, 1.0 / self.n_embed)
|
27 |
+
|
28 |
+
def forward(self,
|
29 |
+
z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.LongTensor]:
|
30 |
+
z = rearrange(z, 'b c h w -> b h w c').contiguous() # [B,C,H,W] -> [B,H,W,C]
|
31 |
+
z_flattened = z.view(-1, self.dim)
|
32 |
+
|
33 |
+
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
34 |
+
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
|
35 |
+
torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
|
36 |
+
|
37 |
+
min_encoding_indices = torch.argmin(d, dim=1)
|
38 |
+
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
39 |
+
return z_q, min_encoding_indices
|
40 |
+
|
41 |
+
def get_codebook_entry(self,
|
42 |
+
indices: torch.LongTensor,
|
43 |
+
shape: Optional[List[int]] = None) -> torch.FloatTensor:
|
44 |
+
z_q = self.embedding(indices)
|
45 |
+
if shape is not None:
|
46 |
+
z_q = z_q.view(shape)
|
47 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
48 |
+
return z_q
|
49 |
+
|
50 |
+
|
51 |
+
class VQGAN(nn.Module):
|
52 |
+
def __init__(self, n_embed: int, embed_dim: int, hparams: OmegaConf) -> None:
|
53 |
+
super().__init__()
|
54 |
+
self.encoder = Encoder(**hparams)
|
55 |
+
self.decoder = Decoder(**hparams)
|
56 |
+
self.quantize = VectorQuantizer(dim=embed_dim, n_embed=n_embed, beta=0.25)
|
57 |
+
self.quant_conv = torch.nn.Conv2d(hparams.z_channels, embed_dim, 1)
|
58 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, hparams.z_channels, 1)
|
59 |
+
self.latent_dim = hparams.attn_resolutions[0]
|
60 |
+
|
61 |
+
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
62 |
+
quant = self.encode(x)
|
63 |
+
dec = self.decode(quant)
|
64 |
+
return dec
|
65 |
+
|
66 |
+
def encode(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
67 |
+
h = self.encoder(x)
|
68 |
+
h = self.quant_conv(h)
|
69 |
+
quant = self.quantize(h)[0]
|
70 |
+
quant = rearrange(quant, 'b h w c -> b c h w').contiguous()
|
71 |
+
return quant
|
72 |
+
|
73 |
+
def decode(self, quant: torch.FloatTensor) -> torch.FloatTensor:
|
74 |
+
quant = self.post_quant_conv(quant)
|
75 |
+
dec = self.decoder(quant)
|
76 |
+
return dec
|
77 |
+
|
78 |
+
def decode_code(self, code: torch.LongTensor) -> torch.FloatTensor:
|
79 |
+
quant = self.quantize.get_codebook_entry(code)
|
80 |
+
quant = quant.permute(0, 3, 1, 2)
|
81 |
+
dec = self.decode(quant)
|
82 |
+
return dec
|
83 |
+
|
84 |
+
def get_codes(self, x: torch.FloatTensor) -> torch.LongTensor:
|
85 |
+
h = self.encoder(x)
|
86 |
+
h = self.quant_conv(h)
|
87 |
+
codes = self.quantize(h)[1].view(x.shape[0], self.latent_dim ** 2)
|
88 |
+
return codes
|
89 |
+
|
90 |
+
def from_ckpt(self, path: str, strict: bool = True) -> None:
|
91 |
+
ckpt = torch.load(path, map_location='cpu')['state_dict']
|
92 |
+
self.load_state_dict(ckpt, strict=strict)
|
93 |
+
print(f'{path} successfully restored..')
|
dalle/models/stage2/__pycache__/layers.cpython-38.pyc
ADDED
Binary file (5.71 kB). View file
|
|
dalle/models/stage2/__pycache__/transformer.cpython-38.pyc
ADDED
Binary file (11.7 kB). View file
|
|
dalle/models/stage2/layers.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------
|
2 |
+
# Minimal DALL-E
|
3 |
+
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------
|
6 |
+
# Modified from minGPT (https://github.com/karpathy/minGPT)
|
7 |
+
# Copyright (c) 2020 Andrej Karpathy. All Rights Reserved.
|
8 |
+
# ------------------------------------------------------------------------------------
|
9 |
+
|
10 |
+
import math
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from torch.nn import functional as F
|
14 |
+
|
15 |
+
|
16 |
+
class GELU(nn.Module):
|
17 |
+
def __init__(self, use_approx=False):
|
18 |
+
super().__init__()
|
19 |
+
self.use_approx = use_approx
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
if self.use_approx:
|
23 |
+
return x * torch.sigmoid(1.702 * x)
|
24 |
+
else:
|
25 |
+
return F.gelu(x)
|
26 |
+
|
27 |
+
|
28 |
+
class MultiHeadSelfAttention(nn.Module):
|
29 |
+
|
30 |
+
def __init__(self,
|
31 |
+
ctx_len: int,
|
32 |
+
embed_dim: int,
|
33 |
+
n_heads: int,
|
34 |
+
resid_pdrop: float,
|
35 |
+
attn_pdrop: float,
|
36 |
+
attn_bias: bool,
|
37 |
+
use_mask: bool = True):
|
38 |
+
super().__init__()
|
39 |
+
assert embed_dim % n_heads == 0
|
40 |
+
|
41 |
+
# key, query, value projections for all heads
|
42 |
+
self.key = nn.Linear(embed_dim, embed_dim, bias=attn_bias)
|
43 |
+
self.query = nn.Linear(embed_dim, embed_dim, bias=attn_bias)
|
44 |
+
self.value = nn.Linear(embed_dim, embed_dim, bias=attn_bias)
|
45 |
+
|
46 |
+
# regularization
|
47 |
+
self.attn_drop = nn.Dropout(attn_pdrop)
|
48 |
+
self.resid_drop = nn.Dropout(resid_pdrop)
|
49 |
+
|
50 |
+
# output projection
|
51 |
+
self.proj = nn.Linear(embed_dim, embed_dim, attn_bias)
|
52 |
+
|
53 |
+
self.n_heads = n_heads
|
54 |
+
self.ctx_len = ctx_len
|
55 |
+
self.use_mask = use_mask
|
56 |
+
if self.use_mask:
|
57 |
+
self.register_buffer("mask", torch.ones(ctx_len, ctx_len), persistent=False)
|
58 |
+
self.mask = torch.tril(self.mask).view(1, ctx_len, ctx_len)
|
59 |
+
|
60 |
+
def forward(self, x, use_cache=False, layer_past=None):
|
61 |
+
B, T, C = x.shape
|
62 |
+
x = x.transpose(0, 1).contiguous() # (B, T, C) -> (T, B, C)
|
63 |
+
|
64 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
65 |
+
k = self.key(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
|
66 |
+
q = self.query(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
|
67 |
+
v = self.value(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
|
68 |
+
|
69 |
+
if use_cache:
|
70 |
+
present = torch.stack([k, v])
|
71 |
+
|
72 |
+
if layer_past is not None:
|
73 |
+
# print(layer_past.shape, k.shape, v.shape, q.shape)
|
74 |
+
# print("LayerPast shape", layer_past.shape)
|
75 |
+
past_key, past_value = layer_past
|
76 |
+
|
77 |
+
if len(past_key.shape) == 4:
|
78 |
+
_, _, seq_len, dim = past_key.shape
|
79 |
+
k = torch.cat([past_key.reshape(-1, seq_len, dim), k], dim=-2)
|
80 |
+
v = torch.cat([past_value.reshape(-1, seq_len, dim), v], dim=-2)
|
81 |
+
elif len(past_key.shape) == 3:
|
82 |
+
past_key, past_value = layer_past
|
83 |
+
k = torch.cat([past_key, k], dim=-2)
|
84 |
+
v = torch.cat([past_value, v], dim=-2)
|
85 |
+
else:
|
86 |
+
raise ValueError
|
87 |
+
|
88 |
+
if use_cache and layer_past is not None:
|
89 |
+
# Tensor shape below: (B * nh, 1, hs) X (B * nh, hs, K) -> (B * nh, 1, K)
|
90 |
+
att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))))
|
91 |
+
att = F.softmax(att, dim=-1)
|
92 |
+
att = self.attn_drop(att)
|
93 |
+
y = torch.bmm(att, v) # (B*nh, 1, K) X (B*nh, K, hs) -> (B*nh, 1, hs)
|
94 |
+
else:
|
95 |
+
# Tensor shape below: (B * nh, T, hs) X (B * nh, hs, T) -> (B * nh, T, T)
|
96 |
+
att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))))
|
97 |
+
if self.use_mask:
|
98 |
+
# TODO : Flip when not prompt tunign
|
99 |
+
# mask = self.mask if T == self.ctx_len else self.mask[:, :T, :T]
|
100 |
+
if T == self.ctx_len:
|
101 |
+
mask = self.mask
|
102 |
+
else:
|
103 |
+
mask = torch.tril(torch.ones(T, T)).view(1, T, T).to(att.device)
|
104 |
+
att = att.masked_fill(mask == 0, float('-inf'))
|
105 |
+
att = F.softmax(att, dim=-1)
|
106 |
+
att = self.attn_drop(att)
|
107 |
+
y = torch.bmm(att, v) # (B*nh, T, T) X (B*nh, T, hs) -> (B*nh, T, hs)
|
108 |
+
y = y.transpose(0, 1).contiguous().view(T, B, C) # re-assemble all head outputs side by side
|
109 |
+
|
110 |
+
# output projection
|
111 |
+
y = self.resid_drop(self.proj(y))
|
112 |
+
if use_cache:
|
113 |
+
return y.transpose(0, 1).contiguous(), present # (T, B, C) -> (B, T, C)
|
114 |
+
else:
|
115 |
+
return y.transpose(0, 1).contiguous() # (T, B, C) -> (B, T, C)
|
116 |
+
|
117 |
+
def forward_with_context(self, x, context, mask=None):
|
118 |
+
B, T, C = x.shape
|
119 |
+
x = x.transpose(0, 1).contiguous() # (B, T, C) -> (T, B, C)
|
120 |
+
|
121 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
122 |
+
q = self.query(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
|
123 |
+
|
124 |
+
B, T_c, C = context.shape
|
125 |
+
k = self.key(context).view(T_c, B * self.n_heads, C // self.n_heads).transpose(0, 1) # (B*nh, T, hs)
|
126 |
+
v = self.value(context).view(T_c, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
|
127 |
+
|
128 |
+
# Tensor shape below: (B * nh, T, hs) X (B * nh, hs, Tc) -> (B * nh, T, Tc)
|
129 |
+
att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))))
|
130 |
+
att = F.softmax(att, dim=-1)
|
131 |
+
att = self.attn_drop(att)
|
132 |
+
y = torch.bmm(att, v) # (B*nh, T, T) X (B*nh, T, hs) -> (B*nh, T, hs)
|
133 |
+
y = y.transpose(0, 1).contiguous().view(T, B, C) # re-assemble all head outputs side by side
|
134 |
+
|
135 |
+
# output projection
|
136 |
+
y = self.resid_drop(self.proj(y)).transpose(0, 1).contiguous()
|
137 |
+
if mask is not None:
|
138 |
+
y = y.masked_fill(mask == 0, float('0.0'))
|
139 |
+
return y # (T, B, C) -> (B, T, C)
|
140 |
+
|
141 |
+
|
142 |
+
class Block(nn.Module):
|
143 |
+
|
144 |
+
def __init__(self,
|
145 |
+
ctx_len: int,
|
146 |
+
embed_dim: int,
|
147 |
+
n_heads: int,
|
148 |
+
mlp_bias: bool,
|
149 |
+
attn_bias: bool,
|
150 |
+
resid_pdrop: bool,
|
151 |
+
attn_pdrop: bool,
|
152 |
+
gelu_use_approx: bool):
|
153 |
+
super().__init__()
|
154 |
+
self.ln1 = nn.LayerNorm(embed_dim)
|
155 |
+
self.ln2 = nn.LayerNorm(embed_dim)
|
156 |
+
|
157 |
+
self.attn = MultiHeadSelfAttention(ctx_len=ctx_len,
|
158 |
+
embed_dim=embed_dim,
|
159 |
+
n_heads=n_heads,
|
160 |
+
attn_pdrop=attn_pdrop,
|
161 |
+
resid_pdrop=resid_pdrop,
|
162 |
+
attn_bias=attn_bias,
|
163 |
+
use_mask=True)
|
164 |
+
self.mlp = nn.Sequential(
|
165 |
+
nn.Linear(embed_dim, 4 * embed_dim, bias=mlp_bias),
|
166 |
+
GELU(gelu_use_approx),
|
167 |
+
nn.Linear(4 * embed_dim, embed_dim, bias=mlp_bias),
|
168 |
+
nn.Dropout(resid_pdrop),
|
169 |
+
)
|
170 |
+
|
171 |
+
def forward(self, x, layer_past=None):
|
172 |
+
x = x + self.attn(self.ln1(x), layer_past=layer_past)
|
173 |
+
x = x + self.mlp(self.ln2(x))
|
174 |
+
return x
|
175 |
+
|
176 |
+
def sample(self, x, layer_past=None):
|
177 |
+
attn, present = self.attn(self.ln1(x), use_cache=True, layer_past=layer_past)
|
178 |
+
x = x + attn
|
179 |
+
x = x + self.mlp(self.ln2(x))
|
180 |
+
return x, present
|
181 |
+
|
182 |
+
def sample_with_context(self, x, context, context_mask, cross_attn_layer, layer_past=None):
|
183 |
+
attn, present = self.attn(self.ln1(x), use_cache=True, layer_past=layer_past)
|
184 |
+
x = x + attn
|
185 |
+
c_attn = cross_attn_layer(x, context, context_mask)
|
186 |
+
x = x + c_attn
|
187 |
+
x = x + self.mlp(self.ln2(x))
|
188 |
+
return x, present
|
189 |
+
|
190 |
+
|
191 |
+
class CrossAttentionLayer(nn.Module):
|
192 |
+
|
193 |
+
def __init__(self,
|
194 |
+
ctx_len: int,
|
195 |
+
embed_dim: int,
|
196 |
+
n_heads: int,
|
197 |
+
attn_bias: bool,
|
198 |
+
resid_pdrop: bool,
|
199 |
+
attn_pdrop: bool):
|
200 |
+
super().__init__()
|
201 |
+
|
202 |
+
self.ln1 = nn.LayerNorm(embed_dim)
|
203 |
+
self.ln2 = nn.LayerNorm(embed_dim)
|
204 |
+
self.attn = MultiHeadSelfAttention(ctx_len=ctx_len,
|
205 |
+
embed_dim=embed_dim,
|
206 |
+
n_heads=n_heads,
|
207 |
+
attn_pdrop=attn_pdrop,
|
208 |
+
resid_pdrop=resid_pdrop,
|
209 |
+
attn_bias=attn_bias,
|
210 |
+
use_mask=False)
|
211 |
+
|
212 |
+
def forward(self, x, context, context_mask=None):
|
213 |
+
attn = self.attn.forward_with_context(self.ln1(x), self.ln2(context), context_mask)
|
214 |
+
# x = x + attn
|
215 |
+
# return x
|
216 |
+
return attn
|
dalle/models/stage2/transformer.py
ADDED
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------
|
2 |
+
# Minimal DALL-E
|
3 |
+
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------
|
6 |
+
# Modified from minGPT (https://github.com/karpathy/minGPT)
|
7 |
+
# Copyright (c) 2020 Andrej Karpathy. All Rights Reserved.
|
8 |
+
# ------------------------------------------------------------------------------------
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
from typing import Optional, Tuple, List
|
13 |
+
from torch.cuda.amp import autocast
|
14 |
+
from omegaconf import OmegaConf
|
15 |
+
from .layers import Block
|
16 |
+
|
17 |
+
class Transformer1d(nn.Module):
|
18 |
+
|
19 |
+
def __init__(self,
|
20 |
+
vocab_size_txt: int,
|
21 |
+
vocab_size_img: int,
|
22 |
+
hparams: OmegaConf) -> None:
|
23 |
+
super().__init__()
|
24 |
+
assert hparams.n_layers == hparams.n_dense_layers
|
25 |
+
|
26 |
+
# input embedding for image and text
|
27 |
+
self.tok_emb_img = nn.Embedding(vocab_size_img, hparams.embed_dim)
|
28 |
+
self.tok_emb_txt = nn.Embedding(vocab_size_txt, hparams.embed_dim)
|
29 |
+
|
30 |
+
self.pos_emb_img = nn.Embedding(hparams.ctx_len_img, hparams.embed_dim)
|
31 |
+
self.pos_emb_txt = nn.Embedding(hparams.ctx_len_txt, hparams.embed_dim)
|
32 |
+
|
33 |
+
self.drop = nn.Dropout(hparams.embd_pdrop)
|
34 |
+
|
35 |
+
# transformer blocks
|
36 |
+
self.blocks = [Block(ctx_len=hparams.ctx_len_img + hparams.ctx_len_txt,
|
37 |
+
embed_dim=hparams.embed_dim,
|
38 |
+
n_heads=hparams.n_heads,
|
39 |
+
mlp_bias=hparams.mlp_bias,
|
40 |
+
attn_bias=hparams.attn_bias,
|
41 |
+
resid_pdrop=hparams.resid_pdrop,
|
42 |
+
attn_pdrop=hparams.attn_pdrop,
|
43 |
+
gelu_use_approx=hparams.gelu_use_approx) for i in range(1, hparams.n_layers+1)]
|
44 |
+
self.blocks = nn.Sequential(*self.blocks)
|
45 |
+
|
46 |
+
# heads for image and text
|
47 |
+
self.ln_f = nn.LayerNorm(hparams.embed_dim)
|
48 |
+
self.head_img = nn.Linear(hparams.embed_dim, vocab_size_img, bias=False)
|
49 |
+
self.head_txt = nn.Linear(hparams.embed_dim, vocab_size_txt, bias=False)
|
50 |
+
|
51 |
+
self.ctx_len_img = hparams.ctx_len_img
|
52 |
+
self.ctx_len_txt = hparams.ctx_len_txt
|
53 |
+
self.n_layers = hparams.n_layers
|
54 |
+
|
55 |
+
self.apply(self._init_weights)
|
56 |
+
|
57 |
+
|
58 |
+
def _init_weights(self, module: nn.Module) -> None:
|
59 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
60 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
61 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
62 |
+
module.bias.data.zero_()
|
63 |
+
elif isinstance(module, nn.LayerNorm):
|
64 |
+
module.bias.data.zero_()
|
65 |
+
module.weight.data.fill_(1.0)
|
66 |
+
|
67 |
+
|
68 |
+
def resize_token_embeddings(self, new_num_tokens):
|
69 |
+
|
70 |
+
old_num_tokens, old_embedding_dim = self.tok_emb_txt.weight.size()
|
71 |
+
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
|
72 |
+
new_embeddings.to(self.tok_emb_txt.weight.device, dtype=self.tok_emb_txt.weight.dtype)
|
73 |
+
self._init_weights(new_embeddings)
|
74 |
+
# numbers of tokens to copy
|
75 |
+
n = min(old_num_tokens, new_num_tokens)
|
76 |
+
new_embeddings.weight.data[:n, :] = self.tok_emb_txt.weight.data[:n, :]
|
77 |
+
self.tok_emb_txt = new_embeddings
|
78 |
+
|
79 |
+
self.resize_lm_head(new_num_tokens)
|
80 |
+
# TODO: also change config to reflect new vocab size
|
81 |
+
|
82 |
+
return new_embeddings
|
83 |
+
|
84 |
+
|
85 |
+
def resize_lm_head(
|
86 |
+
self, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False) -> nn.Linear:
|
87 |
+
|
88 |
+
old_num_tokens, old_lm_head_dim = (
|
89 |
+
self.head_txt.weight.size() if not transposed else self.head_txt.weight.t().size()
|
90 |
+
)
|
91 |
+
# Build new lm head
|
92 |
+
new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
|
93 |
+
has_new_lm_head_bias = self.head_txt.bias is not None
|
94 |
+
new_lm_head = nn.Linear(*new_lm_head_shape, bias=has_new_lm_head_bias)
|
95 |
+
new_lm_head = new_lm_head.to(self.head_txt.weight.device, dtype=self.head_txt.weight.dtype)
|
96 |
+
|
97 |
+
# initialize new lm head (in particular added tokens)
|
98 |
+
self._init_weights(new_lm_head)
|
99 |
+
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
100 |
+
# Copy old lm head weights to new lm head
|
101 |
+
if not transposed:
|
102 |
+
new_lm_head.weight.data[:num_tokens_to_copy, :] = self.head_txt.weight.data[:num_tokens_to_copy, :]
|
103 |
+
else:
|
104 |
+
new_lm_head.weight.data[:, :num_tokens_to_copy] = self.head_txt.weight.data[:, :num_tokens_to_copy]
|
105 |
+
|
106 |
+
# Copy bias weights to new lm head
|
107 |
+
if has_new_lm_head_bias:
|
108 |
+
new_lm_head.bias.data[:num_tokens_to_copy] = self.head_txt.bias.data[:num_tokens_to_copy]
|
109 |
+
|
110 |
+
self.head_txt = new_lm_head
|
111 |
+
|
112 |
+
return new_lm_head
|
113 |
+
|
114 |
+
|
115 |
+
def forward(self,
|
116 |
+
images: torch.LongTensor,
|
117 |
+
texts: torch.LongTensor,
|
118 |
+
pos_images: torch.LongTensor,
|
119 |
+
pos_texts: torch.LongTensor,
|
120 |
+
past: Optional[List[torch.Tensor]] = None,
|
121 |
+
prompt: Optional[List[torch.Tensor]] = None,
|
122 |
+
pos_prompt: Optional[List[torch.Tensor]] = None) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
123 |
+
|
124 |
+
|
125 |
+
B, T = images.shape
|
126 |
+
_, N = texts.shape
|
127 |
+
|
128 |
+
assert T <= self.ctx_len_img, "Already reached the maximum context length (image)."
|
129 |
+
assert N == self.ctx_len_txt, "Already reached the maximum context length (text)."
|
130 |
+
|
131 |
+
texts = self.tok_emb_txt(texts)
|
132 |
+
images = self.tok_emb_img(images)
|
133 |
+
|
134 |
+
texts = texts + self.pos_emb_txt(pos_texts)
|
135 |
+
images = images + self.pos_emb_img(pos_images)
|
136 |
+
|
137 |
+
if prompt is not None:
|
138 |
+
prompt = prompt + self.pos_emb_txt(pos_prompt)
|
139 |
+
texts = torch.cat([prompt, texts], dim=1).contiguous()
|
140 |
+
P = prompt.shape[1]
|
141 |
+
|
142 |
+
x = torch.cat([texts, images], dim=1).contiguous()
|
143 |
+
x = self.drop(x)
|
144 |
+
|
145 |
+
# x = self.blocks(x)
|
146 |
+
for i, block in enumerate(self.blocks):
|
147 |
+
x, _ = block.sample(x, layer_past=None if past is None else past[i])
|
148 |
+
|
149 |
+
x = self.ln_f(x)
|
150 |
+
|
151 |
+
if prompt is not None:
|
152 |
+
texts = x[:, P:N+P-1].contiguous()
|
153 |
+
images = x[:, N+P-1:-1].contiguous()
|
154 |
+
else:
|
155 |
+
texts = x[:, :N-1].contiguous()
|
156 |
+
images = x[:, N-1:-1].contiguous()
|
157 |
+
|
158 |
+
logits_txt = self.head_txt(texts)
|
159 |
+
logits_img = self.head_img(images)
|
160 |
+
return logits_img, logits_txt
|
161 |
+
|
162 |
+
def forward_with_context(self,
|
163 |
+
images: torch.LongTensor,
|
164 |
+
texts: torch.LongTensor,
|
165 |
+
pos_images: torch.LongTensor,
|
166 |
+
pos_texts: torch.LongTensor,
|
167 |
+
src_images: torch.LongTensor,
|
168 |
+
src_pos_images: torch.LongTensor,
|
169 |
+
cross_attention_idxs: List,
|
170 |
+
cross_attention_layers,
|
171 |
+
past: Optional[List[torch.Tensor]] = None,
|
172 |
+
prompt: Optional[List[torch.Tensor]] = None,
|
173 |
+
pos_prompt: Optional[List[torch.Tensor]] = None) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
174 |
+
|
175 |
+
|
176 |
+
B, T = images.shape
|
177 |
+
_, N = texts.shape
|
178 |
+
|
179 |
+
assert T <= self.ctx_len_img, "Already reached the maximum context length (image)."
|
180 |
+
assert N == self.ctx_len_txt, "Already reached the maximum context length (text)."
|
181 |
+
|
182 |
+
texts = self.tok_emb_txt(texts)
|
183 |
+
images = self.tok_emb_img(images)
|
184 |
+
src_images = self.tok_emb_img(src_images)
|
185 |
+
|
186 |
+
texts = texts + self.pos_emb_txt(pos_texts)
|
187 |
+
images = images + self.pos_emb_img(pos_images)
|
188 |
+
src_images = src_images + self.pos_emb_img(src_pos_images)
|
189 |
+
|
190 |
+
if prompt is not None:
|
191 |
+
prompt = prompt + self.pos_emb_txt(pos_prompt)
|
192 |
+
texts = torch.cat([prompt, texts], dim=1).contiguous()
|
193 |
+
P = prompt.shape[1]
|
194 |
+
else:
|
195 |
+
P = 0
|
196 |
+
|
197 |
+
x = torch.cat([texts, images], axis=1).contiguous()
|
198 |
+
x = self.drop(x)
|
199 |
+
|
200 |
+
# prepare mask
|
201 |
+
mask = torch.zeros_like(x[0])
|
202 |
+
mask[self.ctx_len_txt+P-1:, :].fill_(1.0)
|
203 |
+
mask = mask.unsqueeze(0)
|
204 |
+
|
205 |
+
# print(images.shape, texts.shape, src_images.shape, mask.shape, x.shape)
|
206 |
+
|
207 |
+
# x = self.blocks(x)
|
208 |
+
for i, block in enumerate(self.blocks):
|
209 |
+
if i in cross_attention_idxs:
|
210 |
+
x, _ = block.sample_with_context(x, src_images, mask, cross_attention_layers[int(((i+1)/3)-1)], layer_past=None if past is None else past[i])
|
211 |
+
else:
|
212 |
+
x, _ = block.sample(x, layer_past=None if past is None else past[i])
|
213 |
+
|
214 |
+
x = self.ln_f(x)
|
215 |
+
|
216 |
+
if prompt is not None:
|
217 |
+
texts = x[:, P:N+P-1].contiguous()
|
218 |
+
images = x[:, N+P-1:-1].contiguous()
|
219 |
+
else:
|
220 |
+
texts = x[:, :N-1].contiguous()
|
221 |
+
images = x[:, N-1:-1].contiguous()
|
222 |
+
|
223 |
+
logits_txt = self.head_txt(texts)
|
224 |
+
logits_img = self.head_img(images)
|
225 |
+
return logits_img, logits_txt
|
226 |
+
|
227 |
+
@torch.no_grad()
|
228 |
+
def sampling(self,
|
229 |
+
images: torch.LongTensor,
|
230 |
+
texts: torch.LongTensor,
|
231 |
+
pos_images: torch.LongTensor,
|
232 |
+
pos_texts: torch.LongTensor,
|
233 |
+
use_fp16: bool = True,
|
234 |
+
past: Optional[List[torch.Tensor]] = None,
|
235 |
+
prompt: Optional[List[torch.Tensor]] = None,
|
236 |
+
pos_prompt: Optional[List[torch.Tensor]] = None) -> Tuple[torch.FloatTensor, List[torch.FloatTensor]]:
|
237 |
+
|
238 |
+
_, N = texts.shape
|
239 |
+
assert N == self.ctx_len_txt, "Already reached the maximum context length (text)."
|
240 |
+
|
241 |
+
with autocast(enabled=use_fp16):
|
242 |
+
if images is None:
|
243 |
+
# assert past is None
|
244 |
+
|
245 |
+
texts = self.tok_emb_txt(texts)
|
246 |
+
x = texts + self.pos_emb_txt(pos_texts)
|
247 |
+
|
248 |
+
if prompt is not None:
|
249 |
+
prompt = prompt + self.pos_emb_txt(pos_prompt)
|
250 |
+
texts = torch.cat([prompt, texts], dim=1).contiguous()
|
251 |
+
|
252 |
+
x = self.drop(x)
|
253 |
+
|
254 |
+
if past is not None:
|
255 |
+
past = torch.cat(past, dim=-2)
|
256 |
+
|
257 |
+
presents = []
|
258 |
+
for i, block in enumerate(self.blocks):
|
259 |
+
x, present = block.sample(x, layer_past=None if past is None else past[i])
|
260 |
+
presents.append(present)
|
261 |
+
x = self.ln_f(x)
|
262 |
+
x = x[:, N-1].contiguous()
|
263 |
+
logits = self.head_img(x)
|
264 |
+
else:
|
265 |
+
if past is None:
|
266 |
+
texts = self.tok_emb_txt(texts)
|
267 |
+
images = self.tok_emb_img(images)
|
268 |
+
texts = texts + self.pos_emb_txt(pos_texts)
|
269 |
+
images = images + self.pos_emb_img(pos_images)
|
270 |
+
|
271 |
+
if prompt is not None:
|
272 |
+
prompt = prompt + self.pos_emb_txt(pos_prompt)
|
273 |
+
texts = torch.cat([prompt, texts], dim=1).contiguous()
|
274 |
+
|
275 |
+
x = torch.cat([texts, images], axis=1).contiguous()
|
276 |
+
else:
|
277 |
+
images = self.tok_emb_img(images)
|
278 |
+
x = images + self.pos_emb_img(pos_images)
|
279 |
+
x = self.drop(x)
|
280 |
+
|
281 |
+
# if past is not None and len(past) > 1:
|
282 |
+
if past is not None:
|
283 |
+
past = torch.cat(past, dim=-2)
|
284 |
+
# print('Past', past.shape)
|
285 |
+
presents = []
|
286 |
+
# print(len(past), past[0].shape)
|
287 |
+
for i, block in enumerate(self.blocks):
|
288 |
+
x, present = block.sample(x, layer_past=None if past is None else past[i])
|
289 |
+
presents.append(present)
|
290 |
+
x = self.ln_f(x)
|
291 |
+
x = x[:, -1].contiguous()
|
292 |
+
logits = self.head_img(x)
|
293 |
+
return logits, presents
|
294 |
+
|
295 |
+
@torch.no_grad()
|
296 |
+
def sampling_with_context(self,
|
297 |
+
images: torch.LongTensor,
|
298 |
+
cross_attention_idxs,
|
299 |
+
cross_attention_layers,
|
300 |
+
texts: torch.LongTensor,
|
301 |
+
pos_images: torch.LongTensor,
|
302 |
+
pos_texts: torch.LongTensor,
|
303 |
+
source_image: torch.LongTensor,
|
304 |
+
use_fp16: bool = True,
|
305 |
+
past: Optional[List[torch.Tensor]] = None,
|
306 |
+
prompt: Optional[List[torch.Tensor]] = None,
|
307 |
+
pos_prompt: Optional[List[torch.Tensor]] = None
|
308 |
+
) -> Tuple[torch.FloatTensor, List[torch.FloatTensor]]:
|
309 |
+
|
310 |
+
_, N = texts.shape
|
311 |
+
assert N == self.ctx_len_txt, "Already reached the maximum context length (text)."
|
312 |
+
|
313 |
+
if prompt is not None:
|
314 |
+
P = prompt.shape[1]
|
315 |
+
else:
|
316 |
+
P = 0
|
317 |
+
|
318 |
+
with autocast(enabled=use_fp16):
|
319 |
+
if images is None:
|
320 |
+
# assert past is None
|
321 |
+
|
322 |
+
texts = self.tok_emb_txt(texts)
|
323 |
+
texts = texts + self.pos_emb_txt(pos_texts)
|
324 |
+
|
325 |
+
if prompt is not None:
|
326 |
+
prompt = prompt + self.pos_emb_txt(pos_prompt)
|
327 |
+
texts = torch.cat([prompt, texts], dim=1).contiguous()
|
328 |
+
|
329 |
+
x = self.drop(texts)
|
330 |
+
|
331 |
+
if past is not None:
|
332 |
+
past = torch.cat(past, dim=-2)
|
333 |
+
|
334 |
+
# prepare mask
|
335 |
+
mask = torch.zeros_like(x[0])
|
336 |
+
mask[self.ctx_len_txt+P - 1:, :].fill_(1.0)
|
337 |
+
mask = mask.unsqueeze(0)
|
338 |
+
|
339 |
+
presents = []
|
340 |
+
for i, block in enumerate(self.blocks):
|
341 |
+
if i in cross_attention_idxs:
|
342 |
+
x, present = block.sample_with_context(x, source_image, mask,
|
343 |
+
cross_attention_layers[int(((i + 1) / 3) - 1)],
|
344 |
+
layer_past=None if past is None else past[i])
|
345 |
+
else:
|
346 |
+
x, present = block.sample(x, layer_past=None if past is None else past[i])
|
347 |
+
presents.append(present)
|
348 |
+
x = self.ln_f(x)
|
349 |
+
x = x[:, N-1].contiguous()
|
350 |
+
logits = self.head_img(x)
|
351 |
+
else:
|
352 |
+
if past is None:
|
353 |
+
texts = self.tok_emb_txt(texts)
|
354 |
+
images = self.tok_emb_img(images)
|
355 |
+
texts = texts + self.pos_emb_txt(pos_texts)
|
356 |
+
images = images + self.pos_emb_img(pos_images)
|
357 |
+
|
358 |
+
if prompt is not None:
|
359 |
+
prompt = prompt + self.pos_emb_txt(pos_prompt)
|
360 |
+
texts = torch.cat([prompt, texts], dim=1).contiguous()
|
361 |
+
|
362 |
+
x = torch.cat([texts, images], axis=1).contiguous()
|
363 |
+
else:
|
364 |
+
images = self.tok_emb_img(images)
|
365 |
+
x = images + self.pos_emb_img(pos_images)
|
366 |
+
x = self.drop(x)
|
367 |
+
|
368 |
+
# if past is not None and len(past) > 1:
|
369 |
+
if past is not None:
|
370 |
+
past = torch.cat(past, dim=-2)
|
371 |
+
presents = []
|
372 |
+
|
373 |
+
# prepare mask
|
374 |
+
mask = torch.zeros_like(x[0])
|
375 |
+
mask[self.ctx_len_txt+P - 1:, :].fill_(1.0)
|
376 |
+
mask = mask.unsqueeze(0)
|
377 |
+
|
378 |
+
# print(len(past), past[0].shape)
|
379 |
+
for i, block in enumerate(self.blocks):
|
380 |
+
if i in cross_attention_idxs:
|
381 |
+
x, present = block.sample_with_context(x, source_image, mask,
|
382 |
+
cross_attention_layers[int(((i + 1) / 3) - 1)],
|
383 |
+
layer_past=None if past is None else past[i])
|
384 |
+
else:
|
385 |
+
x, present = block.sample(x, layer_past=None if past is None else past[i])
|
386 |
+
presents.append(present)
|
387 |
+
x = self.ln_f(x)
|
388 |
+
x = x[:, -1].contiguous()
|
389 |
+
logits = self.head_img(x)
|
390 |
+
return logits, presents
|
391 |
+
|
392 |
+
def from_ckpt(self, path: str) -> None:
|
393 |
+
ckpt = torch.load(path, map_location='cpu')['state_dict']
|
394 |
+
self.load_state_dict(ckpt, strict=True)
|
395 |
+
print(f'{path} succesfully restored..')
|
396 |
+
|
397 |
+
|
398 |
+
class iGPT(nn.Module):
|
399 |
+
def __init__(self,
|
400 |
+
vocab_size_img: int,
|
401 |
+
use_cls_cond: bool,
|
402 |
+
hparams: OmegaConf) -> None:
|
403 |
+
super().__init__()
|
404 |
+
self.use_cls_cond = use_cls_cond
|
405 |
+
|
406 |
+
# sos token embedding
|
407 |
+
if self.use_cls_cond:
|
408 |
+
self.sos = nn.Embedding(hparams.n_classes, hparams.embed_dim)
|
409 |
+
else:
|
410 |
+
self.sos = nn.Parameter(torch.randn(1, 1, hparams.embed_dim))
|
411 |
+
|
412 |
+
# input embedding
|
413 |
+
self.tok_emb_img = nn.Embedding(vocab_size_img, hparams.embed_dim)
|
414 |
+
self.pos_emb_img = nn.Embedding(hparams.ctx_len_img, hparams.embed_dim)
|
415 |
+
|
416 |
+
self.drop = nn.Dropout(hparams.embd_pdrop)
|
417 |
+
|
418 |
+
# transformer blocks
|
419 |
+
self.blocks = [Block(ctx_len=hparams.ctx_len_img + 1,
|
420 |
+
embed_dim=hparams.embed_dim,
|
421 |
+
n_heads=hparams.n_heads,
|
422 |
+
mlp_bias=hparams.mlp_bias,
|
423 |
+
attn_bias=hparams.attn_bias,
|
424 |
+
resid_pdrop=hparams.resid_pdrop,
|
425 |
+
attn_pdrop=hparams.attn_pdrop,
|
426 |
+
gelu_use_approx=hparams.gelu_use_approx) for i in range(1, hparams.n_layers+1)]
|
427 |
+
self.blocks = nn.Sequential(*self.blocks)
|
428 |
+
|
429 |
+
# head
|
430 |
+
self.ln_f = nn.LayerNorm(hparams.embed_dim)
|
431 |
+
self.head = nn.Linear(hparams.embed_dim, vocab_size_img, bias=False)
|
432 |
+
|
433 |
+
self.ctx_len_img = hparams.ctx_len_img
|
434 |
+
self.n_layers = hparams.n_layers
|
435 |
+
|
436 |
+
self.apply(self._init_weights)
|
437 |
+
|
438 |
+
def _init_weights(self, module: nn.Module) -> None:
|
439 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
440 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
441 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
442 |
+
module.bias.data.zero_()
|
443 |
+
elif isinstance(module, nn.LayerNorm):
|
444 |
+
module.bias.data.zero_()
|
445 |
+
module.weight.data.fill_(1.0)
|
446 |
+
|
447 |
+
@torch.no_grad()
|
448 |
+
def sampling(self,
|
449 |
+
sos: torch.FloatTensor,
|
450 |
+
codes: torch.LongTensor,
|
451 |
+
pos_codes: torch.LongTensor,
|
452 |
+
n_samples: int = 16,
|
453 |
+
use_fp16: bool = True,
|
454 |
+
past: Optional[torch.Tensor] = None) -> Tuple[torch.FloatTensor, List[torch.FloatTensor]]:
|
455 |
+
with autocast(enabled=use_fp16):
|
456 |
+
if codes is None:
|
457 |
+
assert past is None
|
458 |
+
xs = self.drop(sos)
|
459 |
+
presents = []
|
460 |
+
for i, block in enumerate(self.blocks):
|
461 |
+
xs, present = block.sample(xs, layer_past=None)
|
462 |
+
presents.append(present)
|
463 |
+
xs = self.ln_f(xs)
|
464 |
+
logits = self.head(xs)[:, -1]
|
465 |
+
else:
|
466 |
+
if past is None:
|
467 |
+
xs = self.tok_emb_img(codes) + self.pos_emb_img(pos_codes)
|
468 |
+
xs = torch.cat([sos, xs], dim=1)
|
469 |
+
else:
|
470 |
+
xs = self.tok_emb_img(codes) + self.pos_emb_img(pos_codes)
|
471 |
+
xs = self.drop(xs)
|
472 |
+
|
473 |
+
past = torch.cat(past, dim=-2) if past is not None else past
|
474 |
+
presents = []
|
475 |
+
for i, block in enumerate(self.blocks):
|
476 |
+
xs, present = block.sample(xs, layer_past=None if past is None else past[i])
|
477 |
+
presents.append(present)
|
478 |
+
|
479 |
+
xs = self.ln_f(xs)
|
480 |
+
logits = self.head(xs)[:, -1]
|
481 |
+
return logits, presents
|
482 |
+
|
483 |
+
def forward(self,
|
484 |
+
codes: torch.LongTensor,
|
485 |
+
labels: Optional[torch.LongTensor] = None) -> torch.FloatTensor:
|
486 |
+
B, T = codes.shape
|
487 |
+
xps = torch.arange(T, device=codes.device).repeat((B, 1))
|
488 |
+
sos = self.sos.repeat((B, 1, 1)) if labels is None else self.sos(labels).unsqueeze(1)
|
489 |
+
|
490 |
+
h = self.tok_emb_img(codes) + self.pos_emb_img(xps)
|
491 |
+
h = torch.cat([sos, h[:, :-1]], dim=1).contiguous()
|
492 |
+
|
493 |
+
h = self.drop(h)
|
494 |
+
h = self.blocks(h)
|
495 |
+
h = self.ln_f(h)
|
496 |
+
logits = self.head(h)
|
497 |
+
return logits
|
498 |
+
|
499 |
+
def from_ckpt(self, path: str, strict: bool = True) -> None:
|
500 |
+
ckpt = torch.load(path, map_location='cpu')['state_dict']
|
501 |
+
self.load_state_dict(ckpt, strict=strict)
|
502 |
+
print(f'{path} successfully restored..')
|
dalle/models/tokenizer.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------
|
2 |
+
# Minimal DALL-E
|
3 |
+
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
import os
|
8 |
+
from functools import partial
|
9 |
+
from tokenizers import CharBPETokenizer
|
10 |
+
|
11 |
+
|
12 |
+
def build_tokenizer(path: str,
|
13 |
+
context_length: int = 64,
|
14 |
+
*args,
|
15 |
+
**kwargs):
|
16 |
+
try:
|
17 |
+
from_file = partial(CharBPETokenizer.from_file,
|
18 |
+
vocab_filename=os.path.join(path, 'bpe-16k-vocab.json'),
|
19 |
+
merges_filename=os.path.join(path, 'bpe-16k-merges.txt'),
|
20 |
+
unk_token='[UNK]')
|
21 |
+
tokenizer = from_file(*args, **kwargs)
|
22 |
+
except:
|
23 |
+
from_file = partial(CharBPETokenizer.from_file,
|
24 |
+
vocab_filename=os.path.join(path, 'vocab.json'),
|
25 |
+
merges_filename=os.path.join(path, 'merges.txt'),
|
26 |
+
unk_token='[UNK]')
|
27 |
+
tokenizer = from_file(*args, **kwargs)
|
28 |
+
|
29 |
+
# tokenizer = from_file(*args, **kwargs)
|
30 |
+
tokenizer.add_special_tokens(['[PAD]'])
|
31 |
+
tokenizer.enable_padding(length=context_length,
|
32 |
+
pad_id=tokenizer.token_to_id('[PAD]'))
|
33 |
+
tokenizer.enable_truncation(max_length=context_length)
|
34 |
+
print(f'{path} successfully restored..')
|
35 |
+
return tokenizer
|
dalle/trainer_prefix.py
ADDED
@@ -0,0 +1,1629 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import json
|
3 |
+
import math
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
import shutil
|
7 |
+
import warnings
|
8 |
+
from contextlib import contextmanager
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
11 |
+
|
12 |
+
from nltk import word_tokenize
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
from packaging import version
|
16 |
+
from torch import nn
|
17 |
+
from torch.utils.data.dataloader import DataLoader
|
18 |
+
from torch.utils.data.dataset import Dataset
|
19 |
+
from torch.utils.data.distributed import DistributedSampler
|
20 |
+
from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
|
21 |
+
from tqdm.auto import tqdm, trange
|
22 |
+
from torch.nn.utils.rnn import pad_sequence
|
23 |
+
import random
|
24 |
+
|
25 |
+
from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
26 |
+
from transformers.file_utils import is_datasets_available, is_torch_tpu_available
|
27 |
+
from transformers.integrations import (
|
28 |
+
default_hp_search_backend,
|
29 |
+
is_comet_available,
|
30 |
+
is_optuna_available,
|
31 |
+
is_ray_available,
|
32 |
+
is_tensorboard_available,
|
33 |
+
is_wandb_available,
|
34 |
+
run_hp_search_optuna,
|
35 |
+
run_hp_search_ray,
|
36 |
+
)
|
37 |
+
|
38 |
+
from transformers.modeling_utils import PreTrainedModel
|
39 |
+
from transformers.optimization import AdamW, get_linear_schedule_with_warmup, get_constant_schedule_with_warmup
|
40 |
+
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
41 |
+
from transformers.trainer_utils import (
|
42 |
+
PREFIX_CHECKPOINT_DIR,
|
43 |
+
BestRun,
|
44 |
+
EvalPrediction,
|
45 |
+
EvaluationStrategy,
|
46 |
+
HPSearchBackend,
|
47 |
+
PredictionOutput,
|
48 |
+
TrainOutput,
|
49 |
+
default_compute_objective,
|
50 |
+
default_hp_space,
|
51 |
+
set_seed,
|
52 |
+
)
|
53 |
+
from transformers.training_args import TrainingArguments
|
54 |
+
from transformers.utils import logging
|
55 |
+
|
56 |
+
|
57 |
+
_use_native_amp = False
|
58 |
+
_use_apex = False
|
59 |
+
EPS = 1e-12
|
60 |
+
INIT_GUMBEL_TEMP = 5.0
|
61 |
+
|
62 |
+
control_lst = ['positive', 'negative', 'neutral']
|
63 |
+
Control_Temp = {'positive': 3967, 'negative':4633, 'neutral':8500}
|
64 |
+
control_Map = [torch.LongTensor([3967]), torch.LongTensor([4633]), torch.LongTensor([8500])]
|
65 |
+
sst_lst = [(0, 2), (1, 3), (4,)]
|
66 |
+
sst_standard = ["positive", "negative", "very positive", "very negative", "neutral"]
|
67 |
+
# Control_?Map = {j:i for i, j in enumerate(control_lst)}
|
68 |
+
|
69 |
+
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
|
70 |
+
if version.parse(torch.__version__) < version.parse("1.6"):
|
71 |
+
from transformers.file_utils import is_apex_available
|
72 |
+
|
73 |
+
if is_apex_available():
|
74 |
+
from apex import amp
|
75 |
+
_use_apex = True
|
76 |
+
else:
|
77 |
+
_use_native_amp = True
|
78 |
+
from torch.cuda.amp import autocast
|
79 |
+
|
80 |
+
if is_datasets_available():
|
81 |
+
import datasets
|
82 |
+
|
83 |
+
if is_torch_tpu_available():
|
84 |
+
import torch_xla.core.xla_model as xm
|
85 |
+
import torch_xla.debug.metrics as met
|
86 |
+
import torch_xla.distributed.parallel_loader as pl
|
87 |
+
|
88 |
+
if is_tensorboard_available():
|
89 |
+
try:
|
90 |
+
from torch.utils.tensorboard import SummaryWriter
|
91 |
+
except ImportError:
|
92 |
+
from tensorboardX import SummaryWriter
|
93 |
+
|
94 |
+
if is_wandb_available():
|
95 |
+
import wandb
|
96 |
+
|
97 |
+
if is_comet_available():
|
98 |
+
import comet_ml
|
99 |
+
|
100 |
+
if is_optuna_available():
|
101 |
+
import optuna
|
102 |
+
|
103 |
+
if is_ray_available():
|
104 |
+
from ray import tune
|
105 |
+
|
106 |
+
|
107 |
+
logger = logging.get_logger(__name__)
|
108 |
+
|
109 |
+
|
110 |
+
@contextmanager
|
111 |
+
def torch_distributed_zero_first(local_rank: int):
|
112 |
+
"""
|
113 |
+
Decorator to make all processes in distributed training wait for each local_master to do something.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
local_rank (:obj:`int`): The rank of the local process.
|
117 |
+
"""
|
118 |
+
if local_rank not in [-1, 0]:
|
119 |
+
torch.distributed.barrier()
|
120 |
+
yield
|
121 |
+
if local_rank == 0:
|
122 |
+
torch.distributed.barrier()
|
123 |
+
|
124 |
+
def helper_token2bpe(offsets):
|
125 |
+
full_lst = []
|
126 |
+
for example_offset in offsets:
|
127 |
+
bpe2token = []
|
128 |
+
token2bpe = []
|
129 |
+
token_idx = -1
|
130 |
+
# print(example_offset)
|
131 |
+
for bpe_idx, (a,b) in enumerate(example_offset):
|
132 |
+
# print(token2bpe, a, b, bpe_idx)
|
133 |
+
if b - a > 0:
|
134 |
+
if a == 0:
|
135 |
+
# new token
|
136 |
+
token_idx += 1
|
137 |
+
bpe2token.append(token_idx)
|
138 |
+
token2bpe.append([])
|
139 |
+
token2bpe[-1].append(bpe_idx)
|
140 |
+
else:
|
141 |
+
# prev token.
|
142 |
+
bpe2token.append(token_idx)
|
143 |
+
token2bpe[-1].append(bpe_idx)
|
144 |
+
else:
|
145 |
+
bpe2token.append(None)
|
146 |
+
full_lst.append((bpe2token, token2bpe))
|
147 |
+
return full_lst
|
148 |
+
|
149 |
+
class SequentialDistributedSampler(Sampler):
|
150 |
+
"""
|
151 |
+
Distributed Sampler that subsamples indicies sequentially,
|
152 |
+
making it easier to collate all results at the end.
|
153 |
+
|
154 |
+
Even though we only use this sampler for eval and predict (no training),
|
155 |
+
which means that the model params won't have to be synced (i.e. will not hang
|
156 |
+
for synchronization even if varied number of forward passes), we still add extra
|
157 |
+
samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
|
158 |
+
to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
|
159 |
+
"""
|
160 |
+
|
161 |
+
def __init__(self, dataset, num_replicas=None, rank=None):
|
162 |
+
if num_replicas is None:
|
163 |
+
if not torch.distributed.is_available():
|
164 |
+
raise RuntimeError("Requires distributed package to be available")
|
165 |
+
num_replicas = torch.distributed.get_world_size()
|
166 |
+
if rank is None:
|
167 |
+
if not torch.distributed.is_available():
|
168 |
+
raise RuntimeError("Requires distributed package to be available")
|
169 |
+
rank = torch.distributed.get_rank()
|
170 |
+
self.dataset = dataset
|
171 |
+
self.num_replicas = num_replicas
|
172 |
+
self.rank = rank
|
173 |
+
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
|
174 |
+
self.total_size = self.num_samples * self.num_replicas
|
175 |
+
|
176 |
+
def __iter__(self):
|
177 |
+
indices = list(range(len(self.dataset)))
|
178 |
+
|
179 |
+
# add extra samples to make it evenly divisible
|
180 |
+
indices += indices[: (self.total_size - len(indices))]
|
181 |
+
assert (
|
182 |
+
len(indices) == self.total_size
|
183 |
+
), f"Indices length {len(indices)} and total size {self.total_size} mismatched"
|
184 |
+
|
185 |
+
# subsample
|
186 |
+
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
|
187 |
+
assert (
|
188 |
+
len(indices) == self.num_samples
|
189 |
+
), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched"
|
190 |
+
|
191 |
+
return iter(indices)
|
192 |
+
|
193 |
+
def __len__(self):
|
194 |
+
return self.num_samples
|
195 |
+
|
196 |
+
|
197 |
+
def get_tpu_sampler(dataset: Dataset):
|
198 |
+
if xm.xrt_world_size() <= 1:
|
199 |
+
return RandomSampler(dataset)
|
200 |
+
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
201 |
+
|
202 |
+
|
203 |
+
class Trainer_Prefix:
|
204 |
+
"""
|
205 |
+
Trainer is a simple but feature-complete training and eval loop for PyTorch,
|
206 |
+
optimized for 🤗 Transformers.
|
207 |
+
|
208 |
+
Args:
|
209 |
+
model (:class:`~transformers.PreTrainedModel`, `optional`):
|
210 |
+
The model to train, evaluate or use for predictions. If not provided, a ``model_init`` must be passed.
|
211 |
+
args (:class:`~transformers.TrainingArguments`, `optional`):
|
212 |
+
The arguments to tweak for training. Will default to a basic instance of :class:`~transformers.TrainingArguments`
|
213 |
+
with the ``output_dir`` set to a directory named `tmp_trainer` in the current directory if not provided.
|
214 |
+
data_collator (:obj:`DataCollator`, `optional`):
|
215 |
+
The function to use to form a batch from a list of elements of :obj:`train_dataset` or
|
216 |
+
:obj:`eval_dataset`. Will default to :func:`~transformers.default_data_collator` if no ``tokenizer`` is
|
217 |
+
provided, an instance of :func:`~transformers.DataCollatorWithPadding` otherwise.
|
218 |
+
train_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
219 |
+
The dataset to use for training. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
220 |
+
``model.forward()`` method are automatically removed.
|
221 |
+
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
222 |
+
The dataset to use for evaluation. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
223 |
+
``model.forward()`` method are automatically removed.
|
224 |
+
tokenizer (:class:`PreTrainedTokenizerBase`, `optional`):
|
225 |
+
The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs the
|
226 |
+
maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
|
227 |
+
interrupted training or reuse the fine-tuned model.
|
228 |
+
model_init (:obj:`Callable[[], PreTrainedModel]`, `optional`):
|
229 |
+
A function that instantiates the model to be used. If provided, each call to
|
230 |
+
:meth:`~transformers.Trainer.train` will start from a new instance of the model as given by this function.
|
231 |
+
compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
|
232 |
+
The function that will be used to compute metrics at evaluation. Must take a
|
233 |
+
:class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
|
234 |
+
tb_writer (:obj:`SummaryWriter`, `optional`):
|
235 |
+
Object to write to TensorBoard.
|
236 |
+
optimizers (:obj:`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR`, `optional`):
|
237 |
+
A tuple containing the optimizer and the scheduler to use. Will default to an instance of
|
238 |
+
:class:`~transformers.AdamW` on your model and a scheduler given by
|
239 |
+
:func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`.
|
240 |
+
kwargs:
|
241 |
+
Deprecated keyword arguments.
|
242 |
+
"""
|
243 |
+
|
244 |
+
def __init__(
|
245 |
+
self,
|
246 |
+
model: Optional[PreTrainedModel] = None,
|
247 |
+
model_gpt2 : Optional[PreTrainedModel] = None,
|
248 |
+
args: TrainingArguments = None,
|
249 |
+
data_collator: Optional[DataCollator] = None,
|
250 |
+
train_dataset: Optional[Dataset] = None,
|
251 |
+
eval_dataset: Optional[Dataset] = None,
|
252 |
+
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
|
253 |
+
model_init: Callable[[], PreTrainedModel] = None,
|
254 |
+
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
255 |
+
tb_writer: Optional["SummaryWriter"] = None,
|
256 |
+
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
257 |
+
task_mode: Optional[str] = None,
|
258 |
+
use_dropout: Optional[bool] = False,
|
259 |
+
distill: Optional[bool] = False,
|
260 |
+
matching_objective:Optional[str]= None,
|
261 |
+
finetuned_gpt2: Optional[PreTrainedModel] = None,
|
262 |
+
**kwargs,
|
263 |
+
):
|
264 |
+
if args is None:
|
265 |
+
logger.info("No `TrainingArguments` passed, using the current path as `output_dir`.")
|
266 |
+
args = TrainingArguments("tmp_trainer")
|
267 |
+
self.args = args
|
268 |
+
# Seed must be set before instantiating the model when using model
|
269 |
+
set_seed(self.args.seed)
|
270 |
+
assert (
|
271 |
+
model is not None or model_init is not None
|
272 |
+
), "You must provide a model to use `Trainer`, either by using the `model` argument or the `model_init` argument."
|
273 |
+
assert model_init is None
|
274 |
+
self.model = model.to(args.device) if model is not None else None
|
275 |
+
self.gpt2 = model_gpt2.to(args.device) if model_gpt2 is not None else None
|
276 |
+
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
|
277 |
+
self.data_collator = data_collator if data_collator is not None else default_collator
|
278 |
+
self.train_dataset = train_dataset
|
279 |
+
self.eval_dataset = eval_dataset
|
280 |
+
self.tokenizer = tokenizer
|
281 |
+
self.model_init = model_init
|
282 |
+
self.compute_metrics = compute_metrics
|
283 |
+
self.optimizer, self.lr_scheduler = optimizers
|
284 |
+
self.task_mode = task_mode
|
285 |
+
self.use_dropout = use_dropout
|
286 |
+
|
287 |
+
self.curr_best_eval = 10000000.
|
288 |
+
|
289 |
+
self.distill = distill
|
290 |
+
if self.distill:
|
291 |
+
self.matching_objective = matching_objective
|
292 |
+
self.finetuned_gpt2 = finetuned_gpt2
|
293 |
+
|
294 |
+
if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
|
295 |
+
raise RuntimeError(
|
296 |
+
"Passing a `model_init` is incompatible with providing the `optimizers` argument."
|
297 |
+
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
298 |
+
)
|
299 |
+
self.tb_writer = tb_writer
|
300 |
+
self.log_history = []
|
301 |
+
if "prediction_loss_only" in kwargs:
|
302 |
+
warnings.warn(
|
303 |
+
"Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a future version. Use `args.prediction_loss_only` instead.",
|
304 |
+
FutureWarning,
|
305 |
+
)
|
306 |
+
self.args.prediction_loss_only = kwargs.pop("prediction_loss_only")
|
307 |
+
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
308 |
+
|
309 |
+
if tb_writer is None and is_tensorboard_available() and self.is_world_process_zero():
|
310 |
+
self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
|
311 |
+
if not is_tensorboard_available():
|
312 |
+
logger.warning(
|
313 |
+
"You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
|
314 |
+
)
|
315 |
+
|
316 |
+
# Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
|
317 |
+
self._loggers_initialized = False
|
318 |
+
|
319 |
+
# Create output directory if needed
|
320 |
+
if self.is_world_process_zero():
|
321 |
+
os.makedirs(self.args.output_dir, exist_ok=True)
|
322 |
+
if is_torch_tpu_available():
|
323 |
+
# Set an xla_device flag on the model's config.
|
324 |
+
# We'll find a more elegant and not need to do this in the future.
|
325 |
+
self.model.config.xla_device = True
|
326 |
+
if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
|
327 |
+
self.data_collator = self.data_collator.collate_batch
|
328 |
+
warnings.warn(
|
329 |
+
(
|
330 |
+
"The `data_collator` should now be a simple callable (function, class with `__call__`), classes "
|
331 |
+
+ "with a `collate_batch` are deprecated and won't be supported in a future version."
|
332 |
+
),
|
333 |
+
FutureWarning,
|
334 |
+
)
|
335 |
+
|
336 |
+
if is_datasets_available():
|
337 |
+
if isinstance(train_dataset, datasets.Dataset):
|
338 |
+
self._remove_unused_columns(self.train_dataset, description="training")
|
339 |
+
if isinstance(eval_dataset, datasets.Dataset):
|
340 |
+
self._remove_unused_columns(self.eval_dataset, description="evaluation")
|
341 |
+
|
342 |
+
self.global_step = None
|
343 |
+
self.epoch = None
|
344 |
+
self.total_flos = None
|
345 |
+
if self.args.fp16 and _use_native_amp:
|
346 |
+
self.scaler = torch.cuda.amp.GradScaler()
|
347 |
+
self.hp_search_backend = None
|
348 |
+
self.use_tune_checkpoints = False
|
349 |
+
if self.args.label_names is None:
|
350 |
+
self.args.label_names = (["labels"]
|
351 |
+
)
|
352 |
+
|
353 |
+
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
|
354 |
+
if not self.args.remove_unused_columns:
|
355 |
+
return
|
356 |
+
# Inspect model forward signature to keep only the arguments it accepts.
|
357 |
+
signature = inspect.signature(self.model.forward)
|
358 |
+
signature_columns = list(signature.parameters.keys())
|
359 |
+
# Labels may be named label or label_ids, the default data collator handles that.
|
360 |
+
signature_columns += ["label", "label_ids"]
|
361 |
+
columns = [k for k in signature_columns if k in dataset.column_names]
|
362 |
+
ignored_columns = list(set(dataset.column_names) - set(signature_columns))
|
363 |
+
dset_description = "" if description is None else f"in the {description} set "
|
364 |
+
logger.info(
|
365 |
+
f"The following columns {dset_description}don't have a corresponding argument in `{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
|
366 |
+
)
|
367 |
+
dataset.set_format(type=dataset.format["type"], columns=columns)
|
368 |
+
|
369 |
+
def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
|
370 |
+
if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
|
371 |
+
return None
|
372 |
+
elif is_torch_tpu_available():
|
373 |
+
return get_tpu_sampler(self.train_dataset)
|
374 |
+
else:
|
375 |
+
return (
|
376 |
+
RandomSampler(self.train_dataset)
|
377 |
+
if self.args.local_rank == -1
|
378 |
+
else DistributedSampler(self.train_dataset)
|
379 |
+
)
|
380 |
+
|
381 |
+
def get_train_dataloader(self) -> DataLoader:
|
382 |
+
"""
|
383 |
+
Returns the training :class:`~torch.utils.data.DataLoader`.
|
384 |
+
|
385 |
+
Will use no sampler if :obj:`self.train_dataset` is a :obj:`torch.utils.data.IterableDataset`, a random sampler
|
386 |
+
(adapted to distributed training if necessary) otherwise.
|
387 |
+
|
388 |
+
Subclass and override this method if you want to inject some custom behavior.
|
389 |
+
"""
|
390 |
+
if self.train_dataset is None:
|
391 |
+
raise ValueError("Trainer: training requires a train_dataset.")
|
392 |
+
train_sampler = self._get_train_sampler()
|
393 |
+
|
394 |
+
return DataLoader(
|
395 |
+
self.train_dataset,
|
396 |
+
batch_size=self.args.train_batch_size,
|
397 |
+
sampler=train_sampler,
|
398 |
+
collate_fn=self.data_collator,
|
399 |
+
drop_last=self.args.dataloader_drop_last,
|
400 |
+
num_workers=self.args.dataloader_num_workers,
|
401 |
+
worker_init_fn=np.random.seed(self.args.seed)
|
402 |
+
)
|
403 |
+
|
404 |
+
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
|
405 |
+
if isinstance(eval_dataset, torch.utils.data.IterableDataset):
|
406 |
+
return None
|
407 |
+
elif is_torch_tpu_available():
|
408 |
+
return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
409 |
+
elif self.args.local_rank != -1:
|
410 |
+
return SequentialDistributedSampler(eval_dataset)
|
411 |
+
else:
|
412 |
+
return SequentialSampler(eval_dataset)
|
413 |
+
|
414 |
+
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
415 |
+
"""
|
416 |
+
Returns the evaluation :class:`~torch.utils.data.DataLoader`.
|
417 |
+
|
418 |
+
Will use no sampler if :obj:`self.eval_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
|
419 |
+
sampler (adapted to distributed training if necessary) otherwise.
|
420 |
+
|
421 |
+
Subclass and override this method if you want to inject some custom behavior.
|
422 |
+
|
423 |
+
Args:
|
424 |
+
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
425 |
+
If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not
|
426 |
+
accepted by the ``model.forward()`` method are automatically removed.
|
427 |
+
"""
|
428 |
+
if eval_dataset is None and self.eval_dataset is None:
|
429 |
+
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
430 |
+
elif eval_dataset is not None and is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
|
431 |
+
self._remove_unused_columns(eval_dataset, description="evaluation")
|
432 |
+
eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
|
433 |
+
eval_sampler = self._get_eval_sampler(eval_dataset)
|
434 |
+
|
435 |
+
return DataLoader(
|
436 |
+
eval_dataset,
|
437 |
+
sampler=eval_sampler,
|
438 |
+
batch_size=self.args.eval_batch_size,
|
439 |
+
collate_fn=self.data_collator,
|
440 |
+
drop_last=self.args.dataloader_drop_last,
|
441 |
+
num_workers=self.args.dataloader_num_workers,
|
442 |
+
worker_init_fn=np.random.seed(self.args.seed)
|
443 |
+
)
|
444 |
+
|
445 |
+
def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
|
446 |
+
"""
|
447 |
+
Returns the test :class:`~torch.utils.data.DataLoader`.
|
448 |
+
|
449 |
+
Will use no sampler if :obj:`test_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
|
450 |
+
sampler (adapted to distributed training if necessary) otherwise.
|
451 |
+
|
452 |
+
Subclass and override this method if you want to inject some custom behavior.
|
453 |
+
|
454 |
+
Args:
|
455 |
+
eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
|
456 |
+
The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
457 |
+
``model.forward()`` method are automatically removed.
|
458 |
+
"""
|
459 |
+
if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
|
460 |
+
self._remove_unused_columns(test_dataset, description="test")
|
461 |
+
test_sampler = self._get_eval_sampler(test_dataset)
|
462 |
+
|
463 |
+
# We use the same batch_size as for eval.
|
464 |
+
return DataLoader(
|
465 |
+
test_dataset,
|
466 |
+
sampler=test_sampler,
|
467 |
+
batch_size=self.args.eval_batch_size,
|
468 |
+
collate_fn=self.data_collator,
|
469 |
+
drop_last=self.args.dataloader_drop_last,
|
470 |
+
worker_init_fn=np.random.seed(self.args.seed)
|
471 |
+
)
|
472 |
+
|
473 |
+
def create_optimizer_and_scheduler(self, num_training_steps: int):
|
474 |
+
"""
|
475 |
+
Setup the optimizer and the learning rate scheduler.
|
476 |
+
|
477 |
+
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
|
478 |
+
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
|
479 |
+
"""
|
480 |
+
if self.optimizer is None:
|
481 |
+
no_decay = ["bias", "LayerNorm.weight"]
|
482 |
+
optimizer_grouped_parameters = [
|
483 |
+
{
|
484 |
+
"params": [p for n, p in self.model.named_parameters() if (not any(nd in n for nd in no_decay)) and p.requires_grad],
|
485 |
+
"weight_decay": self.args.weight_decay,
|
486 |
+
},
|
487 |
+
{
|
488 |
+
"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad],
|
489 |
+
"weight_decay": 0.0,
|
490 |
+
},
|
491 |
+
]
|
492 |
+
|
493 |
+
self.optimizer = AdamW(
|
494 |
+
optimizer_grouped_parameters,
|
495 |
+
lr=self.args.learning_rate,
|
496 |
+
betas=(self.args.adam_beta1, self.args.adam_beta2),
|
497 |
+
eps=self.args.adam_epsilon,
|
498 |
+
)
|
499 |
+
|
500 |
+
|
501 |
+
# for n, p in self.model.named_parameters():
|
502 |
+
# print(n,p.requires_grad)
|
503 |
+
print(self.optimizer.state_dict())
|
504 |
+
if self.lr_scheduler is None:
|
505 |
+
self.lr_scheduler = get_linear_schedule_with_warmup(
|
506 |
+
self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
|
507 |
+
)
|
508 |
+
|
509 |
+
|
510 |
+
def setup_wandb(self):
|
511 |
+
"""
|
512 |
+
Setup the optional Weights & Biases (`wandb`) integration.
|
513 |
+
|
514 |
+
One can subclass and override this method to customize the setup if needed. Find more information
|
515 |
+
`here <https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
|
516 |
+
|
517 |
+
Environment:
|
518 |
+
WANDB_WATCH:
|
519 |
+
(Optional, ["gradients", "all", "false"]) "gradients" by default, set to "false" to disable gradient logging
|
520 |
+
or "all" to log gradients and parameters
|
521 |
+
WANDB_PROJECT:
|
522 |
+
(Optional): str - "huggingface" by default, set this to a custom string to store results in a different project
|
523 |
+
WANDB_DISABLED:
|
524 |
+
(Optional): boolean - defaults to false, set to "true" to disable wandb entirely
|
525 |
+
"""
|
526 |
+
if hasattr(self, "_setup_wandb"):
|
527 |
+
warnings.warn(
|
528 |
+
"The `_setup_wandb` method is deprecated and won't be called in a future version, define `setup_wandb` in your subclass.",
|
529 |
+
FutureWarning,
|
530 |
+
)
|
531 |
+
return self._setup_wandb()
|
532 |
+
|
533 |
+
if self.is_world_process_zero():
|
534 |
+
logger.info(
|
535 |
+
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
|
536 |
+
)
|
537 |
+
try:
|
538 |
+
combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
|
539 |
+
except AttributeError:
|
540 |
+
# in case the model has no config
|
541 |
+
combined_dict = {**self.args.to_sanitized_dict()}
|
542 |
+
wandb.init(
|
543 |
+
project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
|
544 |
+
)
|
545 |
+
# keep track of model topology and gradients, unsupported on TPU
|
546 |
+
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
|
547 |
+
wandb.watch(
|
548 |
+
self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps)
|
549 |
+
)
|
550 |
+
|
551 |
+
def setup_comet(self):
|
552 |
+
"""
|
553 |
+
Setup the optional Comet.ml integration.
|
554 |
+
|
555 |
+
Environment:
|
556 |
+
COMET_MODE:
|
557 |
+
(Optional): str - "OFFLINE", "ONLINE", or "DISABLED"
|
558 |
+
COMET_PROJECT_NAME:
|
559 |
+
(Optional): str - Comet.ml project name for experiments
|
560 |
+
COMET_OFFLINE_DIRECTORY:
|
561 |
+
(Optional): str - folder to use for saving offline experiments when `COMET_MODE` is "OFFLINE"
|
562 |
+
|
563 |
+
For a number of configurable items in the environment,
|
564 |
+
see `here <https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables>`__
|
565 |
+
"""
|
566 |
+
if self.is_world_master():
|
567 |
+
comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
|
568 |
+
args = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
|
569 |
+
experiment = None
|
570 |
+
if comet_mode == "ONLINE":
|
571 |
+
experiment = comet_ml.Experiment(**args)
|
572 |
+
logger.info("Automatic Comet.ml online logging enabled")
|
573 |
+
elif comet_mode == "OFFLINE":
|
574 |
+
args["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
|
575 |
+
experiment = comet_ml.OfflineExperiment(**args)
|
576 |
+
logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
|
577 |
+
if experiment is not None:
|
578 |
+
experiment._set_model_graph(self.model, framework="transformers")
|
579 |
+
experiment._log_parameters(self.args, prefix="args/", framework="transformers")
|
580 |
+
experiment._log_parameters(self.model.config, prefix="config/", framework="transformers")
|
581 |
+
|
582 |
+
def num_examples(self, dataloader: DataLoader) -> int:
|
583 |
+
"""
|
584 |
+
Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
|
585 |
+
"""
|
586 |
+
return len(dataloader.dataset)
|
587 |
+
|
588 |
+
def _setup_loggers(self):
|
589 |
+
if self._loggers_initialized:
|
590 |
+
return
|
591 |
+
if is_wandb_available():
|
592 |
+
self.setup_wandb()
|
593 |
+
elif os.environ.get("WANDB_DISABLED") != "true":
|
594 |
+
logger.info(
|
595 |
+
"You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
|
596 |
+
"run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
|
597 |
+
)
|
598 |
+
if is_comet_available():
|
599 |
+
self.setup_comet()
|
600 |
+
elif os.environ.get("COMET_MODE") != "DISABLED":
|
601 |
+
logger.info(
|
602 |
+
"To use comet_ml logging, run `pip/conda install comet_ml` "
|
603 |
+
"see https://www.comet.ml/docs/python-sdk/huggingface/"
|
604 |
+
)
|
605 |
+
self._loggers_initialized = True
|
606 |
+
|
607 |
+
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
|
608 |
+
""" HP search setup code """
|
609 |
+
if self.hp_search_backend is None or trial is None:
|
610 |
+
return
|
611 |
+
params = self.hp_space(trial) if self.hp_search_backend == HPSearchBackend.OPTUNA else trial
|
612 |
+
for key, value in params.items():
|
613 |
+
if not hasattr(self.args, key):
|
614 |
+
raise AttributeError(
|
615 |
+
f"Trying to set {key} in the hyperparameter search but there is no corresponding field in `TrainingArguments`."
|
616 |
+
)
|
617 |
+
old_attr = getattr(self.args, key, None)
|
618 |
+
# Casting value to the proper type
|
619 |
+
if old_attr is not None:
|
620 |
+
value = type(old_attr)(value)
|
621 |
+
setattr(self.args, key, value)
|
622 |
+
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
623 |
+
logger.info("Trial:", trial.params)
|
624 |
+
|
625 |
+
def _report_to_hp_search(
|
626 |
+
self, trial: Union["optuna.Trial", Dict[str, Any]], epoch: int, metrics: Dict[str, float]
|
627 |
+
):
|
628 |
+
if self.hp_search_backend is None or trial is None:
|
629 |
+
return
|
630 |
+
self.objective = self.compute_objective(metrics)
|
631 |
+
if self.hp_search_backend == HPSearchBackend.OPTUNA:
|
632 |
+
trial.report(self.objective, epoch)
|
633 |
+
if trial.should_prune():
|
634 |
+
raise optuna.TrialPruned()
|
635 |
+
elif self.hp_search_backend == HPSearchBackend.RAY:
|
636 |
+
if self.global_step % self.args.save_steps == 0:
|
637 |
+
self._tune_save_checkpoint()
|
638 |
+
tune.report(objective=self.objective, **metrics)
|
639 |
+
|
640 |
+
def _tune_save_checkpoint(self):
|
641 |
+
if not self.use_tune_checkpoints:
|
642 |
+
return
|
643 |
+
with tune.checkpoint_dir(step=self.global_step) as checkpoint_dir:
|
644 |
+
self.args.output_dir = checkpoint_dir
|
645 |
+
output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")
|
646 |
+
self.save_model(output_dir)
|
647 |
+
if self.is_world_master():
|
648 |
+
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
649 |
+
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
650 |
+
|
651 |
+
|
652 |
+
def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None):
|
653 |
+
"""
|
654 |
+
Main training entry point.
|
655 |
+
|
656 |
+
Args:
|
657 |
+
model_path (:obj:`str`, `optional`):
|
658 |
+
Local path to the model if the model to train has been instantiated from a local path. If present,
|
659 |
+
training will resume from the optimizer/scheduler states loaded here.
|
660 |
+
trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
|
661 |
+
The trial run or the hyperparameter dictionary for hyperparameter search.
|
662 |
+
"""
|
663 |
+
# This might change the seed so needs to run first.
|
664 |
+
self._hp_search_setup(trial)
|
665 |
+
|
666 |
+
# Model re-init
|
667 |
+
if self.model_init is not None:
|
668 |
+
# Seed must be set before instantiating the model when using model_init.
|
669 |
+
set_seed(self.args.seed)
|
670 |
+
model = self.model_init()
|
671 |
+
self.model = model.to(self.args.device)
|
672 |
+
|
673 |
+
# Reinitializes optimizer and scheduler
|
674 |
+
self.optimizer, self.lr_scheduler = None, None
|
675 |
+
|
676 |
+
# Data loader and number of training steps
|
677 |
+
train_dataloader = self.get_train_dataloader()
|
678 |
+
num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
|
679 |
+
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
|
680 |
+
if self.args.max_steps > 0:
|
681 |
+
t_total = self.args.max_steps
|
682 |
+
num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
|
683 |
+
self.args.max_steps % num_update_steps_per_epoch > 0
|
684 |
+
)
|
685 |
+
else:
|
686 |
+
t_total = int(num_update_steps_per_epoch * self.args.num_train_epochs)
|
687 |
+
num_train_epochs = self.args.num_train_epochs
|
688 |
+
self.args.max_steps = t_total
|
689 |
+
|
690 |
+
self.create_optimizer_and_scheduler(num_training_steps=t_total)
|
691 |
+
|
692 |
+
# Check if saved optimizer or scheduler states exist
|
693 |
+
if (
|
694 |
+
model_path is not None
|
695 |
+
and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
|
696 |
+
and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
|
697 |
+
):
|
698 |
+
# Load in optimizer and scheduler states
|
699 |
+
self.optimizer.load_state_dict(
|
700 |
+
torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
|
701 |
+
)
|
702 |
+
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
|
703 |
+
|
704 |
+
model = self.model
|
705 |
+
if self.args.fp16 and _use_apex:
|
706 |
+
if not is_apex_available():
|
707 |
+
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
708 |
+
model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
|
709 |
+
|
710 |
+
# multi-gpu training (should be after apex fp16 initialization)
|
711 |
+
if self.args.n_gpu > 1:
|
712 |
+
model = torch.nn.DataParallel(model)
|
713 |
+
|
714 |
+
# Distributed training (should be after apex fp16 initialization)
|
715 |
+
if self.args.local_rank != -1:
|
716 |
+
model = torch.nn.parallel.DistributedDataParallel(
|
717 |
+
model,
|
718 |
+
device_ids=[self.args.local_rank],
|
719 |
+
output_device=self.args.local_rank,
|
720 |
+
find_unused_parameters=True,
|
721 |
+
)
|
722 |
+
|
723 |
+
if self.tb_writer is not None:
|
724 |
+
self.tb_writer.add_text("args", self.args.to_json_string())
|
725 |
+
self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})
|
726 |
+
|
727 |
+
# Train!
|
728 |
+
if is_torch_tpu_available():
|
729 |
+
total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
|
730 |
+
else:
|
731 |
+
total_train_batch_size = (
|
732 |
+
self.args.train_batch_size
|
733 |
+
* self.args.gradient_accumulation_steps
|
734 |
+
* (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
|
735 |
+
)
|
736 |
+
logger.info("***** Running training *****")
|
737 |
+
logger.info(" Num examples = %d", self.num_examples(train_dataloader))
|
738 |
+
logger.info(" Num Epochs = %d", num_train_epochs)
|
739 |
+
logger.info(" Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
|
740 |
+
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
|
741 |
+
logger.info(" Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
|
742 |
+
logger.info(" Total optimization steps = %d", t_total)
|
743 |
+
|
744 |
+
self.global_step = 0
|
745 |
+
self.epoch = 0
|
746 |
+
self.total_flos = 0
|
747 |
+
epochs_trained = 0
|
748 |
+
steps_trained_in_current_epoch = 0
|
749 |
+
# Check if continuing training from a checkpoint
|
750 |
+
if model_path is not None:
|
751 |
+
# set global_step to global_step of last saved checkpoint from model path
|
752 |
+
try:
|
753 |
+
self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])
|
754 |
+
# print(model, model.module)
|
755 |
+
if self.args.n_gpu > 1:
|
756 |
+
self.total_flos = getattr(model.module.config, "total_flos", 0)
|
757 |
+
else:
|
758 |
+
self.total_flos = getattr(model.config, "total_flos", 0)
|
759 |
+
|
760 |
+
epochs_trained = self.global_step // num_update_steps_per_epoch
|
761 |
+
steps_trained_in_current_epoch = self.global_step % (num_update_steps_per_epoch)
|
762 |
+
|
763 |
+
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
764 |
+
logger.info(" Continuing training from epoch %d", epochs_trained)
|
765 |
+
logger.info(" Continuing training from global step %d", self.global_step)
|
766 |
+
logger.info(" Continuing training from %d non-embedding floating-point operations", self.total_flos)
|
767 |
+
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
|
768 |
+
except ValueError:
|
769 |
+
self.global_step = 0
|
770 |
+
self.total_flos = 0
|
771 |
+
logger.info(" Starting fine-tuning.")
|
772 |
+
|
773 |
+
tr_loss = torch.tensor(0.0).to(self.args.device)
|
774 |
+
logging_loss_scalar = 0.0
|
775 |
+
model.zero_grad()
|
776 |
+
disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
|
777 |
+
train_pbar = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm)
|
778 |
+
for epoch in range(epochs_trained, int(np.ceil(num_train_epochs))):
|
779 |
+
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
|
780 |
+
train_dataloader.sampler.set_epoch(epoch)
|
781 |
+
|
782 |
+
if is_torch_tpu_available():
|
783 |
+
parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
|
784 |
+
self.args.device
|
785 |
+
)
|
786 |
+
epoch_iterator = parallel_loader
|
787 |
+
else:
|
788 |
+
epoch_iterator = train_dataloader
|
789 |
+
|
790 |
+
# Reset the past mems state at the beginning of each epoch if necessary.
|
791 |
+
if self.args.past_index >= 0:
|
792 |
+
self._past = None
|
793 |
+
|
794 |
+
epoch_pbar = tqdm(epoch_iterator, desc="Iteration", disable=disable_tqdm)
|
795 |
+
for step, inputs in enumerate(epoch_iterator):
|
796 |
+
|
797 |
+
# Skip past any already trained steps if resuming training
|
798 |
+
if steps_trained_in_current_epoch > 0:
|
799 |
+
steps_trained_in_current_epoch -= 1
|
800 |
+
epoch_pbar.update(1)
|
801 |
+
continue
|
802 |
+
|
803 |
+
tr_loss += self.training_step(model, inputs)
|
804 |
+
|
805 |
+
self.total_flos += self.floating_point_ops(inputs)
|
806 |
+
|
807 |
+
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
|
808 |
+
# last step in epoch but step is always smaller than gradient_accumulation_steps
|
809 |
+
len(epoch_iterator) <= self.args.gradient_accumulation_steps
|
810 |
+
and (step + 1) == len(epoch_iterator)
|
811 |
+
):
|
812 |
+
if self.args.fp16 and _use_native_amp:
|
813 |
+
self.scaler.unscale_(self.optimizer)
|
814 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
|
815 |
+
elif self.args.fp16 and _use_apex:
|
816 |
+
torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
|
817 |
+
else:
|
818 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
|
819 |
+
|
820 |
+
if is_torch_tpu_available():
|
821 |
+
xm.optimizer_step(self.optimizer)
|
822 |
+
elif self.args.fp16 and _use_native_amp:
|
823 |
+
self.scaler.step(self.optimizer)
|
824 |
+
self.scaler.update()
|
825 |
+
else:
|
826 |
+
self.optimizer.step()
|
827 |
+
|
828 |
+
# URGENT
|
829 |
+
self.lr_scheduler.step()
|
830 |
+
model.zero_grad()
|
831 |
+
self.global_step += 1
|
832 |
+
self.epoch = epoch + (step + 1) / len(epoch_iterator)
|
833 |
+
|
834 |
+
|
835 |
+
if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
|
836 |
+
self.global_step == 1 and self.args.logging_first_step
|
837 |
+
):
|
838 |
+
logs: Dict[str, float] = {}
|
839 |
+
tr_loss_scalar = tr_loss.item()
|
840 |
+
logs["loss"] = (tr_loss_scalar - logging_loss_scalar) / self.args.logging_steps
|
841 |
+
# backward compatibility for pytorch schedulers
|
842 |
+
logs["learning_rate"] = (
|
843 |
+
self.lr_scheduler.get_last_lr()[0]
|
844 |
+
if version.parse(torch.__version__) >= version.parse("1.4")
|
845 |
+
else self.lr_scheduler.get_lr()[0]
|
846 |
+
)
|
847 |
+
logging_loss_scalar = tr_loss_scalar
|
848 |
+
|
849 |
+
self.log(logs)
|
850 |
+
|
851 |
+
# print(self.args.evaluation_strategy == EvaluationStrategy.STEPS )
|
852 |
+
# print(self.global_step % self.args.eval_steps == 0)
|
853 |
+
# print()
|
854 |
+
|
855 |
+
if (
|
856 |
+
self.args.evaluation_strategy == EvaluationStrategy.STEPS
|
857 |
+
and self.global_step % self.args.eval_steps == 0
|
858 |
+
):
|
859 |
+
metrics = self.evaluate()
|
860 |
+
self._report_to_hp_search(trial, epoch, metrics)
|
861 |
+
|
862 |
+
#############################EARLY STOPPING########################
|
863 |
+
if 'lowdata' in self.args.output_dir or 'earlystop' in self.args.output_dir:
|
864 |
+
self.save_based_on_eval = True
|
865 |
+
else:
|
866 |
+
self.save_based_on_eval = False
|
867 |
+
print('if not see a line lowdata: below, then did not go into low data. ')
|
868 |
+
if self.save_based_on_eval and metrics["eval_loss"] < self.curr_best_eval:
|
869 |
+
print('lowdata:', self.global_step, self.curr_best_eval, metrics["eval_loss"],
|
870 |
+
'perplexity={}'.format(math.exp(metrics["eval_loss"])))
|
871 |
+
self.curr_best_eval = metrics["eval_loss"]
|
872 |
+
if hasattr(model, "module"):
|
873 |
+
assert (
|
874 |
+
model.module is self.model
|
875 |
+
), f"Module {model.module} should be a reference to self.model"
|
876 |
+
else:
|
877 |
+
assert model is self.model, f"Model {model} should be a reference to self.model"
|
878 |
+
# Save model checkpoint
|
879 |
+
output_dir_name = os.path.basename(self.args.output_dir)
|
880 |
+
checkpoint_folder = f"{output_dir_name}-{PREFIX_CHECKPOINT_DIR}-{self.global_step}"
|
881 |
+
if self.hp_search_backend is not None and trial is not None:
|
882 |
+
run_id = (
|
883 |
+
trial.number
|
884 |
+
if self.hp_search_backend == HPSearchBackend.OPTUNA
|
885 |
+
else tune.get_trial_id()
|
886 |
+
)
|
887 |
+
checkpoint_folder += f"-run-{run_id}"
|
888 |
+
output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
|
889 |
+
|
890 |
+
self.store_flos()
|
891 |
+
print('saving to output_dir', output_dir)
|
892 |
+
self.save_model(output_dir)
|
893 |
+
|
894 |
+
if self.is_world_process_zero():
|
895 |
+
self._rotate_checkpoints(use_mtime=True)
|
896 |
+
#####################################################
|
897 |
+
|
898 |
+
if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
|
899 |
+
print('saving model at a checkpoint!!')
|
900 |
+
# In all cases (even distributed/parallel), self.model is always a reference
|
901 |
+
# to the model we want to save.
|
902 |
+
if hasattr(model, "module"):
|
903 |
+
assert (
|
904 |
+
model.module is self.model
|
905 |
+
), f"Module {model.module} should be a reference to self.model"
|
906 |
+
else:
|
907 |
+
assert model is self.model, f"Model {model} should be a reference to self.model"
|
908 |
+
# Save model checkpoint
|
909 |
+
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}"
|
910 |
+
if self.hp_search_backend is not None and trial is not None:
|
911 |
+
run_id = (
|
912 |
+
trial.number
|
913 |
+
if self.hp_search_backend == HPSearchBackend.OPTUNA
|
914 |
+
else tune.get_trial_id()
|
915 |
+
)
|
916 |
+
checkpoint_folder += f"-run-{run_id}"
|
917 |
+
output_dir = os.path.join(self.args.output_dir, checkpoint_folder)
|
918 |
+
|
919 |
+
self.store_flos()
|
920 |
+
|
921 |
+
self.save_model(output_dir)
|
922 |
+
|
923 |
+
if self.is_world_process_zero():
|
924 |
+
self._rotate_checkpoints(use_mtime=True)
|
925 |
+
|
926 |
+
if is_torch_tpu_available():
|
927 |
+
xm.rendezvous("saving_optimizer_states")
|
928 |
+
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
929 |
+
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
930 |
+
elif self.is_world_process_zero():
|
931 |
+
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
932 |
+
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
933 |
+
|
934 |
+
epoch_pbar.update(1)
|
935 |
+
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
|
936 |
+
break
|
937 |
+
epoch_pbar.close()
|
938 |
+
train_pbar.update(1)
|
939 |
+
|
940 |
+
if self.args.evaluation_strategy == EvaluationStrategy.EPOCH:
|
941 |
+
metrics = self.evaluate()
|
942 |
+
self._report_to_hp_search(trial, epoch, metrics)
|
943 |
+
|
944 |
+
if self.args.tpu_metrics_debug or self.args.debug:
|
945 |
+
if is_torch_tpu_available():
|
946 |
+
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
|
947 |
+
xm.master_print(met.metrics_report())
|
948 |
+
else:
|
949 |
+
logger.warning(
|
950 |
+
"You enabled PyTorch/XLA debug metrics but you don't have a TPU "
|
951 |
+
"configured. Check your training configuration if this is unexpected."
|
952 |
+
)
|
953 |
+
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
|
954 |
+
break
|
955 |
+
|
956 |
+
train_pbar.close()
|
957 |
+
if self.tb_writer:
|
958 |
+
self.tb_writer.close()
|
959 |
+
if self.args.past_index and hasattr(self, "_past"):
|
960 |
+
# Clean the state at the end of training
|
961 |
+
delattr(self, "_past")
|
962 |
+
|
963 |
+
logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
|
964 |
+
return TrainOutput(self.global_step, tr_loss.item() / self.global_step)
|
965 |
+
|
966 |
+
def hyperparameter_search(
|
967 |
+
self,
|
968 |
+
hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
|
969 |
+
compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
|
970 |
+
n_trials: int = 20,
|
971 |
+
direction: str = "minimize",
|
972 |
+
backend: Optional[Union["str", HPSearchBackend]] = None,
|
973 |
+
**kwargs
|
974 |
+
) -> BestRun:
|
975 |
+
"""
|
976 |
+
Launch an hyperparameter search using ``optuna`` or ``Ray Tune``. The optimized quantity is determined by
|
977 |
+
:obj:`compute_objectie`, which defaults to a function returning the evaluation loss when no metric is provided,
|
978 |
+
the sum of all metrics otherwise.
|
979 |
+
|
980 |
+
.. warning::
|
981 |
+
|
982 |
+
To use this method, you need to have provided a ``model_init`` when initializing your
|
983 |
+
:class:`~transformers.Trainer`: we need to reinitialize the model at each new run. This is incompatible
|
984 |
+
with the ``optimizers`` argument, so you need to subclass :class:`~transformers.Trainer` and override the
|
985 |
+
method :meth:`~transformers.Trainer.create_optimizer_and_scheduler` for custom optimizer/scheduler.
|
986 |
+
|
987 |
+
Args:
|
988 |
+
hp_space (:obj:`Callable[["optuna.Trial"], Dict[str, float]]`, `optional`):
|
989 |
+
A function that defines the hyperparameter search space. Will default to
|
990 |
+
:func:`~transformers.trainer_utils.default_hp_space_optuna` or
|
991 |
+
:func:`~transformers.trainer_utils.default_hp_space_ray` depending on your backend.
|
992 |
+
compute_objective (:obj:`Callable[[Dict[str, float]], float]`, `optional`):
|
993 |
+
A function computing the objective to minimize or maximize from the metrics returned by the
|
994 |
+
:obj:`evaluate` method. Will default to :func:`~transformers.trainer_utils.default_compute_objective`.
|
995 |
+
n_trials (:obj:`int`, `optional`, defaults to 100):
|
996 |
+
The number of trial runs to test.
|
997 |
+
direction(:obj:`str`, `optional`, defaults to :obj:`"minimize"`):
|
998 |
+
Whether to optimize greater or lower objects. Can be :obj:`"minimize"` or :obj:`"maximize"`, you should
|
999 |
+
pick :obj:`"minimize"` when optimizing the validation loss, :obj:`"maximize"` when optimizing one or
|
1000 |
+
several metrics.
|
1001 |
+
backend(:obj:`str` or :class:`~transformers.training_utils.HPSearchBackend`, `optional`):
|
1002 |
+
The backend to use for hyperparameter search. Will default to optuna or Ray Tune, depending on which
|
1003 |
+
one is installed. If both are installed, will default to optuna.
|
1004 |
+
kwargs:
|
1005 |
+
Additional keyword arguments passed along to :obj:`optuna.create_study` or :obj:`ray.tune.run`. For
|
1006 |
+
more information see:
|
1007 |
+
|
1008 |
+
- the documentation of `optuna.create_study <https://optuna.readthedocs.io/en/stable/reference/alias_generated/optuna.create_study.html#optuna.create_study>`__
|
1009 |
+
- the documentation of `tune.run <https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run>`__
|
1010 |
+
|
1011 |
+
Returns:
|
1012 |
+
:class:`transformers.trainer_utils.BestRun`: All the informations about the best run.
|
1013 |
+
"""
|
1014 |
+
if backend is None:
|
1015 |
+
backend = default_hp_search_backend()
|
1016 |
+
if backend is None:
|
1017 |
+
raise RuntimeError(
|
1018 |
+
"At least one of optuna or ray should be installed. "
|
1019 |
+
"To install optuna run `pip install optuna`."
|
1020 |
+
"To install ray run `pip install ray[tune]`."
|
1021 |
+
)
|
1022 |
+
backend = HPSearchBackend(backend)
|
1023 |
+
if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
|
1024 |
+
raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
|
1025 |
+
if backend == HPSearchBackend.RAY and not is_ray_available():
|
1026 |
+
raise RuntimeError(
|
1027 |
+
"You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
|
1028 |
+
)
|
1029 |
+
self.hp_search_backend = backend
|
1030 |
+
|
1031 |
+
if self.model_init is None:
|
1032 |
+
raise RuntimeError(
|
1033 |
+
"To use hyperparameter search, you need to pass your model through a model_init function."
|
1034 |
+
)
|
1035 |
+
|
1036 |
+
self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
|
1037 |
+
self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
|
1038 |
+
|
1039 |
+
run_hp_search = run_hp_search_optuna if backend == HPSearchBackend.OPTUNA else run_hp_search_ray
|
1040 |
+
best_run = run_hp_search(self, n_trials, direction, **kwargs)
|
1041 |
+
|
1042 |
+
self.hp_search_backend = None
|
1043 |
+
return best_run
|
1044 |
+
|
1045 |
+
def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
|
1046 |
+
"""
|
1047 |
+
Log :obj:`logs` on the various objects watching training.
|
1048 |
+
|
1049 |
+
Subclass and override this method to inject custom behavior.
|
1050 |
+
|
1051 |
+
Args:
|
1052 |
+
logs (:obj:`Dict[str, float]`):
|
1053 |
+
The values to log.
|
1054 |
+
iterator (:obj:`tqdm`, `optional`):
|
1055 |
+
A potential tqdm progress bar to write the logs on.
|
1056 |
+
"""
|
1057 |
+
# Set up loggers like W&B or Comet ML
|
1058 |
+
self._setup_loggers()
|
1059 |
+
|
1060 |
+
if hasattr(self, "_log"):
|
1061 |
+
warnings.warn(
|
1062 |
+
"The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",
|
1063 |
+
FutureWarning,
|
1064 |
+
)
|
1065 |
+
return self._log(logs, iterator=iterator)
|
1066 |
+
|
1067 |
+
if self.epoch is not None:
|
1068 |
+
logs["epoch"] = self.epoch
|
1069 |
+
if self.total_flos is not None:
|
1070 |
+
if self.args.local_rank != -1:
|
1071 |
+
total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
|
1072 |
+
else:
|
1073 |
+
total_flos = self.total_flos
|
1074 |
+
if total_flos > 0:
|
1075 |
+
logs["total_flos"] = self.total_flos
|
1076 |
+
if self.global_step is None:
|
1077 |
+
# when logging evaluation metrics without training
|
1078 |
+
self.global_step = 0
|
1079 |
+
if self.tb_writer:
|
1080 |
+
for k, v in logs.items():
|
1081 |
+
if isinstance(v, (int, float)):
|
1082 |
+
self.tb_writer.add_scalar(k, v, self.global_step)
|
1083 |
+
else:
|
1084 |
+
logger.warning(
|
1085 |
+
"Trainer is attempting to log a value of "
|
1086 |
+
'"%s" of type %s for key "%s" as a scalar. '
|
1087 |
+
"This invocation of Tensorboard's writer.add_scalar() "
|
1088 |
+
"is incorrect so we dropped this attribute.",
|
1089 |
+
v,
|
1090 |
+
type(v),
|
1091 |
+
k,
|
1092 |
+
)
|
1093 |
+
self.tb_writer.flush()
|
1094 |
+
if is_wandb_available():
|
1095 |
+
if self.is_world_process_zero():
|
1096 |
+
wandb.log(logs, step=self.global_step)
|
1097 |
+
if is_comet_available():
|
1098 |
+
if self.is_world_process_zero():
|
1099 |
+
experiment = comet_ml.config.get_global_experiment()
|
1100 |
+
if experiment is not None:
|
1101 |
+
experiment._log_metrics(logs, step=self.global_step, epoch=self.epoch, framework="transformers")
|
1102 |
+
output = {**logs, **{"step": self.global_step}}
|
1103 |
+
if self.is_world_process_zero():
|
1104 |
+
self.log_history.append(output)
|
1105 |
+
if iterator is not None:
|
1106 |
+
iterator.write(output)
|
1107 |
+
else:
|
1108 |
+
print(output)
|
1109 |
+
|
1110 |
+
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
|
1111 |
+
"""
|
1112 |
+
Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
|
1113 |
+
handling potential state.
|
1114 |
+
"""
|
1115 |
+
for k, v in inputs.items():
|
1116 |
+
if isinstance(v, torch.Tensor):
|
1117 |
+
inputs[k] = v.to(self.args.device)
|
1118 |
+
|
1119 |
+
if self.args.past_index >= 0 and self._past is not None:
|
1120 |
+
assert False
|
1121 |
+
inputs["mems"] = self._past
|
1122 |
+
|
1123 |
+
return inputs
|
1124 |
+
|
1125 |
+
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
|
1126 |
+
"""
|
1127 |
+
Perform a training step on a batch of inputs.
|
1128 |
+
|
1129 |
+
Subclass and override to inject custom behavior.
|
1130 |
+
|
1131 |
+
Args:
|
1132 |
+
model (:obj:`nn.Module`):
|
1133 |
+
The model to train.
|
1134 |
+
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
1135 |
+
The inputs and targets of the model.
|
1136 |
+
|
1137 |
+
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
1138 |
+
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
|
1139 |
+
|
1140 |
+
Return:
|
1141 |
+
:obj:`torch.Tensor`: The tensor with training loss on this batch.
|
1142 |
+
"""
|
1143 |
+
if hasattr(self, "_training_step"):
|
1144 |
+
warnings.warn(
|
1145 |
+
"The `_training_step` method is deprecated and won't be called in a future version, define `training_step` in your subclass.",
|
1146 |
+
FutureWarning,
|
1147 |
+
)
|
1148 |
+
return self._training_step(model, inputs, self.optimizer)
|
1149 |
+
|
1150 |
+
model.train()
|
1151 |
+
if self.use_dropout:
|
1152 |
+
if self.gpt2 is not None:
|
1153 |
+
self.gpt2.train()
|
1154 |
+
inputs = self._prepare_inputs(inputs)
|
1155 |
+
|
1156 |
+
if self.args.fp16 and _use_native_amp:
|
1157 |
+
with autocast():
|
1158 |
+
if self.distill:
|
1159 |
+
loss = self.compute_loss_distill(model, inputs, gpt2_model=self.gpt2, )
|
1160 |
+
else:
|
1161 |
+
loss = self.compute_loss(model, inputs, gpt2_model=self.gpt2)
|
1162 |
+
else:
|
1163 |
+
if self.distill:
|
1164 |
+
loss = self.compute_loss_distill(model, inputs, gpt2_model=self.gpt2)
|
1165 |
+
else:
|
1166 |
+
loss = self.compute_loss(model, inputs, gpt2_model=self.gpt2)
|
1167 |
+
|
1168 |
+
if self.args.n_gpu > 1:
|
1169 |
+
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
1170 |
+
|
1171 |
+
if self.args.gradient_accumulation_steps > 1:
|
1172 |
+
loss = loss / self.args.gradient_accumulation_steps
|
1173 |
+
|
1174 |
+
if self.args.fp16 and _use_native_amp:
|
1175 |
+
self.scaler.scale(loss).backward()
|
1176 |
+
elif self.args.fp16 and _use_apex:
|
1177 |
+
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
1178 |
+
scaled_loss.backward()
|
1179 |
+
else:
|
1180 |
+
# print(loss)
|
1181 |
+
loss.backward()
|
1182 |
+
|
1183 |
+
# print('max allocated_memory:', torch.cuda.max_memory_allocated(0), 'total_memory:', torch.cuda.get_device_properties(0).total_memory,
|
1184 |
+
# 'percentage', torch.cuda.max_memory_allocated(0)/torch.cuda.get_device_properties(0).total_memory)
|
1185 |
+
|
1186 |
+
|
1187 |
+
return loss.detach()
|
1188 |
+
|
1189 |
+
|
1190 |
+
|
1191 |
+
|
1192 |
+
|
1193 |
+
def compute_loss(self, model, inputs, gpt2_model=None):
|
1194 |
+
"""
|
1195 |
+
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
1196 |
+
|
1197 |
+
Subclass and override for custom behavior.
|
1198 |
+
"""
|
1199 |
+
# outputs = model.forward_weighted(**inputs)
|
1200 |
+
if 'prompt_lab' in inputs:
|
1201 |
+
prompt_lab_ = inputs['prompt_lab']
|
1202 |
+
k = torch.cat(self.discri_labels_code, dim=0)
|
1203 |
+
inputs['control_code'] = torch.index_select(k, 0, prompt_lab_)
|
1204 |
+
del inputs['prompt_lab']
|
1205 |
+
|
1206 |
+
outputs = model(**inputs, gpt2_model=gpt2_model)
|
1207 |
+
# Save past state if it exists
|
1208 |
+
if self.args.past_index >= 0:
|
1209 |
+
self._past = outputs[self.args.past_index]
|
1210 |
+
|
1211 |
+
# print(outputs[0])
|
1212 |
+
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
1213 |
+
# print(outputs[0], outputs.loss)
|
1214 |
+
# URGENT
|
1215 |
+
# print('compute_loss', outputs[0])
|
1216 |
+
return outputs[0].mean()
|
1217 |
+
|
1218 |
+
def compute_loss_distill(self, model, inputs, gpt2_model=None):
|
1219 |
+
"""
|
1220 |
+
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
1221 |
+
|
1222 |
+
Subclass and override for custom behavior.
|
1223 |
+
"""
|
1224 |
+
# outputs = model.forward_weighted(**inputs)
|
1225 |
+
|
1226 |
+
with torch.no_grad():
|
1227 |
+
output_finetuned = self.finetuned_gpt2(**inputs)
|
1228 |
+
|
1229 |
+
outputs = model(**inputs, gpt2_model=gpt2_model)
|
1230 |
+
# Save past state if it exists
|
1231 |
+
if self.args.past_index >= 0:
|
1232 |
+
self._past = outputs[self.args.past_index]
|
1233 |
+
|
1234 |
+
if self.matching_objective == 'kl':
|
1235 |
+
# distrib_finetuned=torch.log_softmax(output_finetuned.logits[:,:,:-2], dim=-1) #bsz, seqlen, vocab
|
1236 |
+
distrib_finetuned=torch.log_softmax(output_finetuned.logits, dim=-1) #bsz, seqlen, vocab
|
1237 |
+
distrib_prefix = torch.log_softmax(outputs.logits, dim=-1) # bsz, seqlen, vocab
|
1238 |
+
loss = torch.sum(distrib_finetuned.exp() * (distrib_finetuned - distrib_prefix), dim=-1) #bsz, seqlen
|
1239 |
+
|
1240 |
+
elif self.matching_objective == 'logits':
|
1241 |
+
loss = torch.norm(output_finetuned.logits - outputs.logits, dim=-1) #bsz, seqlen
|
1242 |
+
# loss = torch.norm(output_finetuned.logits[:,:,:-2] - outputs.logits, dim=-1) #bsz, seqlen
|
1243 |
+
|
1244 |
+
elif self.matching_objective == 'last_layer':
|
1245 |
+
activation_diff = output_finetuned.last_hidden_state - outputs.last_hidden_state
|
1246 |
+
loss = torch.norm(activation_diff, dim=-1) # bsz, seqlen
|
1247 |
+
else:
|
1248 |
+
assert False, "invalid matching_objective"
|
1249 |
+
|
1250 |
+
return loss.sum(dim=-1).mean()
|
1251 |
+
|
1252 |
+
def is_local_master(self) -> bool:
|
1253 |
+
"""
|
1254 |
+
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
|
1255 |
+
several machines) main process.
|
1256 |
+
|
1257 |
+
.. warning::
|
1258 |
+
|
1259 |
+
This method is deprecated, use :meth:`~transformers.Trainer.is_local_process_zero` instead.
|
1260 |
+
"""
|
1261 |
+
warnings.warn("This method is deprecated, use `Trainer.is_local_process_zero()` instead.", FutureWarning)
|
1262 |
+
return self.is_local_process_zero()
|
1263 |
+
|
1264 |
+
def is_local_process_zero(self) -> bool:
|
1265 |
+
"""
|
1266 |
+
Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
|
1267 |
+
several machines) main process.
|
1268 |
+
"""
|
1269 |
+
if is_torch_tpu_available():
|
1270 |
+
return xm.is_master_ordinal(local=True)
|
1271 |
+
else:
|
1272 |
+
return self.args.local_rank in [-1, 0]
|
1273 |
+
|
1274 |
+
def is_world_master(self) -> bool:
|
1275 |
+
"""
|
1276 |
+
Whether or not this process is the global main process (when training in a distributed fashion on
|
1277 |
+
several machines, this is only going to be :obj:`True` for one process).
|
1278 |
+
|
1279 |
+
.. warning::
|
1280 |
+
|
1281 |
+
This method is deprecated, use :meth:`~transformers.Trainer.is_world_process_zero` instead.
|
1282 |
+
"""
|
1283 |
+
warnings.warn("This method is deprecated, use `Trainer.is_world_process_zero()` instead.", FutureWarning)
|
1284 |
+
return self.is_world_process_zero()
|
1285 |
+
|
1286 |
+
def is_world_process_zero(self) -> bool:
|
1287 |
+
"""
|
1288 |
+
Whether or not this process is the global main process (when training in a distributed fashion on
|
1289 |
+
several machines, this is only going to be :obj:`True` for one process).
|
1290 |
+
"""
|
1291 |
+
if is_torch_tpu_available():
|
1292 |
+
return xm.is_master_ordinal(local=False)
|
1293 |
+
else:
|
1294 |
+
return self.args.local_rank == -1 or torch.distributed.get_rank() == 0
|
1295 |
+
|
1296 |
+
def save_model(self, output_dir: Optional[str] = None):
|
1297 |
+
"""
|
1298 |
+
Will save the model, so you can reload it using :obj:`from_pretrained()`.
|
1299 |
+
|
1300 |
+
Will only save from the world_master process (unless in TPUs).
|
1301 |
+
"""
|
1302 |
+
|
1303 |
+
if is_torch_tpu_available():
|
1304 |
+
self._save_tpu(output_dir)
|
1305 |
+
elif self.is_world_process_zero():
|
1306 |
+
self._save(output_dir)
|
1307 |
+
|
1308 |
+
def _save_tpu(self, output_dir: Optional[str] = None):
|
1309 |
+
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
1310 |
+
logger.info("Saving model checkpoint to %s", output_dir)
|
1311 |
+
|
1312 |
+
if xm.is_master_ordinal():
|
1313 |
+
os.makedirs(output_dir, exist_ok=True)
|
1314 |
+
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
1315 |
+
json.dump(
|
1316 |
+
self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
|
1317 |
+
)
|
1318 |
+
|
1319 |
+
# Save a trained model and configuration using `save_pretrained()`.
|
1320 |
+
# They can then be reloaded using `from_pretrained()`
|
1321 |
+
if not isinstance(self.model, PreTrainedModel):
|
1322 |
+
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
|
1323 |
+
|
1324 |
+
xm.rendezvous("saving_checkpoint")
|
1325 |
+
self.model.save_pretrained(output_dir)
|
1326 |
+
if self.tokenizer is not None:
|
1327 |
+
self.tokenizer.save_pretrained(output_dir)
|
1328 |
+
|
1329 |
+
def _save(self, output_dir: Optional[str] = None):
|
1330 |
+
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
1331 |
+
os.makedirs(output_dir, exist_ok=True)
|
1332 |
+
logger.info("Saving model checkpoint to %s", output_dir)
|
1333 |
+
# Save a trained model and configuration using `save_pretrained()`.
|
1334 |
+
# They can then be reloaded using `from_pretrained()`
|
1335 |
+
if not isinstance(self.model, PreTrainedModel):
|
1336 |
+
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
|
1337 |
+
self.model.save_pretrained(output_dir)
|
1338 |
+
if self.tokenizer is not None:
|
1339 |
+
self.tokenizer.save_pretrained(output_dir)
|
1340 |
+
|
1341 |
+
# Good practice: save your training arguments together with the trained model
|
1342 |
+
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
1343 |
+
json.dump(
|
1344 |
+
self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
|
1345 |
+
)
|
1346 |
+
|
1347 |
+
def store_flos(self):
|
1348 |
+
# Storing the number of floating-point operations that went into the model
|
1349 |
+
if self.total_flos is not None:
|
1350 |
+
if self.args.local_rank != -1:
|
1351 |
+
total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
|
1352 |
+
else:
|
1353 |
+
total_flos = self.total_flos
|
1354 |
+
if total_flos > 0:
|
1355 |
+
self.model.config.total_flos = total_flos
|
1356 |
+
|
1357 |
+
def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
|
1358 |
+
output_dir_name = os.path.basename(self.args.output_dir)
|
1359 |
+
checkpoint_prefix = f"{output_dir_name}-{PREFIX_CHECKPOINT_DIR}"
|
1360 |
+
|
1361 |
+
ordering_and_checkpoint_path = []
|
1362 |
+
|
1363 |
+
glob_checkpoints = [str(x) for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")]
|
1364 |
+
|
1365 |
+
for path in glob_checkpoints:
|
1366 |
+
if use_mtime:
|
1367 |
+
ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
|
1368 |
+
else:
|
1369 |
+
regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
|
1370 |
+
if regex_match and regex_match.groups():
|
1371 |
+
ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
|
1372 |
+
|
1373 |
+
checkpoints_sorted = sorted(ordering_and_checkpoint_path)
|
1374 |
+
checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
|
1375 |
+
return checkpoints_sorted
|
1376 |
+
|
1377 |
+
def _rotate_checkpoints(self, use_mtime=False) -> None:
|
1378 |
+
if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
|
1379 |
+
return
|
1380 |
+
|
1381 |
+
# Check if we should delete older checkpoint(s)
|
1382 |
+
checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime)
|
1383 |
+
if len(checkpoints_sorted) <= self.args.save_total_limit:
|
1384 |
+
return
|
1385 |
+
|
1386 |
+
number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - self.args.save_total_limit)
|
1387 |
+
checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
|
1388 |
+
for checkpoint in checkpoints_to_be_deleted:
|
1389 |
+
logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
|
1390 |
+
shutil.rmtree(checkpoint)
|
1391 |
+
|
1392 |
+
def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
|
1393 |
+
"""
|
1394 |
+
Run evaluation and returns metrics.
|
1395 |
+
|
1396 |
+
The calling script will be responsible for providing a method to compute metrics, as they are
|
1397 |
+
task-dependent (pass it to the init :obj:`compute_metrics` argument).
|
1398 |
+
|
1399 |
+
You can also subclass and override this method to inject custom behavior.
|
1400 |
+
|
1401 |
+
Args:
|
1402 |
+
eval_dataset (:obj:`Dataset`, `optional`):
|
1403 |
+
Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
|
1404 |
+
columns not accepted by the ``model.forward()`` method are automatically removed.
|
1405 |
+
|
1406 |
+
Returns:
|
1407 |
+
A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
|
1408 |
+
"""
|
1409 |
+
eval_dataloader = self.get_eval_dataloader(eval_dataset)
|
1410 |
+
|
1411 |
+
output = self.prediction_loop(eval_dataloader, description="Evaluation")
|
1412 |
+
|
1413 |
+
self.log(output.metrics)
|
1414 |
+
|
1415 |
+
if self.args.tpu_metrics_debug or self.args.debug:
|
1416 |
+
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
|
1417 |
+
xm.master_print(met.metrics_report())
|
1418 |
+
|
1419 |
+
return output.metrics
|
1420 |
+
|
1421 |
+
|
1422 |
+
|
1423 |
+
def predict(self, test_dataset: Dataset) -> PredictionOutput:
|
1424 |
+
"""
|
1425 |
+
Run prediction and returns predictions and potential metrics.
|
1426 |
+
|
1427 |
+
Depending on the dataset and your use case, your test dataset may contain labels.
|
1428 |
+
In that case, this method will also return metrics, like in :obj:`evaluate()`.
|
1429 |
+
|
1430 |
+
Args:
|
1431 |
+
test_dataset (:obj:`Dataset`):
|
1432 |
+
Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
1433 |
+
``model.forward()`` method are automatically removed.
|
1434 |
+
|
1435 |
+
Returns:
|
1436 |
+
`NamedTuple`:
|
1437 |
+
predictions (:obj:`np.ndarray`):
|
1438 |
+
The predictions on :obj:`test_dataset`.
|
1439 |
+
label_ids (:obj:`np.ndarray`, `optional`):
|
1440 |
+
The labels (if the dataset contained some).
|
1441 |
+
metrics (:obj:`Dict[str, float]`, `optional`):
|
1442 |
+
The potential dictionary of metrics (if the dataset contained labels).
|
1443 |
+
"""
|
1444 |
+
test_dataloader = self.get_test_dataloader(test_dataset)
|
1445 |
+
|
1446 |
+
return self.prediction_loop(test_dataloader, description="Prediction")
|
1447 |
+
|
1448 |
+
def prediction_loop(
|
1449 |
+
self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None
|
1450 |
+
) -> PredictionOutput:
|
1451 |
+
"""
|
1452 |
+
Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
|
1453 |
+
|
1454 |
+
Works both with or without labels.
|
1455 |
+
"""
|
1456 |
+
if hasattr(self, "_prediction_loop"):
|
1457 |
+
warnings.warn(
|
1458 |
+
"The `_prediction_loop` method is deprecated and won't be called in a future version, define `prediction_loop` in your subclass.",
|
1459 |
+
FutureWarning,
|
1460 |
+
)
|
1461 |
+
return self._prediction_loop(dataloader, description, prediction_loss_only=prediction_loss_only)
|
1462 |
+
|
1463 |
+
prediction_loss_only = (
|
1464 |
+
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
|
1465 |
+
)
|
1466 |
+
|
1467 |
+
assert not getattr(
|
1468 |
+
self.model.config, "output_attentions", False
|
1469 |
+
), "The prediction loop does not work with `output_attentions=True`."
|
1470 |
+
assert not getattr(
|
1471 |
+
self.model.config, "output_hidden_states", False
|
1472 |
+
), "The prediction loop does not work with `output_hidden_states=True`."
|
1473 |
+
|
1474 |
+
model = self.model
|
1475 |
+
# multi-gpu eval
|
1476 |
+
if self.args.n_gpu > 1:
|
1477 |
+
model = torch.nn.DataParallel(model)
|
1478 |
+
else:
|
1479 |
+
model = self.model
|
1480 |
+
# Note: in torch.distributed mode, there's no point in wrapping the model
|
1481 |
+
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
|
1482 |
+
|
1483 |
+
batch_size = dataloader.batch_size
|
1484 |
+
logger.info("***** Running %s *****", description)
|
1485 |
+
logger.info(" Num examples = %d", self.num_examples(dataloader))
|
1486 |
+
logger.info(" Batch size = %d", batch_size)
|
1487 |
+
eval_losses: List[float] = []
|
1488 |
+
preds: torch.Tensor = None
|
1489 |
+
label_ids: torch.Tensor = None
|
1490 |
+
entropy_losses: List[float] = []
|
1491 |
+
model.eval()
|
1492 |
+
if self.gpt2 is not None:
|
1493 |
+
self.gpt2.eval()
|
1494 |
+
|
1495 |
+
print(model.training)
|
1496 |
+
print(self.gpt2.training)
|
1497 |
+
|
1498 |
+
if is_torch_tpu_available():
|
1499 |
+
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)
|
1500 |
+
|
1501 |
+
if self.args.past_index >= 0:
|
1502 |
+
self._past = None
|
1503 |
+
|
1504 |
+
disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm
|
1505 |
+
for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm):
|
1506 |
+
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
|
1507 |
+
batch_size = inputs[list(inputs.keys())[0]].shape[0]
|
1508 |
+
if loss is not None:
|
1509 |
+
eval_losses.extend([loss] * batch_size)
|
1510 |
+
if logits is not None:
|
1511 |
+
preds = logits if preds is None else nested_concat(preds, logits, dim=0)
|
1512 |
+
temp_logits = [torch.log_softmax(x) for x in logits]
|
1513 |
+
entropy_losses.extend([(x.exp() * x).sum() for x in temp_logits])
|
1514 |
+
if labels is not None:
|
1515 |
+
label_ids = labels if label_ids is None else nested_concat(label_ids, labels, dim=0)
|
1516 |
+
|
1517 |
+
if self.args.past_index and hasattr(self, "_past"):
|
1518 |
+
# Clean the state at the end of the evaluation loop
|
1519 |
+
delattr(self, "_past")
|
1520 |
+
|
1521 |
+
|
1522 |
+
|
1523 |
+
if self.compute_metrics is not None and preds is not None and label_ids is not None:
|
1524 |
+
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
|
1525 |
+
else:
|
1526 |
+
metrics = {}
|
1527 |
+
|
1528 |
+
# Prefix all keys with eval_
|
1529 |
+
for key in list(metrics.keys()):
|
1530 |
+
if not key.startswith("eval_"):
|
1531 |
+
metrics[f"eval_{key}"] = metrics.pop(key)
|
1532 |
+
if len(entropy_losses) > 0:
|
1533 |
+
metrics['entropy'] = np.mean(entropy_losses)
|
1534 |
+
print('entropy', metrics['entropy'] )
|
1535 |
+
|
1536 |
+
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
|
1537 |
+
|
1538 |
+
def prediction_step(
|
1539 |
+
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
|
1540 |
+
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
1541 |
+
"""
|
1542 |
+
Perform an evaluation step on :obj:`model` using obj:`inputs`.
|
1543 |
+
|
1544 |
+
Subclass and override to inject custom behavior.
|
1545 |
+
|
1546 |
+
Args:
|
1547 |
+
model (:obj:`nn.Module`):
|
1548 |
+
The model to evaluate.
|
1549 |
+
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
1550 |
+
The inputs and targets of the model.
|
1551 |
+
|
1552 |
+
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
1553 |
+
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
|
1554 |
+
prediction_loss_only (:obj:`bool`):
|
1555 |
+
Whether or not to return the loss only.
|
1556 |
+
|
1557 |
+
Return:
|
1558 |
+
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
1559 |
+
A tuple with the loss, logits and labels (each being optional).
|
1560 |
+
"""
|
1561 |
+
has_labels = all(inputs.get(k) is not None for k in self.args.label_names)
|
1562 |
+
inputs = self._prepare_inputs(inputs)
|
1563 |
+
|
1564 |
+
# At eval time, set the weights to 1/bsz. and see the results..
|
1565 |
+
|
1566 |
+
# if 'weights' in inputs:
|
1567 |
+
# weights = inputs['weights']
|
1568 |
+
# bsz = weights.view(-1).shape[0]
|
1569 |
+
# weights = (torch.ones(weights.shape)/bsz).to(weights.device)
|
1570 |
+
# inputs['weights'] = weights
|
1571 |
+
|
1572 |
+
with torch.no_grad():
|
1573 |
+
# outputs = model.forward_weighted(**inputs)
|
1574 |
+
outputs = model(**inputs, gpt2_model=self.gpt2)
|
1575 |
+
if has_labels:
|
1576 |
+
# The .mean() is to reduce in case of distributed training
|
1577 |
+
loss = outputs[0].mean().item()
|
1578 |
+
logits = outputs[1:]
|
1579 |
+
else:
|
1580 |
+
loss = None
|
1581 |
+
# Slicing so we get a tuple even if `outputs` is a `ModelOutput`.
|
1582 |
+
logits = outputs[:]
|
1583 |
+
if self.args.past_index >= 0:
|
1584 |
+
self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]
|
1585 |
+
|
1586 |
+
if prediction_loss_only:
|
1587 |
+
return (loss, None, None)
|
1588 |
+
|
1589 |
+
logits = tuple(logit.detach() for logit in logits)
|
1590 |
+
if len(logits) == 1:
|
1591 |
+
logits = logits[0]
|
1592 |
+
|
1593 |
+
if has_labels:
|
1594 |
+
labels = tuple(inputs.get(name).detach() for name in self.args.label_names)
|
1595 |
+
if len(labels) == 1:
|
1596 |
+
labels = labels[0]
|
1597 |
+
else:
|
1598 |
+
labels = None
|
1599 |
+
|
1600 |
+
return (loss, logits, labels)
|
1601 |
+
|
1602 |
+
def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
|
1603 |
+
"""
|
1604 |
+
For models that inherit from :class:`~transformers.PretrainedModel`, uses
|
1605 |
+
that method to compute the number of floating point operations for every backward + forward pass. If using
|
1606 |
+
another model, either implement such a method in the model or subclass and override this method.
|
1607 |
+
|
1608 |
+
Args:
|
1609 |
+
model (:obj:`nn.Module`):
|
1610 |
+
The model to evaluate.
|
1611 |
+
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
1612 |
+
The inputs and targets of the model.
|
1613 |
+
|
1614 |
+
Returns:
|
1615 |
+
:obj:`int`: The number of floating-point operations.
|
1616 |
+
"""
|
1617 |
+
|
1618 |
+
if isinstance(self.model, torch.nn.DataParallel) or isinstance(
|
1619 |
+
self.model, torch.nn.parallel.DistributedDataParallel
|
1620 |
+
):
|
1621 |
+
model = self.model.module
|
1622 |
+
else:
|
1623 |
+
model = self.model
|
1624 |
+
|
1625 |
+
if hasattr(model, "floating_point_ops"):
|
1626 |
+
return model.floating_point_ops(inputs)
|
1627 |
+
|
1628 |
+
else:
|
1629 |
+
return 0
|
dalle/utils/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .utils import *
|
2 |
+
from .config import *
|
3 |
+
from .sampling import *
|
dalle/utils/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (214 Bytes). View file
|
|
dalle/utils/__pycache__/config.cpython-38.pyc
ADDED
Binary file (7.78 kB). View file
|
|
dalle/utils/__pycache__/sampling.cpython-38.pyc
ADDED
Binary file (6.86 kB). View file
|
|
dalle/utils/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (3.62 kB). View file
|
|
dalle/utils/config.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------
|
2 |
+
# Minimal DALL-E
|
3 |
+
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
from typing import Optional, List
|
8 |
+
from dataclasses import dataclass, field
|
9 |
+
from omegaconf import OmegaConf
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class DataConfig:
|
14 |
+
dataset: Optional[str] = None
|
15 |
+
tokenizer_type: str = 'CharBPE'
|
16 |
+
context_length: int = 64
|
17 |
+
image_resolution: int = 256
|
18 |
+
transforms: str = 'dalle-vqvae'
|
19 |
+
bpe_pdrop: Optional[float] = None
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class Stage1Hparams:
|
24 |
+
double_z: bool = False
|
25 |
+
z_channels: int = 256
|
26 |
+
resolution: int = 256
|
27 |
+
in_channels: int = 3
|
28 |
+
out_ch: int = 3
|
29 |
+
ch: int = 128
|
30 |
+
ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
|
31 |
+
num_res_blocks: int = 2
|
32 |
+
attn_resolutions: List[int] = field(default_factory=lambda: [16])
|
33 |
+
pdrop: float = 0.0
|
34 |
+
|
35 |
+
|
36 |
+
@dataclass
|
37 |
+
class Stage2Hparams:
|
38 |
+
embed_dim: int = 1536
|
39 |
+
n_layers: int = 42
|
40 |
+
n_heads: int = 24
|
41 |
+
n_dense_layers: int = 42
|
42 |
+
ctx_len_img: int = 256
|
43 |
+
ctx_len_txt: int = 64
|
44 |
+
embd_pdrop: float = 0.0
|
45 |
+
resid_pdrop: float = 0.0
|
46 |
+
attn_pdrop: float = 0.0
|
47 |
+
mlp_bias: bool = True
|
48 |
+
attn_bias: bool = True
|
49 |
+
gelu_use_approx: bool = False
|
50 |
+
use_head_txt: bool = True
|
51 |
+
n_classes: Optional[int] = None
|
52 |
+
|
53 |
+
|
54 |
+
@dataclass
|
55 |
+
class Stage1Config:
|
56 |
+
type: str = 'vqgan'
|
57 |
+
embed_dim: int = 256
|
58 |
+
n_embed: int = 16384
|
59 |
+
hparams: Stage1Hparams = Stage1Hparams()
|
60 |
+
|
61 |
+
|
62 |
+
@dataclass
|
63 |
+
class Stage2Config:
|
64 |
+
type: str = 'transformer1d'
|
65 |
+
vocab_size_txt: int = 16384
|
66 |
+
vocab_size_img: int = 16384
|
67 |
+
use_cls_cond: Optional[bool] = None
|
68 |
+
hparams: Stage2Hparams = Stage2Hparams()
|
69 |
+
|
70 |
+
|
71 |
+
@dataclass
|
72 |
+
class WarmupConfig:
|
73 |
+
epoch: int = 1
|
74 |
+
multiplier: int = 1
|
75 |
+
buffer_epoch: int = 0
|
76 |
+
min_lr: float = 0.0
|
77 |
+
mode: str = 'fix'
|
78 |
+
peak_lr: float = 1e-4
|
79 |
+
start_from_zero: bool = True
|
80 |
+
|
81 |
+
|
82 |
+
@dataclass
|
83 |
+
class OptConfig:
|
84 |
+
opt_type: str = 'adamW'
|
85 |
+
learning_rate: float = 5e-5
|
86 |
+
weight_decay: float = 1e-4
|
87 |
+
betas: List[float] = field(default_factory=lambda: [0.9, 0.99])
|
88 |
+
grad_clip_norm: float = 1.0
|
89 |
+
|
90 |
+
sched_type: str = 'cosine'
|
91 |
+
max_steps: int = 0
|
92 |
+
min_lr: float = 1e-6
|
93 |
+
|
94 |
+
|
95 |
+
@dataclass
|
96 |
+
class ExpConfig:
|
97 |
+
per_gpu_train_batch_size: int = 4
|
98 |
+
per_gpu_eval_batch_size: int = 32
|
99 |
+
num_train_epochs: int = 10
|
100 |
+
save_ckpt_freq: int = 1
|
101 |
+
test_freq: int = 10
|
102 |
+
use_amp: bool = True
|
103 |
+
|
104 |
+
|
105 |
+
@dataclass
|
106 |
+
class PrefixModelConfig:
|
107 |
+
model_name_or_path: Optional[str] = ''
|
108 |
+
prefix_model_name_or_path: str = ''
|
109 |
+
prefix_mode: str = 'activation'
|
110 |
+
tuning_mode: str = 'finetune'
|
111 |
+
top_k_layers: int = 2
|
112 |
+
parameterize_mode: str = 'mlp'
|
113 |
+
optim_prefix: bool = False
|
114 |
+
preseqlen: int = 10
|
115 |
+
prefix_dropout: float = 0.1
|
116 |
+
init_random: bool = False
|
117 |
+
hidden_dim_prefix: int = 512
|
118 |
+
lowdata: bool = False
|
119 |
+
lowdata_token: str = ''
|
120 |
+
init_shallow: bool = False
|
121 |
+
init_shallow_word: bool = False
|
122 |
+
teacher_dropout: float = 0.1
|
123 |
+
gumbel: bool = False
|
124 |
+
replay_buffer: bool = False
|
125 |
+
|
126 |
+
|
127 |
+
@dataclass
|
128 |
+
class PromptModelConfig:
|
129 |
+
model_name_or_path: Optional[str] = ''
|
130 |
+
prefix_model_name_or_path: str = ''
|
131 |
+
tuning_mode: str = 'prompt'
|
132 |
+
preseqlen: int = 10
|
133 |
+
prefix_dropout: float = 0.1
|
134 |
+
|
135 |
+
|
136 |
+
@dataclass
|
137 |
+
class StoryModelConfig:
|
138 |
+
model_name_or_path: Optional[str] = ''
|
139 |
+
prefix_model_name_or_path: str = ''
|
140 |
+
tuning_mode: str = 'story'
|
141 |
+
preseqlen: int = 10
|
142 |
+
prefix_dropout: float = 0.1
|
143 |
+
prompt: bool = False
|
144 |
+
story_len: int = 4
|
145 |
+
sent_embed: int = 256
|
146 |
+
condition: bool = False
|
147 |
+
clip_embed: bool = False
|
148 |
+
|
149 |
+
|
150 |
+
@dataclass
|
151 |
+
class DefaultConfig:
|
152 |
+
dataset: DataConfig = DataConfig()
|
153 |
+
stage1: Stage1Config = Stage1Config()
|
154 |
+
stage2: Stage2Config = Stage2Config()
|
155 |
+
|
156 |
+
|
157 |
+
@dataclass
|
158 |
+
class FineTuningConfig:
|
159 |
+
dataset: DataConfig = DataConfig()
|
160 |
+
stage1: Stage1Config = Stage1Config()
|
161 |
+
stage2: Stage2Config = Stage2Config()
|
162 |
+
optimizer: OptConfig = OptConfig()
|
163 |
+
experiment: ExpConfig = ExpConfig()
|
164 |
+
|
165 |
+
|
166 |
+
@dataclass
|
167 |
+
class PrefixTuningConfig:
|
168 |
+
dataset: DataConfig = DataConfig()
|
169 |
+
stage1: Stage1Config = Stage1Config()
|
170 |
+
stage2: Stage2Config = Stage2Config()
|
171 |
+
prefix: PrefixModelConfig = PrefixModelConfig()
|
172 |
+
optimizer: OptConfig = OptConfig()
|
173 |
+
experiment: ExpConfig = ExpConfig()
|
174 |
+
|
175 |
+
|
176 |
+
@dataclass
|
177 |
+
class PromptTuningConfig:
|
178 |
+
dataset: DataConfig = DataConfig()
|
179 |
+
stage1: Stage1Config = Stage1Config()
|
180 |
+
stage2: Stage2Config = Stage2Config()
|
181 |
+
prompt: PromptModelConfig = PromptModelConfig()
|
182 |
+
optimizer: OptConfig = OptConfig()
|
183 |
+
experiment: ExpConfig = ExpConfig()
|
184 |
+
|
185 |
+
|
186 |
+
@dataclass
|
187 |
+
class StoryConfig:
|
188 |
+
dataset: DataConfig = DataConfig()
|
189 |
+
stage1: Stage1Config = Stage1Config()
|
190 |
+
stage2: Stage2Config = Stage2Config()
|
191 |
+
story: StoryModelConfig = StoryModelConfig()
|
192 |
+
optimizer: OptConfig = OptConfig()
|
193 |
+
experiment: ExpConfig = ExpConfig()
|
194 |
+
|
195 |
+
|
196 |
+
def get_base_config(mode):
|
197 |
+
if mode == 'default':
|
198 |
+
return OmegaConf.structured(DefaultConfig)
|
199 |
+
elif mode == 'finetuning':
|
200 |
+
return OmegaConf.structured(FineTuningConfig)
|
201 |
+
elif mode == 'prefixtuning':
|
202 |
+
return OmegaConf.structured(PrefixTuningConfig)
|
203 |
+
elif mode == 'prompt_tuning':
|
204 |
+
return OmegaConf.structured(PromptTuningConfig)
|
205 |
+
elif mode == 'story':
|
206 |
+
return OmegaConf.structured(StoryConfig)
|
207 |
+
else:
|
208 |
+
raise ValueError
|
209 |
+
# return OmegaConf.structured(DefaultConfig if use_default else FineTuningConfig)
|
dalle/utils/sampling.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------
|
2 |
+
# Minimal DALL-E
|
3 |
+
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from typing import Optional
|
9 |
+
from tqdm import tqdm
|
10 |
+
from torch.nn import functional as F
|
11 |
+
|
12 |
+
|
13 |
+
torch.set_printoptions(precision=2, threshold=10)
|
14 |
+
def cutoff_topk_logits(logits: torch.FloatTensor, k: int) -> torch.FloatTensor:
|
15 |
+
if k is None:
|
16 |
+
return logits
|
17 |
+
else:
|
18 |
+
v, ix = torch.topk(logits, k)
|
19 |
+
out = logits.clone()
|
20 |
+
out[out < v[:, [-1]]] = -float('Inf')
|
21 |
+
return out
|
22 |
+
|
23 |
+
|
24 |
+
def cutoff_topp_probs(probs: torch.FloatTensor, p: float) -> torch.FloatTensor:
|
25 |
+
if p is None:
|
26 |
+
return probs
|
27 |
+
else:
|
28 |
+
sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
|
29 |
+
cum_probs = torch.cumsum(sorted_probs, dim=-1)
|
30 |
+
|
31 |
+
sorted_idx_remove_cond = cum_probs >= p
|
32 |
+
|
33 |
+
sorted_idx_remove_cond[..., 1:] = sorted_idx_remove_cond[..., :-1].clone()
|
34 |
+
sorted_idx_remove_cond[..., 0] = 0
|
35 |
+
|
36 |
+
indices_to_remove = sorted_idx_remove_cond.scatter(-1, sorted_indices, sorted_idx_remove_cond)
|
37 |
+
probs = probs.masked_fill(indices_to_remove, 0.0)
|
38 |
+
norm_probs = probs / torch.sum(probs, dim=-1, keepdim=True)
|
39 |
+
return norm_probs
|
40 |
+
|
41 |
+
|
42 |
+
def get_positional_encoding(inputs: torch.LongTensor, mode: str = '1d') -> torch.LongTensor:
|
43 |
+
device = inputs.device
|
44 |
+
if mode == '1d':
|
45 |
+
B, N = inputs.shape
|
46 |
+
xs_pos = torch.arange(N, device=device).repeat((B, 1))
|
47 |
+
elif mode == '2d':
|
48 |
+
B, H, W = inputs.shape
|
49 |
+
xs_pos_h = torch.arange(H, device=device).repeat(B, W, 1).transpose(1, 2)
|
50 |
+
xs_pos_w = torch.arange(W, device=device).repeat(B, H, 1)
|
51 |
+
xs_pos = (xs_pos_h, xs_pos_w)
|
52 |
+
else:
|
53 |
+
raise ValueError('%s positional encoding invalid' % mode)
|
54 |
+
return xs_pos
|
55 |
+
|
56 |
+
|
57 |
+
@torch.no_grad()
|
58 |
+
def sampling(model: torch.nn.Module,
|
59 |
+
tokens: torch.LongTensor,
|
60 |
+
top_k: Optional[float] = None,
|
61 |
+
top_p: Optional[float] = None,
|
62 |
+
softmax_temperature: float = 1.0,
|
63 |
+
is_tqdm: bool = True,
|
64 |
+
use_fp16: bool = True,
|
65 |
+
max_seq_len: int = 256,
|
66 |
+
prompt: Optional[torch.tensor] = None,
|
67 |
+
pos_prompt: Optional[torch.Tensor] = None) -> torch.LongTensor:
|
68 |
+
|
69 |
+
code = None
|
70 |
+
past = None
|
71 |
+
|
72 |
+
pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
|
73 |
+
pos_enc_tokens = get_positional_encoding(tokens, mode='1d')
|
74 |
+
|
75 |
+
for cnt, h in enumerate(pbar):
|
76 |
+
if code is None:
|
77 |
+
code_ = None
|
78 |
+
pos_enc_code_ = None
|
79 |
+
else:
|
80 |
+
code_ = code.clone().detach()
|
81 |
+
pos_enc_code_ = get_positional_encoding(code_, mode='1d')
|
82 |
+
code_ = code_[:, cnt-1].unsqueeze(-1)
|
83 |
+
pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1)
|
84 |
+
|
85 |
+
logits, present = model.sampling(images=code_,
|
86 |
+
texts=tokens,
|
87 |
+
pos_images=pos_enc_code_,
|
88 |
+
pos_texts=pos_enc_tokens,
|
89 |
+
use_fp16=use_fp16,
|
90 |
+
past=past,
|
91 |
+
prompt=prompt,
|
92 |
+
pos_prompt=pos_prompt)
|
93 |
+
|
94 |
+
logits = logits.to(dtype=torch.float32)
|
95 |
+
logits = logits / softmax_temperature
|
96 |
+
|
97 |
+
# print(len(present), present[0].shape)
|
98 |
+
present = torch.stack(present).clone().detach()
|
99 |
+
if past is None:
|
100 |
+
past = [present]
|
101 |
+
else:
|
102 |
+
past.append(present)
|
103 |
+
|
104 |
+
logits = cutoff_topk_logits(logits, top_k)
|
105 |
+
probs = F.softmax(logits, dim=-1)
|
106 |
+
probs = cutoff_topp_probs(probs, top_p)
|
107 |
+
# print(probs[0])
|
108 |
+
|
109 |
+
idx = torch.multinomial(probs, num_samples=1).clone().detach()
|
110 |
+
# print(idx)
|
111 |
+
code = idx if code is None else torch.cat([code, idx], axis=1)
|
112 |
+
|
113 |
+
del past
|
114 |
+
return code
|
115 |
+
|
116 |
+
|
117 |
+
@torch.no_grad()
|
118 |
+
def sampling_prefix(model: torch.nn.Module,
|
119 |
+
tokens: torch.LongTensor,
|
120 |
+
past: torch.FloatTensor,
|
121 |
+
top_k: Optional[float] = None,
|
122 |
+
top_p: Optional[float] = None,
|
123 |
+
softmax_temperature: float = 1.0,
|
124 |
+
is_tqdm: bool = True,
|
125 |
+
use_fp16: bool = True,
|
126 |
+
max_seq_len: int = 256,
|
127 |
+
labels = None) -> torch.LongTensor:
|
128 |
+
code = None
|
129 |
+
|
130 |
+
pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
|
131 |
+
pos_enc_tokens = get_positional_encoding(tokens, mode='1d')
|
132 |
+
|
133 |
+
# print("Entering sampling_prefix; ", past.shape)
|
134 |
+
if past is not None:
|
135 |
+
past = [past]
|
136 |
+
|
137 |
+
for cnt, h in enumerate(pbar):
|
138 |
+
if code is None:
|
139 |
+
code_ = None
|
140 |
+
pos_enc_code_ = None
|
141 |
+
else:
|
142 |
+
code_ = code.clone().detach()
|
143 |
+
pos_enc_code_ = get_positional_encoding(code_, mode='1d')
|
144 |
+
code_ = code_[:, cnt-1].unsqueeze(-1)
|
145 |
+
pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1)
|
146 |
+
|
147 |
+
# print("Looop enter")
|
148 |
+
# print(cnt, past[0].shape)
|
149 |
+
# print("-------------------")
|
150 |
+
logits, present = model.sampling(images=code_,
|
151 |
+
texts=tokens,
|
152 |
+
pos_images=pos_enc_code_,
|
153 |
+
pos_texts=pos_enc_tokens,
|
154 |
+
use_fp16=use_fp16,
|
155 |
+
past=past)
|
156 |
+
logits = logits.to(dtype=torch.float32)
|
157 |
+
logits = logits / softmax_temperature
|
158 |
+
|
159 |
+
present = torch.stack(present).clone().detach()
|
160 |
+
|
161 |
+
# print('Present', present.shape)
|
162 |
+
|
163 |
+
if past is None:
|
164 |
+
past = [present]
|
165 |
+
else:
|
166 |
+
# print("Loop end")
|
167 |
+
# print(present.shape)
|
168 |
+
# print("-----------------")
|
169 |
+
|
170 |
+
# n_layers, temp, _, seq_len, n_dim = present.shape
|
171 |
+
# _, _, bs, n_heads, pre_seq_len, n_dim = past[0].shape
|
172 |
+
# assert temp == 2
|
173 |
+
# past.append(present.view(n_layers, temp, bs, n_heads, seq_len, n_dim))
|
174 |
+
|
175 |
+
past.append(present)
|
176 |
+
|
177 |
+
logits = cutoff_topk_logits(logits, top_k)
|
178 |
+
probs = F.softmax(logits, dim=-1)
|
179 |
+
probs = cutoff_topp_probs(probs, top_p)
|
180 |
+
print(torch.topk(probs, 5, dim=-1))
|
181 |
+
if labels is not None:
|
182 |
+
print(labels[cnt])
|
183 |
+
idx = torch.multinomial(probs, num_samples=1).clone().detach()
|
184 |
+
# print(idx)
|
185 |
+
code = idx if code is None else torch.cat([code, idx], axis=1)
|
186 |
+
|
187 |
+
del past
|
188 |
+
return code
|
189 |
+
|
190 |
+
|
191 |
+
@torch.no_grad()
|
192 |
+
def sampling_prefix_new(model: torch.nn.Module,
|
193 |
+
tokens: torch.LongTensor,
|
194 |
+
past: torch.FloatTensor,
|
195 |
+
top_k: Optional[float] = None,
|
196 |
+
top_p: Optional[float] = None,
|
197 |
+
softmax_temperature: float = 1.0,
|
198 |
+
is_tqdm: bool = True,
|
199 |
+
use_fp16: bool = True,
|
200 |
+
max_seq_len: int = 256) -> torch.LongTensor:
|
201 |
+
code = None
|
202 |
+
|
203 |
+
pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
|
204 |
+
pos_enc_tokens = get_positional_encoding(tokens, mode='1d')
|
205 |
+
|
206 |
+
# print("Entering sampling_prefix; ", past.shape)
|
207 |
+
if past is not None:
|
208 |
+
past = [past]
|
209 |
+
|
210 |
+
for cnt, h in enumerate(pbar):
|
211 |
+
if code is None:
|
212 |
+
code_ = None
|
213 |
+
pos_enc_code_ = None
|
214 |
+
else:
|
215 |
+
code_ = code.clone().detach()
|
216 |
+
pos_enc_code_ = get_positional_encoding(code_, mode='1d')
|
217 |
+
# code_ = code_[:, cnt-1].unsqueeze(-1)
|
218 |
+
# pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1)
|
219 |
+
|
220 |
+
# print("Looop enter")
|
221 |
+
# print(cnt, past[0].shape)
|
222 |
+
# print("-------------------")
|
223 |
+
|
224 |
+
if cnt == 0:
|
225 |
+
logits, present = model.sampling(images=code_,
|
226 |
+
texts=tokens,
|
227 |
+
pos_images=pos_enc_code_,
|
228 |
+
pos_texts=pos_enc_tokens,
|
229 |
+
use_fp16=use_fp16,
|
230 |
+
past=past)
|
231 |
+
logits = logits.to(dtype=torch.float32)
|
232 |
+
logits = logits / softmax_temperature
|
233 |
+
|
234 |
+
present = torch.stack(present).clone().detach()
|
235 |
+
|
236 |
+
# print('Present', present.shape)
|
237 |
+
|
238 |
+
if past is None:
|
239 |
+
past = [present]
|
240 |
+
else:
|
241 |
+
pass
|
242 |
+
|
243 |
+
logits = cutoff_topk_logits(logits, top_k)
|
244 |
+
probs = F.softmax(logits, dim=-1)
|
245 |
+
probs = cutoff_topp_probs(probs, top_p)
|
246 |
+
# print(torch.topk(probs[0], 5))
|
247 |
+
idx = torch.multinomial(probs, num_samples=1).clone().detach()
|
248 |
+
# print(idx)
|
249 |
+
code = idx if code is None else torch.cat([code, idx], axis=1)
|
250 |
+
|
251 |
+
else:
|
252 |
+
pass
|
253 |
+
|
254 |
+
|
255 |
+
del past
|
256 |
+
return code
|
257 |
+
|
258 |
+
@torch.no_grad()
|
259 |
+
def sampling_conditional(model: torch.nn.Module,
|
260 |
+
cross_attention_idxs,
|
261 |
+
cross_attention_layers,
|
262 |
+
tokens: torch.LongTensor,
|
263 |
+
src_codes: torch.FloatTensor,
|
264 |
+
top_k: Optional[float] = None,
|
265 |
+
top_p: Optional[float] = None,
|
266 |
+
softmax_temperature: float = 1.0,
|
267 |
+
is_tqdm: bool = True,
|
268 |
+
use_fp16: bool = True,
|
269 |
+
max_seq_len: int = 256,
|
270 |
+
prompt: Optional[torch.tensor] = None,
|
271 |
+
pos_prompt: Optional[torch.Tensor] = None) -> torch.LongTensor:
|
272 |
+
|
273 |
+
code = None
|
274 |
+
past = None
|
275 |
+
|
276 |
+
pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
|
277 |
+
pos_enc_tokens = get_positional_encoding(tokens, mode='1d')
|
278 |
+
|
279 |
+
src_pos_tokens = get_positional_encoding(src_codes, mode='1d')
|
280 |
+
src_tokens = model.tok_emb_img(src_codes)
|
281 |
+
src_tokens = src_tokens + model.pos_emb_img(src_pos_tokens)
|
282 |
+
|
283 |
+
for cnt, h in enumerate(pbar):
|
284 |
+
if code is None:
|
285 |
+
code_ = None
|
286 |
+
pos_enc_code_ = None
|
287 |
+
else:
|
288 |
+
code_ = code.clone().detach()
|
289 |
+
pos_enc_code_ = get_positional_encoding(code_, mode='1d')
|
290 |
+
code_ = code_[:, cnt-1].unsqueeze(-1)
|
291 |
+
pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1)
|
292 |
+
|
293 |
+
logits, present = model.sampling_with_context(images=code_,
|
294 |
+
cross_attention_idxs=cross_attention_idxs,
|
295 |
+
cross_attention_layers=cross_attention_layers,
|
296 |
+
texts=tokens,
|
297 |
+
pos_images=pos_enc_code_,
|
298 |
+
pos_texts=pos_enc_tokens,
|
299 |
+
source_image=src_tokens,
|
300 |
+
use_fp16=use_fp16,
|
301 |
+
past=past,
|
302 |
+
prompt=prompt,
|
303 |
+
pos_prompt=pos_prompt)
|
304 |
+
logits = logits.to(dtype=torch.float32)
|
305 |
+
logits = logits / softmax_temperature
|
306 |
+
|
307 |
+
present = torch.stack(present).clone().detach()
|
308 |
+
if past is None:
|
309 |
+
past = [present]
|
310 |
+
else:
|
311 |
+
past.append(present)
|
312 |
+
|
313 |
+
logits = cutoff_topk_logits(logits, top_k)
|
314 |
+
probs = F.softmax(logits, dim=-1)
|
315 |
+
probs = cutoff_topp_probs(probs, top_p)
|
316 |
+
|
317 |
+
idx = torch.multinomial(probs, num_samples=1).clone().detach()
|
318 |
+
code = idx if code is None else torch.cat([code, idx], axis=1)
|
319 |
+
|
320 |
+
del past
|
321 |
+
return code
|
322 |
+
|
323 |
+
|
324 |
+
@torch.no_grad()
|
325 |
+
def sampling_igpt(model: torch.nn.Module,
|
326 |
+
sos: torch.FloatTensor,
|
327 |
+
top_k: Optional[float] = None,
|
328 |
+
top_p: Optional[float] = None,
|
329 |
+
softmax_temperature: float = 1.0,
|
330 |
+
is_tqdm: bool = True,
|
331 |
+
use_fp16: bool = True,
|
332 |
+
max_seq_len: int = 256) -> torch.LongTensor:
|
333 |
+
code = None
|
334 |
+
past = None
|
335 |
+
pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
|
336 |
+
|
337 |
+
for cnt, h in enumerate(pbar):
|
338 |
+
if code is None:
|
339 |
+
code_ = None
|
340 |
+
pos_enc_code_ = None
|
341 |
+
else:
|
342 |
+
code_ = code.clone().detach()
|
343 |
+
pos_enc_code_ = get_positional_encoding(code_, mode='1d')
|
344 |
+
code_ = code_[:, cnt-1].unsqueeze(-1)
|
345 |
+
pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1)
|
346 |
+
|
347 |
+
logits, present = model.sampling(sos=sos,
|
348 |
+
codes=code_,
|
349 |
+
pos_codes=pos_enc_code_,
|
350 |
+
use_fp16=use_fp16,
|
351 |
+
past=past)
|
352 |
+
logits = logits.to(dtype=torch.float32)
|
353 |
+
logits = logits / softmax_temperature
|
354 |
+
|
355 |
+
present = torch.stack(present).clone().detach()
|
356 |
+
if past is None:
|
357 |
+
past = [present]
|
358 |
+
else:
|
359 |
+
past.append(present)
|
360 |
+
|
361 |
+
logits = cutoff_topk_logits(logits, top_k)
|
362 |
+
probs = F.softmax(logits, dim=-1)
|
363 |
+
probs = cutoff_topp_probs(probs, top_p)
|
364 |
+
|
365 |
+
idx = torch.multinomial(probs, num_samples=1).clone().detach()
|
366 |
+
code = idx if code is None else torch.cat([code, idx], axis=1)
|
367 |
+
|
368 |
+
del past
|
369 |
+
return code
|
dalle/utils/utils.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------
|
2 |
+
# Minimal DALL-E
|
3 |
+
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
import os
|
8 |
+
import random
|
9 |
+
import urllib
|
10 |
+
import hashlib
|
11 |
+
import tarfile
|
12 |
+
import torch
|
13 |
+
import clip
|
14 |
+
import numpy as np
|
15 |
+
from PIL import Image
|
16 |
+
from torch.nn import functional as F
|
17 |
+
from tqdm import tqdm
|
18 |
+
import torchvision.utils as vutils
|
19 |
+
import matplotlib.pyplot as plt
|
20 |
+
|
21 |
+
|
22 |
+
def set_seed(seed: int):
|
23 |
+
random.seed(seed)
|
24 |
+
np.random.seed(seed)
|
25 |
+
torch.manual_seed(seed)
|
26 |
+
torch.cuda.manual_seed_all(seed)
|
27 |
+
|
28 |
+
|
29 |
+
@torch.no_grad()
|
30 |
+
def clip_score(prompt: str,
|
31 |
+
images: np.ndarray,
|
32 |
+
model_clip: torch.nn.Module,
|
33 |
+
preprocess_clip,
|
34 |
+
device: str) -> np.ndarray:
|
35 |
+
images = [preprocess_clip(Image.fromarray((image*255).astype(np.uint8))) for image in images]
|
36 |
+
images = torch.stack(images, dim=0).to(device=device)
|
37 |
+
texts = clip.tokenize(prompt).to(device=device)
|
38 |
+
texts = torch.repeat_interleave(texts, images.shape[0], dim=0)
|
39 |
+
|
40 |
+
image_features = model_clip.encode_image(images)
|
41 |
+
text_features = model_clip.encode_text(texts)
|
42 |
+
|
43 |
+
scores = F.cosine_similarity(image_features, text_features).squeeze()
|
44 |
+
rank = torch.argsort(scores, descending=True).cpu().numpy()
|
45 |
+
return rank
|
46 |
+
|
47 |
+
|
48 |
+
def download(url: str, root: str) -> str:
|
49 |
+
os.makedirs(root, exist_ok=True)
|
50 |
+
filename = os.path.basename(url)
|
51 |
+
pathname = filename[:-len('.tar.gz')]
|
52 |
+
|
53 |
+
expected_md5 = url.split("/")[-2]
|
54 |
+
download_target = os.path.join(root, filename)
|
55 |
+
result_path = os.path.join(root, pathname)
|
56 |
+
|
57 |
+
if os.path.isfile(download_target) and (os.path.exists(result_path) and not os.path.isfile(result_path)):
|
58 |
+
return result_path
|
59 |
+
|
60 |
+
with urllib.request.urlopen(url) as source, open(download_target, 'wb') as output:
|
61 |
+
with tqdm(total=int(source.info().get('Content-Length')), ncols=80, unit='iB', unit_scale=True,
|
62 |
+
unit_divisor=1024) as loop:
|
63 |
+
while True:
|
64 |
+
buffer = source.read(8192)
|
65 |
+
if not buffer:
|
66 |
+
break
|
67 |
+
|
68 |
+
output.write(buffer)
|
69 |
+
loop.update(len(buffer))
|
70 |
+
|
71 |
+
if hashlib.md5(open(download_target, 'rb').read()).hexdigest() != expected_md5:
|
72 |
+
raise RuntimeError(f'Model has been downloaded but the md5 checksum does not not match')
|
73 |
+
|
74 |
+
with tarfile.open(download_target, 'r:gz') as f:
|
75 |
+
pbar = tqdm(f.getmembers(), total=len(f.getmembers()))
|
76 |
+
for member in pbar:
|
77 |
+
pbar.set_description(f'extracting: {member.name} (size:{member.size // (1024 * 1024)}MB)')
|
78 |
+
f.extract(member=member, path=root)
|
79 |
+
|
80 |
+
return result_path
|
81 |
+
|
82 |
+
|
83 |
+
def realpath_url_or_path(url_or_path: str, root: str = None) -> str:
|
84 |
+
if urllib.parse.urlparse(url_or_path).scheme in ('http', 'https'):
|
85 |
+
return download(url_or_path, root)
|
86 |
+
return url_or_path
|
87 |
+
|
88 |
+
|
89 |
+
def images_to_numpy(tensor):
|
90 |
+
generated = tensor.data.cpu().numpy().transpose(1,2,0)
|
91 |
+
generated[generated < -1] = -1
|
92 |
+
generated[generated > 1] = 1
|
93 |
+
generated = (generated + 1) / 2 * 255
|
94 |
+
return generated.astype('uint8')
|
95 |
+
|
96 |
+
|
97 |
+
def save_image(ground_truth, images, out_dir, batch_idx):
|
98 |
+
|
99 |
+
for i, im in enumerate(images):
|
100 |
+
if len(im.shape) == 3:
|
101 |
+
plt.imsave(os.path.join(out_dir, 'test_%s_%s.png' % (batch_idx, i)), im)
|
102 |
+
else:
|
103 |
+
bs = im.shape[0]
|
104 |
+
# plt.imsave()
|
105 |
+
for j in range(bs):
|
106 |
+
plt.imsave(os.path.join(out_dir, 'test_%s_%s_%s.png' % (batch_idx, i, j)), im[j])
|
107 |
+
|
108 |
+
|
109 |
+
# print("Ground truth Images shape: ", ground_truth.shape, len(images))
|
110 |
+
|
111 |
+
# images = vutils.make_grid(images, nrow=ground_truth.shape[0])
|
112 |
+
# images = images_to_numpy(images)
|
113 |
+
#
|
114 |
+
# if ground_truth is not None:
|
115 |
+
# ground_truth = vutils.make_grid(ground_truth, 5)
|
116 |
+
# ground_truth = images_to_numpy(ground_truth)
|
117 |
+
# print("Ground Truth shape, Generated Images shape: ", ground_truth.shape, images.shape)
|
118 |
+
# images = np.concatenate([ground_truth, images], axis=0)
|
119 |
+
#
|
120 |
+
# output = Image.fromarray(images)
|
121 |
+
# output.save('%s/fake_samples_epoch_%03d.png' % (out_dir, batch_idx))
|
122 |
+
|
123 |
+
# if texts is not None:
|
124 |
+
# fid = open('%s/fake_samples_epoch_%03d_%03d.txt' % (image_dir, epoch, idx), 'w')
|
125 |
+
# for idx in range(images.shape[0]):
|
126 |
+
# fid.write(str(idx) + '--------------------------------------------------------\n')
|
127 |
+
# for i in range(len(texts)):
|
128 |
+
# fid.write(texts[i][idx] + '\n')
|
129 |
+
# fid.write('\n\n')
|
130 |
+
# fid.close()
|
131 |
+
return
|
demo/Barney.png
ADDED
demo/Betty.png
ADDED
demo/Crong.png
ADDED
demo/Dino.png
ADDED
demo/Eddy.png
ADDED
demo/Fred.png
ADDED
demo/Harry.png
ADDED
demo/Loopy.png
ADDED
demo/MrSlate.png
ADDED
demo/Pebbles.png
ADDED
demo/Petty.png
ADDED
demo/Poby.png
ADDED
demo/Pororo.png
ADDED
demo/Rody.png
ADDED
demo/Tongtong.png
ADDED