nousr commited on
Commit
f3bcd49
·
1 Parent(s): ee215f0

overhaul repo

Browse files
README.md CHANGED
@@ -12,24 +12,75 @@ Models are diffusion trainers from https://github.com/lucidrains/DALLE2-pytorch
12
  Join Us!: https://discord.gg/uPMftTmrvS
13
 
14
  ---
 
 
15
 
16
- # Models
17
- The repo currently has many models, I recommend using the latest EMA model checkpoints as they are the best performing models right now.
18
 
19
- > **_DISCLAIMER_**: **I will be removing many of the older models**. They were trained on older versions of *DALLE2 PyTorch* and massively under perform compared to recent models. **If for whatever reason you want an old model please make a backup** (you have 7 days from this README commit timestamp).
20
 
21
- ### Loading the models might look something like this:
22
 
23
- > Note: This repo's documentation will get an overhaul \~soon\~. If you're reading this, and having issues loading checkpoints, please reach out on LAION.
 
 
 
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  ```python
26
  import torch
27
  from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
28
  from dalle2_pytorch.trainer import DiffusionPriorTrainer
29
 
30
- def load_diffusion_model(dprior_path, device):
31
-
32
- # If you are getting issues with size mismatches, it's likely this configuration
33
  prior_network = DiffusionPriorNetwork(
34
  dim=768,
35
  depth=24,
@@ -44,8 +95,7 @@ def load_diffusion_model(dprior_path, device):
44
  num_timesteps=1000,
45
  ff_mult=4
46
  )
47
-
48
- # currently, only ViT-L/14 models are being trained
49
  diffusion_prior = DiffusionPrior(
50
  net=prior_network,
51
  clip=OpenAIClipAdapter("ViT-L/14"),
@@ -56,10 +106,7 @@ def load_diffusion_model(dprior_path, device):
56
  condition_on_text_encodings=True,
57
 
58
  )
59
-
60
- # this will load the entire trainer
61
- # If you only want EMA weights for inference you will need to extract them yourself for now
62
- # (if you beat me to writing a nice function for that please make a PR on Github!)
63
  trainer = DiffusionPriorTrainer(
64
  diffusion_prior=diffusion_prior,
65
  lr=1.1e-4,
@@ -71,8 +118,80 @@ def load_diffusion_model(dprior_path, device):
71
  device=device,
72
  accelerator=None,
73
  )
74
-
75
  trainer.load(dprior_path)
76
-
77
  return trainer
78
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  Join Us!: https://discord.gg/uPMftTmrvS
13
 
14
  ---
15
+ # Diffusion Prior
16
+ This readme serves as an introduction to the diffusion prior.
17
 
18
+ ## Intro
 
19
 
20
+ A properly trained prior will allow you to translate between two embedding spaces. If you know *a priori* that two embeddings are connected some way—then ability the translate between them could extremely helpful.
21
 
22
+ ### Motivation
23
 
24
+ Before we dive into the model, let’s look at a quick example of where the model may be helpful.
25
+
26
+ For demonstration purposes we will imagine that we wish to generate images from text using CLIP and a Decoder.
27
+
28
+ > [CLIP](https://openai.com/blog/clip/) is a contrastive model that learns to maximize the cosine similarity between a given image and caption, however, there is no guarantee that these embeddings are in the same space. While the embeddings generated are ***close*** the image and text embeddings occupy two disjoint sets.
29
 
30
+ ```python
31
+ # Load Models
32
+ clip_model = clip.load("ViT-L/14")
33
+ decoder = Decoder(checkpoint="best.pth") # A decoder trained on CLIP Image embeddings
34
+
35
+ # Retrieve prompt from user and encode with CLIP
36
+ prompt = "A corgi wearing sunglasses"
37
+ tokenized_text = tokenize(prompt)
38
+ text_embedding = clip_model.encode_text(tokenized_text)
39
+
40
+ # Now, pass the text embedding to the decoder
41
+ predicted_image = decoder.sample(text_embedding)
42
+ ```
43
+
44
+ > **Question**: *Can you spot the issue here?*
45
+ >
46
+ > **Answer**: *We’re trying to generate an image from a text embedding!*
47
+
48
+ Unfortunately, we run into the issue previously mentioned--the image embeddings and the text embeddings are not interchangeable! Now let's look at a better solution
49
+
50
+ ```python
51
+ # Load Models
52
+ prior= Prior(checkpoint="prior.pth") # A decoder trained to go from: text-> clip text emb -> clip img emb
53
+ decoder = Decoder(checkpoint="decoder.pth") # A decoder trained on CLIP Image embeddings
54
+
55
+ # Retrieve prompt from user and encode with a prior
56
+ prompt = "A corgi wearing sunglasses"
57
+ tokenized_text = tokenize(prompt)
58
+ text_embedding = prior.sample(tokenized_text) # <-- now we get an embedding in the same space as images!
59
+
60
+ # Now, pass the predicted image embedding to the decoder
61
+ predicted_image = decoder.sample(text_embedding)
62
+ ```
63
+
64
+ With the prior we are able to successfully generate embeddings *within* CLIP's image space! For this reason, the decoder will perform much better as it receives input that is much closer to its training data.
65
+
66
+ > **You may be asking yourself the following question:**
67
+ >
68
+ > *"Why don't you just train the decoder on clip text embeddings instead of image embeddings?"*
69
+ >
70
+ > OpenAI covers this topic in their [DALLE-2 paper](https://arxiv.org/abs/2204.06125). The TL;DR is *"it doesn't work as well as decoders trained on image embeddings"*...also...its just an example :smile:
71
+
72
+ ## Usage
73
+
74
+ To utilize a pre-trained prior, it’s quite simple.
75
+
76
+ ### Loading Checkpoints
77
  ```python
78
  import torch
79
  from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
80
  from dalle2_pytorch.trainer import DiffusionPriorTrainer
81
 
82
+ def load_diffusion_model(dprior_path):
83
+
 
84
  prior_network = DiffusionPriorNetwork(
85
  dim=768,
86
  depth=24,
 
95
  num_timesteps=1000,
96
  ff_mult=4
97
  )
98
+
 
99
  diffusion_prior = DiffusionPrior(
100
  net=prior_network,
101
  clip=OpenAIClipAdapter("ViT-L/14"),
 
106
  condition_on_text_encodings=True,
107
 
108
  )
109
+
 
 
 
110
  trainer = DiffusionPriorTrainer(
111
  diffusion_prior=diffusion_prior,
112
  lr=1.1e-4,
 
118
  device=device,
119
  accelerator=None,
120
  )
121
+
122
  trainer.load(dprior_path)
123
+
124
  return trainer
125
+ ```
126
+
127
+ Here we instantiate a model matches the configuration it was trained with, and then load the weights (*just like any other PyTorch model!*)
128
+
129
+ ### Sampling
130
+ Once we have a pre-trained model, generating embeddings is quite simple!
131
+ ```python
132
+ # tokenize the text
133
+ tokenized_text = clip.tokenize("<your amazing prompt>")
134
+ # predict an embedding
135
+ predicted_embedding = prior.sample(tokenized_text, n_samples_per_batch=2, cond_scale=1.0)
136
+ ```
137
+
138
+ The resulting tensor returned from `.sample()` is of the same shape as your training data along the non-batch dimension(s). For example, a prior trained on `ViT-L/14` embeddings will predict an embedding of shape (1, 768).
139
+
140
+ > For CLIP priors, this is quite handy as it means that you can use prior.sample(tokenizer_text) as a drop in replacement for clip.encode_text().
141
+
142
+ **Some things to note:**
143
+ * It is possible to specify the number of embeddings to sample from (the default suggested by OpenAI is `n=2`). Put simply, the idea here is that you avoid getting unlucky with a bad embedding generation by creating two; and selecting the one with the higher cosine similarity with the prompt.
144
+ * You may specify a higher conditioning scale than the default (`1.0`). It is unclear whether OpenAI uses a higher value for the prior specifically, or only on the decoder. Local testing has shown poor results with anything higher than `1.0` but *ymmv*.
145
+
146
+ ---
147
+
148
+ ## Training
149
+
150
+ ### Overview
151
+
152
+ Training the prior is a relatively straightforward process thanks to the Trainer base class. The major step that is required of you is preparing a dataset in the format that EmbeddingReader expects. Having pre-computed embeddings massively increases training efficiency and is generally recommended as you will likely benefit from having them on hand for other tasks as well. Once you have a dataset, you are ready to move onto configuration
153
+
154
+ ## Dataset
155
+
156
+ To train the prior, it is highly recommended to use precomputed embeddings for the images. To obtain these for a custom dataset, you can leverage [img2datset](https://github.com/rom1504/img2dataset) to pull images from a list of URLs and [clip_retrieval](https://github.com/rom1504/clip-retrieval#clip-inference) for generating the actual embeddings that can be used in the prior's dataloader.
157
+
158
+ ## Configuration
159
+
160
+ The configuration file allows for you to easily track and reproduce experiments. It is a simple JSON file that will specify the architecture, dataset, and training parameters. For more information and specifics please see the configuration README.
161
+
162
+ ## Distributed Training
163
+
164
+ If you would like to train in a distributed manner we have opted to leverage huggingface’ new Accelerate library. HFA makes it extremely simple to distribute work across multiple GPU’s and nodes. All that is required of you is to follow the simple CLI configuration tool [more information here](https://huggingface.co/docs/accelerate/accelerator).
165
+
166
+ ## Evaluation
167
+
168
+ There are a variety of metrics available to you when training the prior. You can read a brief description of each in the table below:
169
+ | Metric | Description | Comments |
170
+ | ----------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
171
+ | Online Model Validation | The validation loss associated with your online model. | Ideally validation loss will be as low as possible. Using L2 loss, values as low as `0.1` and lower are possible after around 1 Billion samples seen. |
172
+ | EMA Validation | This metric measures the validation loss associated with your EMA model. | This will likely lag behind your "online" model's validation loss, but should outperform in the long-term. |
173
+ | Baseline Similarity | Baseline similarity refers to the similarity between your dataset's prompts and associated image embeddings. This will serve as a guide for your prior's performance in cosine similarity. | Generally `0.3` is considered a good cosine similarity for caption similarity. |
174
+ | Similarity With Original Image | This metric will measure the cosine similarity between your prior's predicted image embedding and the actual image that the caption was associated with. This is useful for determining wether your prior is generating images with the right contents. | Values around `0.75`+ are obtainable. This metric should improve rapidly in the early stages of training and plateau with diminishing increases over time. If it takes hundreds of millions of samples to reach above `0.5`/`0.6` similarity--then you likely are suffering from some kind of training error or inefficiency (i.e. not using EMA) |
175
+ | Difference From Baseline Similarity | Sometimes its useful to visualize a metric in another light. This metric will show you how your prior's predicted image embeddings match up with the baseline similarity measured in your dataset. | This value should float around `0.0` with some room for variation. After a billion samples seen, values are within `0.01`+/- of `0.0`. If this climbs to high, (~>`0.02`) then this may be a sign that your model is overfitting somehow. |
176
+ | Similarity With Text | This metric is your bread and butter cosine similarity between the predicted image embedding and the original caption given to the prior. Monitoring this metric will be on of your main focuses and is probably the second most important behind your loss. | As mentioned, this value should be close to baseline similarity. We have observed early rapid increase with diminishing returns as the prior learns to generate valid image embeddings. If this value increases too far beyond the baseline similarity--it could be an indication that your model is overfitting. |
177
+ | Similarity With Unrelated Caption | This metric will attempt to exposed an overfit prior by feeding it arbitrary prompts (from your dataset) and then measure the similarity of this predicted embedding with some other image. | Early on we found that a poorly trained/modeled prior could effectively fool CLIP into believing that the cosine similarity between two images were high (when in fact the caption and image were completely unrelated). With this in mind--a low value is ideal, anything below `0.1` is probably safe. |
178
+
179
+ ## Launching the script
180
+
181
+ Now that you’ve done all the prep it’s time for the easy part! 🚀
182
+
183
+ To actually launch the script, you will either use `accelerate launch train_diffusion_prior.py --config_path <path to your config>` to launch with distributed training & huggingface accelerate or `python train_diffusion_prior.py` if you would like to train on your gpu/cpu without huggingface accelerate.
184
+
185
+ ## Checkpointing
186
+
187
+ Checkpoints will be saved to the directory specified in your configuration file.
188
+
189
+ Additionally, a final checkpoint is saved before running the test split. This file will be saved to the same directory and titled “latest.pth”. This is to avoid problems where your `save_every` configuration does not overlap with the number of steps required to do a complete pass through the data.
190
+
191
+ ## Things To Keep In Mind
192
+
193
+ The prior has not been trained for tasks other than the traditional CLIP embedding translation…at least yet.
194
+
195
+ As we finalize the replication of unCLIP, there will almost assuredly be experiments attempting to apply the prior network to other tasks.
196
+
197
+ With that in mind, you are more or less a pioneer in embedding-translation if you are reading this and attempting something you don’t see documentation for!
old_models/medium-half-lr.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c5a69ff7cdcb59d1aba8a2804337521221c74a26c1e90efa69ee0ab2c8bb8836
3
- size 3555893015
 
 
 
 
old_models/vit-b-20k.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4c0bd55cbac72a762b7a076d1aa2bab7d6f87372787fb1e63f5a60647fb6d729
3
- size 958496157
 
 
 
 
old_models/vit-l-100k.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0eba6d7a2fda049f6f25d8519936b8b10f4d8460833527eb658d2c0139e6f23f
3
- size 2200498685
 
 
 
 
old_models/vit-l-10k.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:268684b64d129c5bbc92ba9da117ee6db6a54276c71e177b2cf8f238136f3ab2
3
- size 2200498429
 
 
 
 
old_models/vit-l-20k.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a56514627cfb916f1be24467c1b4806722d62a49bd3daf3cb765507626ddc25a
3
- size 2200498429
 
 
 
 
old_models/vit-l-50k.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0ef4235284fd3e69b26ce9c3d535835b30197493a6dea155e733e830f9440704
3
- size 2200498429
 
 
 
 
old_models/vit-l-60k.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:fb108c627868d0650c01d896967cf8c3fc2f113a946796a7148e1a64c27fc1f3
3
- size 2200498429
 
 
 
 
old_models/vit-l-70k.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e357efc561a6629fb7be375da2f753aa1b3ed772d9dab56895bf52d0dafa8c23
3
- size 2200498685
 
 
 
 
old_models/vit-l-87k.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:da0e3ba64717705fc0439dfd5aa790a8c8da068a423fdfc75d9db487e27a2146
3
- size 2200498685
 
 
 
 
vit-l-14/aesthetic/coming_soon ADDED
File without changes
vit-l-14/{ema080M.pth → laion2b/ema080M.pth} RENAMED
File without changes
vit-l-14/{ema136M.pth → laion2b/ema136M.pth} RENAMED
File without changes
vit-l-14/{ema160M.pth → laion2b/ema160M.pth} RENAMED
File without changes
vit-l-14/{ema200M.pth → laion2b/ema200M.pth} RENAMED
File without changes
vit-l-14/{ema224M.pth → laion2b/ema224M.pth} RENAMED
File without changes
vit-l-14/{ema375M.pth → laion2b/ema375M.pth} RENAMED
File without changes
vit-l-14/{ema400M.pth → laion2b/ema400M.pth} RENAMED
File without changes
vit-l-14/{ema465M.pth → laion2b/ema465M.pth} RENAMED
File without changes
vit-l-14/{ema540M.pth → laion2b/ema540M.pth} RENAMED
File without changes
vit-l-14/{ema615M.pth → laion2b/ema615M.pth} RENAMED
File without changes
vit-l-14/{ema855M.pth → laion2b/ema855M.pth} RENAMED
File without changes