Andrew DalPino
commited on
Commit
·
cac8fe7
0
Parent(s):
Initial commit
Browse files- .gitattributes +35 -0
- .gitignore +15 -0
- README.md +136 -0
- beam_search.py +104 -0
- data.py +254 -0
- dataset/.gitignore +2 -0
- generate.py +100 -0
- instruction-tune.py +197 -0
- model.py +499 -0
- model_sizing.ipynb +330 -0
- models/lightgpt-small.pt +3 -0
- out/.gitignore +2 -0
- pre-train.py +320 -0
- requirements.txt +7 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
.mypy_cache/
|
3 |
+
env/
|
4 |
+
build/
|
5 |
+
develop-eggs/
|
6 |
+
dist/
|
7 |
+
lib/
|
8 |
+
lib64/
|
9 |
+
wheels/
|
10 |
+
*.egg-info/
|
11 |
+
.installed.cfg
|
12 |
+
*.egg
|
13 |
+
.venv
|
14 |
+
venv/
|
15 |
+
ENV/
|
README.md
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LightGPT
|
2 |
+
|
3 |
+
A lightweight generative pre-trained Transformer (GPT) for the people! A unique feature of LightGPT is that it gives you the ability to progressively trade off compute for additional memory-efficiency - allowing you to train large models on smaller consumer hardware. It also supports memory-efficient training over multiple GPUs or clusters of GPUs using PyTorch's Distributed Data Parallel (DDP) protocol with ZeRO Redundancy sharding. Unlike closed-source LLMs, LightGPT provides both the model weights *and* the code to train and fine-tune the model yourself.
|
4 |
+
|
5 |
+
## What makes LightGPT different?
|
6 |
+
|
7 |
+
- **Parameter-efficiency**: LightGPT aims to be a more parsimonious model by only training parameters that are absolutely necessary. As such, biases and positional embeddings have been completely removed from the neural network architecture. In addition, the token embeddings and output layer share their weights resulting in a further reduction in trainable parameters.
|
8 |
+
|
9 |
+
- **Training efficiency**: Compared to Adam, LightGPT's Adafactor optimizer reduces the number of training-time buffers from O(n*m) to O(n+m) for every trainable weight matrix with little difference in runtime or minima quality. In addition, with activation check-pointing, model, gradient, and optimizer state sharding, we can reduce the number of buffers needed during training by a factor of 10 or more.
|
10 |
+
|
11 |
+
- **Fully open-source**: Want to train your own LightGPT? Go right ahead! In addition to our model weights, we also release our training and inferencing code so you can train one yourself. With the power of open-source, we hope that others can learn from and continue improving LightGPT over time.
|
12 |
+
|
13 |
+
|
14 |
+
## Install Project Dependencies
|
15 |
+
|
16 |
+
Project dependencies are specified in the `requirements.txt` file. You can install them with [pip](https://pip.pypa.io/en/stable/) using the following command from the project root. I recommend using a virtual environment such as venv to keep package dependencies on your system tidy.
|
17 |
+
|
18 |
+
```
|
19 |
+
python -m venv ./.venv
|
20 |
+
|
21 |
+
source ./.venv/bin/activate
|
22 |
+
|
23 |
+
pip install -r requirements.txt
|
24 |
+
```
|
25 |
+
|
26 |
+
## Quick Start
|
27 |
+
|
28 |
+
If you'd just like to start training right away, the default settings should work on most single-GPU systems with 12G of VRAM or more.
|
29 |
+
|
30 |
+
```
|
31 |
+
python pre-train.py
|
32 |
+
```
|
33 |
+
|
34 |
+
> Note that it will take a while to download and pre-process the dataset the first time that the training script is run.
|
35 |
+
|
36 |
+
If you have a larger system you can increase the training load by increasing the capacity of the network and `batch_size` at runtime.
|
37 |
+
|
38 |
+
```
|
39 |
+
python pre-train.py --embedding_dimensions=1024 --num_hidden_layers=24 --batch_size=8
|
40 |
+
```
|
41 |
+
|
42 |
+
To distribute the training workload over a cluster of GPUs or multiple cluster nodes, use PyTorch's [torchrun](https://pytorch.org/docs/stable/elastic/run.html) extension to launch a distributed data parallel session.
|
43 |
+
|
44 |
+
```
|
45 |
+
torchrun --standalone --nnodes=1 --nproc-per-node=8 pre-train.py --batch_size=16 --gradient_accumulation_steps=32
|
46 |
+
```
|
47 |
+
|
48 |
+
> Note that when training in data-parallel mode it's important that the `gradient_accumulation_steps` divides evenly into the world size for maximum performance. For example, if we have an 8 GPU cluster, we could perform 32 gradient accumulation steps in exactly 4 passes over the network.
|
49 |
+
|
50 |
+
After training, you can generate text from the model by running the `generate.py` script from the commandline with a prompt.
|
51 |
+
|
52 |
+
```
|
53 |
+
python generate.py
|
54 |
+
```
|
55 |
+
|
56 |
+
### Pre-training Arguments
|
57 |
+
|
58 |
+
| Argument | Default | Type | Description |
|
59 |
+
|---|---|---|---|
|
60 |
+
| --batch_size | 1 | int | The number of samples to pass through the network at a time. |
|
61 |
+
| --gradient_accumulation_steps | 128 | int | The number of batches to pass through the network before updating the weights. |
|
62 |
+
| --samples_per_epoch | 4096 | int | The number of training samples to pass through the network every epoch. |
|
63 |
+
| --learning_rate | 5e-4 | float | The global step size taken after every gradient accumulation step. |
|
64 |
+
| --max_gradient_norm | 1.0 | float | Clip gradients above this threshold before stepping. |
|
65 |
+
| --num_epochs | 2145 | int | The number of epochs to train for. |
|
66 |
+
| --eval_interval | 10 | int | Evaluate the model after this many epochs on the testing set. |
|
67 |
+
| --block_size | 1024 | int | The number of tokens within the context window for every sample. |
|
68 |
+
| --embedding_dimensions | 1024 | int | The dimensionality of the token embeddings. |
|
69 |
+
| --num_attention_heads | 16 | int | The number of attention heads within every block. |
|
70 |
+
| --num_hidden_layers | 24 | int | The number of attention/MLP blocks within the hidden layer of the network. |
|
71 |
+
| --dropout | 0.1 | float | The proportion of signals to send to zero during training as regularization. |
|
72 |
+
| --activation_checkpointing | False | bool | Should we use activation checkpointing? |
|
73 |
+
| --checkpoint_interval | 20 | int | Save the model parameters to disk every this many epochs. |
|
74 |
+
| --checkpoint_path | "./out/checkpoint.pt" | string | The path to the checkpoint file on disk. |
|
75 |
+
| --dataset_path | "./dataset" | string | The path to the dataset files on disk. |
|
76 |
+
| --num_dataset_processes | 8 | int | The number of processes (CPUs) to use to process the dataset. |
|
77 |
+
| --resume | False | bool | Should we resume training from the last checkpoint? |
|
78 |
+
| --device | "cuda" | string | The device to run the computation on. |
|
79 |
+
| --seed | None | int | The seed for the random number generator. |
|
80 |
+
|
81 |
+
### Instruction-tuning Arguments
|
82 |
+
|
83 |
+
| Argument | Default | Type | Description |
|
84 |
+
|---|---|---|---|
|
85 |
+
| --base_model_path | "./out/checkpoint.pt" | string | The path to the pre-trained model. |
|
86 |
+
| --batch_size | 1 | int | The number of samples to pass through the network at a time. |
|
87 |
+
| --gradient_accumulation_steps | 128 | int | The number of batches to pass through the network before updating the weights. |
|
88 |
+
| --learning_rate | 5e-4 | float | The global step size taken after every gradient accumulation step. |
|
89 |
+
| --mask_input | False | bool | Should we mask the input part of the sample i.e. only train on the output? |
|
90 |
+
| --rank | 8 | int | The rank of the LoRA decomposition matrices. |
|
91 |
+
| --alpha | 1.0 | float | The strength of the LoRA signal. |
|
92 |
+
| --dropout | 0.05 | float | The proportion of signals to send to zero during training as regularization. |
|
93 |
+
| --num_epochs | 4 | int | The number of epochs to train for. |
|
94 |
+
| --eval_interval | 1 | int | Evaluate the model after this many epochs on the testing set. |
|
95 |
+
| --checkpoint_interval | 1 | int | Save the model parameters to disk every this many epochs. |
|
96 |
+
| --checkpoint_path | "./out/lora_instruction.pt" | string | The path to the checkpoint file on disk. |
|
97 |
+
| --resume | False | bool | Should we resume training from the last checkpoint? |
|
98 |
+
| --device | "cuda" | string | The device to run the computation on. |
|
99 |
+
| --seed | None | int | The seed for the random number generator. |
|
100 |
+
|
101 |
+
### Generation Arguments
|
102 |
+
|
103 |
+
| Argument | Default | Type | Description |
|
104 |
+
|---|---|---|---|
|
105 |
+
| --checkpoint_path | "./out/checkpoint.pt" | string | The path to the checkpoint file on disk. |
|
106 |
+
| --lora_path | None | string | The path to the LoRA checkpoint. |
|
107 |
+
| --max_tokens | 500 | int | The maximum number of tokens that the model should generate per sample. |
|
108 |
+
| --temperature | 1.0 | float | The amount of regularization applied to the candidate token probabilities. |
|
109 |
+
| --top_k | 500 | int | Only sample from this many candidate tokens with the highest probabilities. |
|
110 |
+
| --top_p | 0.9 | float | Of the `top_k` tokens, drop all but the `top_p` portion of the cumulative probability distribution. |
|
111 |
+
| --device | "cuda" | string | The device to run the computation on. |
|
112 |
+
| --seed | None | int | The seed for the random number generator. |
|
113 |
+
|
114 |
+
### Beam Search Arguments
|
115 |
+
|
116 |
+
| Argument | Default | Type | Description |
|
117 |
+
|---|---|---|---|
|
118 |
+
| --checkpoint_path | "./out/checkpoint.pt" | string | The path to the checkpoint file on disk. |
|
119 |
+
| --lora_path | None | string | The path to the LoRA checkpoint. |
|
120 |
+
| --max_tokens | 200 | int | The maximum number of tokens that the model should generate per sample. |
|
121 |
+
| --num_candidates | 3 | int | The number of candidate sequences to output. |
|
122 |
+
| --beam_width | 16 | int | The number of candidate sequences to keep track of during search. |
|
123 |
+
| --device | "cuda" | string | The device to run the computation on. |
|
124 |
+
| --seed | None | int | The seed for the random number generator. |
|
125 |
+
|
126 |
+
## References:
|
127 |
+
>- A. Radford, et al. Language Models are Unsupervised Multitask Learners, OpenAI, 2019.
|
128 |
+
>- T. Brown, et al. Language Models are Few-Shot Learners. OpenAI, 2020.
|
129 |
+
>- A. Kazemnejad, et al. The Impact of Positional Encoding on Length Generalization in Transformers, 37th Conference on Neural Information Processing Systems (NeurIPS 2023).
|
130 |
+
>- S. Rajbhandari, et al. ZeRO: Memory Optimizations Toward Training Trillion Parameter Models, 2020.
|
131 |
+
>- J. R. Hermans, et al. Accumulated Gradient Normalization, JMLR: Workshop and Conference Proceedings, 2017.
|
132 |
+
>- T. Chen, et al. Training Deep Nets with Sublinear Memory Cost. MIT, 2019.
|
133 |
+
>- B. Zhang, et al. Root Mean Square Layer Normalization. 33rd Conference on Neural Information Processing Systems, NeurIPS 2019.
|
134 |
+
|
135 |
+
## License
|
136 |
+
The code is licensed [MIT](LICENSE) and the tutorial is licensed [CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/).
|
beam_search.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
from os import path
|
4 |
+
from argparse import ArgumentParser
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from torch.cuda import is_available as cuda_is_available
|
9 |
+
|
10 |
+
from model import GPT, GPTWithLoRA
|
11 |
+
from data import Alpaca
|
12 |
+
|
13 |
+
import tiktoken
|
14 |
+
|
15 |
+
|
16 |
+
def main():
|
17 |
+
parser = ArgumentParser(
|
18 |
+
description="Generate text from the model given a prompt.",
|
19 |
+
)
|
20 |
+
|
21 |
+
parser.add_argument("--checkpoint_path", default="./out/checkpoint.pt", type=str)
|
22 |
+
parser.add_argument("--lora_path", default=None, type=str)
|
23 |
+
parser.add_argument("--max_tokens", default=200, type=int)
|
24 |
+
parser.add_argument("--num_candidates", default=3, type=int)
|
25 |
+
parser.add_argument("--beam_width", default=16, type=int)
|
26 |
+
parser.add_argument("--device", default="cuda", type=str)
|
27 |
+
parser.add_argument("--seed", default=None, type=int)
|
28 |
+
|
29 |
+
args = parser.parse_args()
|
30 |
+
|
31 |
+
if "cuda" in args.device and not cuda_is_available():
|
32 |
+
raise RuntimeError("Cuda is not available.")
|
33 |
+
|
34 |
+
torch.set_float32_matmul_precision("high")
|
35 |
+
|
36 |
+
if args.seed:
|
37 |
+
torch.manual_seed(args.seed)
|
38 |
+
random.seed(args.seed)
|
39 |
+
|
40 |
+
tokenizer = tiktoken.get_encoding(Alpaca.ENCODING)
|
41 |
+
|
42 |
+
checkpoint = torch.load(
|
43 |
+
args.checkpoint_path, map_location=args.device, weights_only=True
|
44 |
+
)
|
45 |
+
|
46 |
+
model = GPT(**checkpoint["model_args"])
|
47 |
+
|
48 |
+
model = torch.compile(model)
|
49 |
+
|
50 |
+
model.load_state_dict(checkpoint["model"])
|
51 |
+
|
52 |
+
print("Model checkpoint loaded")
|
53 |
+
|
54 |
+
if args.lora_path:
|
55 |
+
checkpoint = torch.load(
|
56 |
+
args.lora_path, map_location=args.device, weights_only=True
|
57 |
+
)
|
58 |
+
|
59 |
+
model = GPTWithLoRA(model, **checkpoint["lora_args"])
|
60 |
+
|
61 |
+
model = torch.compile(model)
|
62 |
+
|
63 |
+
model.load_state_dict(checkpoint["lora"], strict=False)
|
64 |
+
|
65 |
+
model.merge_lora_parameters()
|
66 |
+
|
67 |
+
print("LoRA checkpoint loaded")
|
68 |
+
|
69 |
+
model.to(args.device)
|
70 |
+
|
71 |
+
model.eval()
|
72 |
+
|
73 |
+
while True:
|
74 |
+
prompt = input("Enter a prompt: ")
|
75 |
+
|
76 |
+
if args.lora_path:
|
77 |
+
prompt = Alpaca.PROMPT_TEMPLATE.format(instruction=prompt)
|
78 |
+
|
79 |
+
prompt = tokenizer.encode_ordinary(prompt)
|
80 |
+
|
81 |
+
prompt = torch.tensor(prompt, dtype=torch.int64, device=args.device)
|
82 |
+
|
83 |
+
candidates = model.beam_search(
|
84 |
+
prompt,
|
85 |
+
args.max_tokens,
|
86 |
+
args.num_candidates,
|
87 |
+
args.beam_width,
|
88 |
+
)
|
89 |
+
|
90 |
+
for i, candidate in enumerate(candidates, start=1):
|
91 |
+
print(f"Sequence #{i}")
|
92 |
+
|
93 |
+
out = tokenizer.decode(candidate.tokens.tolist()).strip()
|
94 |
+
|
95 |
+
print(out, end="\n\n")
|
96 |
+
|
97 |
+
print("\n")
|
98 |
+
|
99 |
+
if "y" not in input("Go again? (yes|no): ").lower():
|
100 |
+
break
|
101 |
+
|
102 |
+
|
103 |
+
if __name__ == "__main__":
|
104 |
+
main()
|
data.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
from os import path
|
4 |
+
from copy import deepcopy
|
5 |
+
|
6 |
+
from datasets import load_dataset
|
7 |
+
|
8 |
+
import tiktoken
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
import torch
|
13 |
+
|
14 |
+
from torch import Tensor
|
15 |
+
from torch.utils.data import IterableDataset, Dataset
|
16 |
+
from torch.nn.utils.rnn import pad_sequence
|
17 |
+
|
18 |
+
from tqdm import tqdm
|
19 |
+
|
20 |
+
|
21 |
+
class Openwebtext(IterableDataset):
|
22 |
+
DATASET_NAME = "openwebtext"
|
23 |
+
|
24 |
+
FILE_PREFIX = DATASET_NAME
|
25 |
+
|
26 |
+
TRAIN_FILENAME = f"{FILE_PREFIX}-train.bin"
|
27 |
+
TEST_FILENAME = f"{FILE_PREFIX}-test.bin"
|
28 |
+
|
29 |
+
TEST_SPLIT_PROPORTION = 0.005
|
30 |
+
NUM_SHARDS = 1024
|
31 |
+
|
32 |
+
ENCODING = "r50k_base"
|
33 |
+
|
34 |
+
PADDING_INDEX = -100
|
35 |
+
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
root_path: str,
|
39 |
+
train: bool = True,
|
40 |
+
tokens_per_sample: int = 1024,
|
41 |
+
samples_per_epoch: int = 4096,
|
42 |
+
num_processes: int = 8,
|
43 |
+
):
|
44 |
+
super().__init__()
|
45 |
+
|
46 |
+
if tokens_per_sample < 1:
|
47 |
+
raise ValueError(f"Tokens per sample must be greater than 0.")
|
48 |
+
|
49 |
+
if samples_per_epoch < 1:
|
50 |
+
raise ValueError(f"Samples per epoch must be greater than 0.")
|
51 |
+
|
52 |
+
train_path = path.join(root_path, self.TRAIN_FILENAME)
|
53 |
+
test_path = path.join(root_path, self.TEST_FILENAME)
|
54 |
+
|
55 |
+
self.tokenizer = tiktoken.get_encoding(self.ENCODING)
|
56 |
+
|
57 |
+
if not path.exists(train_path) or not path.exists(test_path):
|
58 |
+
tokenized_splits = (
|
59 |
+
load_dataset(self.DATASET_NAME, num_proc=num_processes, split="train")
|
60 |
+
.train_test_split(test_size=self.TEST_SPLIT_PROPORTION, shuffle=True)
|
61 |
+
.map(
|
62 |
+
self.tokenize,
|
63 |
+
desc="Tokenizing",
|
64 |
+
remove_columns=["text"],
|
65 |
+
num_proc=num_processes,
|
66 |
+
)
|
67 |
+
)
|
68 |
+
|
69 |
+
for split, dataset in tokenized_splits.items():
|
70 |
+
bin_path = path.join(root_path, f"{self.FILE_PREFIX}-{split}.bin")
|
71 |
+
|
72 |
+
total_length = np.sum(dataset["length"], dtype=np.uint64)
|
73 |
+
|
74 |
+
bin_out = np.memmap(
|
75 |
+
bin_path, dtype=np.uint16, mode="w+", shape=total_length
|
76 |
+
)
|
77 |
+
|
78 |
+
index = 0
|
79 |
+
|
80 |
+
for i in tqdm(range(self.NUM_SHARDS), desc="Writing"):
|
81 |
+
batch = dataset.shard(
|
82 |
+
num_shards=self.NUM_SHARDS, index=i, contiguous=True
|
83 |
+
).with_format("numpy")
|
84 |
+
|
85 |
+
token_batch = np.concatenate(batch["tokens"])
|
86 |
+
|
87 |
+
n = len(token_batch)
|
88 |
+
|
89 |
+
bin_out[index : index + n] = token_batch
|
90 |
+
|
91 |
+
index += n
|
92 |
+
|
93 |
+
bin_out.flush()
|
94 |
+
|
95 |
+
bin_file_path = path.join(
|
96 |
+
root_path, self.TRAIN_FILENAME if train else self.TEST_FILENAME
|
97 |
+
)
|
98 |
+
|
99 |
+
memmap = np.memmap(bin_file_path, dtype=np.uint16, mode="r")
|
100 |
+
|
101 |
+
self.memmap = memmap
|
102 |
+
self.max_start = len(memmap) - (tokens_per_sample + 1)
|
103 |
+
self.tokens_per_sample = tokens_per_sample
|
104 |
+
self.samples_per_epoch = samples_per_epoch
|
105 |
+
|
106 |
+
@property
|
107 |
+
def vocabulary_size(self) -> int:
|
108 |
+
return self.tokenizer.max_token_value + 1
|
109 |
+
|
110 |
+
@property
|
111 |
+
def eos_index(self) -> int:
|
112 |
+
return self.tokenizer.eot_token
|
113 |
+
|
114 |
+
def tokenize(self, sample: dict) -> dict:
|
115 |
+
tokens = self.tokenizer.encode_ordinary(sample["text"])
|
116 |
+
|
117 |
+
tokens.append(self.tokenizer.eot_token)
|
118 |
+
|
119 |
+
return {
|
120 |
+
"tokens": tokens,
|
121 |
+
"length": len(tokens),
|
122 |
+
}
|
123 |
+
|
124 |
+
def __iter__(self):
|
125 |
+
for i in range(self.samples_per_epoch):
|
126 |
+
start = random.randint(0, self.max_start)
|
127 |
+
end = start + self.tokens_per_sample
|
128 |
+
|
129 |
+
x = self.memmap[start:end]
|
130 |
+
y = self.memmap[start + 1 : end + 1]
|
131 |
+
|
132 |
+
x = x.astype(np.int64)
|
133 |
+
y = y.astype(np.int64)
|
134 |
+
|
135 |
+
assert x.shape == y.shape, "Sample / label shape mismatch."
|
136 |
+
|
137 |
+
yield x, y
|
138 |
+
|
139 |
+
|
140 |
+
class Alpaca(Dataset):
|
141 |
+
DATASET_NAME = "tatsu-lab/alpaca"
|
142 |
+
|
143 |
+
ENCODING = "r50k_base"
|
144 |
+
|
145 |
+
PADDING_INDEX = -100
|
146 |
+
|
147 |
+
PROMPT_TEMPLATE = (
|
148 |
+
"Below is an instruction that describes a task. Write a response that "
|
149 |
+
"appropriately completes the request.\n\n"
|
150 |
+
"### Instruction:\n{instruction}\n\n"
|
151 |
+
"### Response:\n"
|
152 |
+
)
|
153 |
+
|
154 |
+
PROMPT_TEMPLATE_WITH_INPUT = (
|
155 |
+
"Below is an instruction that describes a task, paired with an input "
|
156 |
+
"that provides further context. Write a response that appropriately "
|
157 |
+
"completes the request.\n\n"
|
158 |
+
"### Input:\n{input}\n\n"
|
159 |
+
"### Instruction:\n{instruction}\n\n"
|
160 |
+
"### Response:\n"
|
161 |
+
)
|
162 |
+
|
163 |
+
RESPONSE_TEMPLATE = "{output}"
|
164 |
+
|
165 |
+
def __init__(self, max_tokens_per_sample: int = 1024, mask_input: bool = True):
|
166 |
+
super().__init__()
|
167 |
+
|
168 |
+
if max_tokens_per_sample < 1:
|
169 |
+
raise ValueError(
|
170 |
+
f"Max tokens per sample must be greater than 0, {max_tokens_per_sample} given."
|
171 |
+
)
|
172 |
+
|
173 |
+
self.dataset = load_dataset(self.DATASET_NAME, split="train")
|
174 |
+
|
175 |
+
self.tokenizer = tiktoken.get_encoding(self.ENCODING)
|
176 |
+
|
177 |
+
self.max_tokens_per_sample = max_tokens_per_sample
|
178 |
+
self.mask_input = mask_input
|
179 |
+
|
180 |
+
@property
|
181 |
+
def vocabulary_size(self) -> int:
|
182 |
+
return self.tokenizer.max_token_value + 1
|
183 |
+
|
184 |
+
@property
|
185 |
+
def eos_index(self) -> int:
|
186 |
+
return self.tokenizer.eot_token
|
187 |
+
|
188 |
+
def collate(self, batch: list) -> tuple[Tensor, Tensor]:
|
189 |
+
"""Custom collate function adds left padding to batched samples."""
|
190 |
+
|
191 |
+
sample, labels = [], []
|
192 |
+
|
193 |
+
for x, y in batch:
|
194 |
+
sample.append(x)
|
195 |
+
labels.append(y)
|
196 |
+
|
197 |
+
x = pad_sequence(
|
198 |
+
sample,
|
199 |
+
batch_first=True,
|
200 |
+
padding_value=self.PADDING_INDEX,
|
201 |
+
padding_side="left",
|
202 |
+
)
|
203 |
+
y = pad_sequence(
|
204 |
+
labels,
|
205 |
+
batch_first=True,
|
206 |
+
padding_value=self.PADDING_INDEX,
|
207 |
+
padding_side="left",
|
208 |
+
)
|
209 |
+
|
210 |
+
assert x.shape == y.shape, "Sample / label batch shape mismatch."
|
211 |
+
|
212 |
+
return x, y
|
213 |
+
|
214 |
+
def __getitem__(self, index: int):
|
215 |
+
row = self.dataset[index]
|
216 |
+
|
217 |
+
has_input = len(row["input"]) > 0
|
218 |
+
|
219 |
+
if has_input:
|
220 |
+
text = self.PROMPT_TEMPLATE_WITH_INPUT.format(
|
221 |
+
input=row["input"], instruction=row["instruction"]
|
222 |
+
)
|
223 |
+
else:
|
224 |
+
text = self.PROMPT_TEMPLATE.format(instruction=row["instruction"])
|
225 |
+
|
226 |
+
tokens = self.tokenizer.encode_ordinary(text)
|
227 |
+
|
228 |
+
sample = deepcopy(tokens)
|
229 |
+
|
230 |
+
if self.mask_input:
|
231 |
+
labels = [self.PADDING_INDEX] * len(tokens)
|
232 |
+
else:
|
233 |
+
labels = deepcopy(tokens)
|
234 |
+
|
235 |
+
text = self.RESPONSE_TEMPLATE.format(output=row["output"])
|
236 |
+
|
237 |
+
tokens = self.tokenizer.encode_ordinary(text)
|
238 |
+
|
239 |
+
tokens.append(self.tokenizer.eot_token)
|
240 |
+
|
241 |
+
sample.extend(tokens)
|
242 |
+
labels.extend(tokens)
|
243 |
+
|
244 |
+
end = min(len(sample), self.max_tokens_per_sample + 1)
|
245 |
+
|
246 |
+
x = torch.tensor(sample[0 : end - 1], dtype=torch.int64)
|
247 |
+
y = torch.tensor(labels[1:end], dtype=torch.int64)
|
248 |
+
|
249 |
+
assert x.shape == y.shape, "Sample / label shape mismatch."
|
250 |
+
|
251 |
+
return x, y
|
252 |
+
|
253 |
+
def __len__(self):
|
254 |
+
return len(self.dataset)
|
dataset/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
generate.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
from os import path
|
4 |
+
from argparse import ArgumentParser
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from torch.cuda import is_available as cuda_is_available
|
9 |
+
|
10 |
+
from model import GPT, GPTWithLoRA
|
11 |
+
from data import Alpaca
|
12 |
+
|
13 |
+
import tiktoken
|
14 |
+
|
15 |
+
|
16 |
+
def main():
|
17 |
+
parser = ArgumentParser(
|
18 |
+
description="Generate text from the model given a prompt.",
|
19 |
+
)
|
20 |
+
|
21 |
+
parser.add_argument("--checkpoint_path", default="./out/checkpoint.pt", type=str)
|
22 |
+
parser.add_argument("--lora_path", default=None, type=str)
|
23 |
+
parser.add_argument("--max_tokens", default=1000, type=int)
|
24 |
+
parser.add_argument("--temperature", default=1.0, type=float)
|
25 |
+
parser.add_argument("--top_k", default=500, type=int)
|
26 |
+
parser.add_argument("--top_p", default=0.9, type=float)
|
27 |
+
parser.add_argument("--device", default="cuda", type=str)
|
28 |
+
parser.add_argument("--seed", default=None, type=int)
|
29 |
+
|
30 |
+
args = parser.parse_args()
|
31 |
+
|
32 |
+
if "cuda" in args.device and not cuda_is_available():
|
33 |
+
raise RuntimeError("Cuda is not available.")
|
34 |
+
|
35 |
+
torch.set_float32_matmul_precision("high")
|
36 |
+
|
37 |
+
if args.seed:
|
38 |
+
torch.manual_seed(args.seed)
|
39 |
+
random.seed(args.seed)
|
40 |
+
|
41 |
+
tokenizer = tiktoken.get_encoding(Alpaca.ENCODING)
|
42 |
+
|
43 |
+
checkpoint = torch.load(
|
44 |
+
args.checkpoint_path, map_location=args.device, weights_only=True
|
45 |
+
)
|
46 |
+
|
47 |
+
model = GPT(**checkpoint["model_args"])
|
48 |
+
|
49 |
+
model = torch.compile(model)
|
50 |
+
|
51 |
+
model.load_state_dict(checkpoint["model"])
|
52 |
+
|
53 |
+
print("Model checkpoint loaded")
|
54 |
+
|
55 |
+
if args.lora_path:
|
56 |
+
checkpoint = torch.load(
|
57 |
+
args.lora_path, map_location=args.device, weights_only=True
|
58 |
+
)
|
59 |
+
|
60 |
+
model = GPTWithLoRA(model, **checkpoint["lora_args"])
|
61 |
+
|
62 |
+
model = torch.compile(model)
|
63 |
+
|
64 |
+
model.load_state_dict(checkpoint["lora"], strict=False)
|
65 |
+
|
66 |
+
model.merge_lora_parameters()
|
67 |
+
|
68 |
+
print("LoRA checkpoint loaded")
|
69 |
+
|
70 |
+
model.to(args.device)
|
71 |
+
|
72 |
+
model.eval()
|
73 |
+
|
74 |
+
while True:
|
75 |
+
prompt = input("Enter a prompt: ")
|
76 |
+
|
77 |
+
if args.lora_path:
|
78 |
+
prompt = Alpaca.PROMPT_TEMPLATE.format(instruction=prompt)
|
79 |
+
|
80 |
+
prompt = tokenizer.encode_ordinary(prompt)
|
81 |
+
|
82 |
+
prompt = torch.tensor(prompt, dtype=torch.int64, device=args.device)
|
83 |
+
|
84 |
+
for token in model.generate(
|
85 |
+
prompt, args.max_tokens, args.temperature, args.top_k, args.top_p
|
86 |
+
):
|
87 |
+
out = tokenizer.decode_single_token_bytes(token).decode(
|
88 |
+
"utf-8", errors="replace"
|
89 |
+
)
|
90 |
+
|
91 |
+
print(out, end="", flush=True)
|
92 |
+
|
93 |
+
print("\n")
|
94 |
+
|
95 |
+
if "y" not in input("Go again? (yes|no): ").lower():
|
96 |
+
break
|
97 |
+
|
98 |
+
|
99 |
+
if __name__ == "__main__":
|
100 |
+
main()
|
instruction-tune.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
from argparse import ArgumentParser
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from torch.optim import Adafactor
|
9 |
+
from torch.amp import autocast
|
10 |
+
from torch.cuda import is_available as cuda_is_available, is_bf16_supported
|
11 |
+
from torch.utils.data import random_split
|
12 |
+
|
13 |
+
from torchmetrics.text import Perplexity
|
14 |
+
|
15 |
+
from model import GPT, GPTWithLoRA
|
16 |
+
from data import Alpaca
|
17 |
+
|
18 |
+
import tiktoken
|
19 |
+
|
20 |
+
from tqdm import tqdm
|
21 |
+
|
22 |
+
|
23 |
+
def main():
|
24 |
+
parser = ArgumentParser(description="Instruction-tune the foundation model.")
|
25 |
+
|
26 |
+
parser.add_argument("--base_model_path", default="./out/checkpoint.pt", type=str)
|
27 |
+
parser.add_argument("--batch_size", default=1, type=int)
|
28 |
+
parser.add_argument("--gradient_accumulation_steps", default=128, type=int)
|
29 |
+
parser.add_argument("--learning_rate", default=1e-2, type=float)
|
30 |
+
parser.add_argument("--mask_input", default=True, type=bool)
|
31 |
+
parser.add_argument("--rank", default=8, type=int)
|
32 |
+
parser.add_argument("--alpha", default=1.0, type=float)
|
33 |
+
parser.add_argument("--dropout", default=0.05, type=float)
|
34 |
+
parser.add_argument("--num_epochs", default=4, type=int)
|
35 |
+
parser.add_argument("--eval_interval", default=1, type=int)
|
36 |
+
parser.add_argument("--checkpoint_interval", default=1, type=int)
|
37 |
+
parser.add_argument(
|
38 |
+
"--checkpoint_path", default="./out/lora_instruction.pt", type=str
|
39 |
+
)
|
40 |
+
parser.add_argument("--resume", action="store_true")
|
41 |
+
parser.add_argument("--device", default="cuda", type=str)
|
42 |
+
parser.add_argument("--seed", default=None, type=int)
|
43 |
+
|
44 |
+
args = parser.parse_args()
|
45 |
+
|
46 |
+
if "cuda" in args.device and not cuda_is_available():
|
47 |
+
raise RuntimeError("Cuda is not available.")
|
48 |
+
|
49 |
+
torch.set_float32_matmul_precision("high")
|
50 |
+
|
51 |
+
dtype = (
|
52 |
+
torch.bfloat16
|
53 |
+
if "cuda" in args.device and is_bf16_supported()
|
54 |
+
else torch.float32
|
55 |
+
)
|
56 |
+
|
57 |
+
forward_context = autocast(device_type=args.device, dtype=dtype)
|
58 |
+
|
59 |
+
if args.seed:
|
60 |
+
torch.manual_seed(args.seed)
|
61 |
+
random.seed(args.seed)
|
62 |
+
|
63 |
+
checkpoint = torch.load(
|
64 |
+
args.base_model_path, map_location=args.device, weights_only=True
|
65 |
+
)
|
66 |
+
|
67 |
+
model_args = checkpoint["model_args"]
|
68 |
+
|
69 |
+
dataset = Alpaca(model_args["block_size"], args.mask_input)
|
70 |
+
|
71 |
+
training, testing = random_split(dataset, (0.9, 0.1))
|
72 |
+
|
73 |
+
train_loader = DataLoader(
|
74 |
+
training,
|
75 |
+
collate_fn=dataset.collate,
|
76 |
+
batch_size=args.batch_size,
|
77 |
+
pin_memory="cpu" not in args.device,
|
78 |
+
shuffle=True,
|
79 |
+
)
|
80 |
+
test_loader = DataLoader(
|
81 |
+
testing,
|
82 |
+
collate_fn=dataset.collate,
|
83 |
+
batch_size=args.batch_size,
|
84 |
+
pin_memory="cpu" not in args.device,
|
85 |
+
shuffle=False,
|
86 |
+
)
|
87 |
+
|
88 |
+
model = GPT(**model_args)
|
89 |
+
|
90 |
+
model = torch.compile(model)
|
91 |
+
|
92 |
+
model.load_state_dict(checkpoint["model"])
|
93 |
+
|
94 |
+
print("Model checkpoint loaded")
|
95 |
+
|
96 |
+
lora_args = {
|
97 |
+
"rank": args.rank,
|
98 |
+
"alpha": args.alpha,
|
99 |
+
"dropout": args.dropout,
|
100 |
+
}
|
101 |
+
|
102 |
+
model = GPTWithLoRA(model, **lora_args).to(args.device)
|
103 |
+
|
104 |
+
print("Compiling model")
|
105 |
+
model.compile()
|
106 |
+
|
107 |
+
print(f"Model has {model.num_trainable_params:,} trainable parameters")
|
108 |
+
|
109 |
+
optimizer = Adafactor(model.parameters(), lr=args.learning_rate)
|
110 |
+
|
111 |
+
perplexity_metric = Perplexity(ignore_index=dataset.PADDING_INDEX).to(args.device)
|
112 |
+
|
113 |
+
starting_epoch = 1
|
114 |
+
|
115 |
+
if args.resume:
|
116 |
+
checkpoint = torch.load(
|
117 |
+
args.checkpoint_path, map_location=args.device, weights_only=True
|
118 |
+
)
|
119 |
+
|
120 |
+
model.load_state_dict(checkpoint["lora"], strict=False)
|
121 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
122 |
+
starting_epoch += checkpoint["epoch"]
|
123 |
+
|
124 |
+
print("Previous checkpoint resumed successfully")
|
125 |
+
|
126 |
+
model.train()
|
127 |
+
|
128 |
+
print("Instruction-tuning ...")
|
129 |
+
|
130 |
+
for epoch in range(starting_epoch, args.num_epochs + 1):
|
131 |
+
total_cross_entropy, total_batches = 0.0, 0
|
132 |
+
|
133 |
+
for step, (x, y) in enumerate(
|
134 |
+
tqdm(train_loader, desc=f"Epoch {epoch}", leave=False), start=1
|
135 |
+
):
|
136 |
+
x = x.to(args.device, non_blocking=True)
|
137 |
+
y = y.to(args.device, non_blocking=True)
|
138 |
+
|
139 |
+
with forward_context:
|
140 |
+
y_pred, loss = model(x, y)
|
141 |
+
|
142 |
+
scaled_loss = loss / args.gradient_accumulation_steps
|
143 |
+
|
144 |
+
scaled_loss.backward()
|
145 |
+
|
146 |
+
total_cross_entropy += loss.item()
|
147 |
+
|
148 |
+
if step % args.gradient_accumulation_steps == 0:
|
149 |
+
optimizer.step()
|
150 |
+
|
151 |
+
optimizer.zero_grad(set_to_none=True)
|
152 |
+
|
153 |
+
total_batches += 1
|
154 |
+
|
155 |
+
average_cross_entropy = total_cross_entropy / total_batches
|
156 |
+
|
157 |
+
print(
|
158 |
+
f"Epoch {epoch}: Cross Entropy: {average_cross_entropy:.5f}",
|
159 |
+
)
|
160 |
+
|
161 |
+
if epoch % args.eval_interval == 0:
|
162 |
+
model.eval()
|
163 |
+
|
164 |
+
for x, y in tqdm(test_loader, desc="Testing", leave=False):
|
165 |
+
x = x.to(args.device, non_blocking=True)
|
166 |
+
y = y.to(args.device, non_blocking=True)
|
167 |
+
|
168 |
+
with torch.no_grad():
|
169 |
+
y_pred, _ = model(x)
|
170 |
+
|
171 |
+
perplexity_metric.update(y_pred, y)
|
172 |
+
|
173 |
+
perplexity = perplexity_metric.compute()
|
174 |
+
|
175 |
+
print(f"Perplexity: {perplexity:.3f}")
|
176 |
+
|
177 |
+
perplexity_metric.reset()
|
178 |
+
|
179 |
+
model.train()
|
180 |
+
|
181 |
+
if epoch % args.checkpoint_interval == 0:
|
182 |
+
checkpoint = {
|
183 |
+
"epoch": epoch,
|
184 |
+
"lora_args": lora_args,
|
185 |
+
"lora": model.state_dict(),
|
186 |
+
"optimizer": optimizer.state_dict(),
|
187 |
+
}
|
188 |
+
|
189 |
+
torch.save(checkpoint, args.checkpoint_path)
|
190 |
+
|
191 |
+
print("Checkpoint saved")
|
192 |
+
|
193 |
+
print("Done!")
|
194 |
+
|
195 |
+
|
196 |
+
if __name__ == "__main__":
|
197 |
+
main()
|
model.py
ADDED
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import sqrt, exp
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
from functools import partial
|
4 |
+
from typing import Iterator, Self
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from torch import Tensor
|
9 |
+
from torch.nn import (
|
10 |
+
Module,
|
11 |
+
ModuleList,
|
12 |
+
Sequential,
|
13 |
+
Embedding,
|
14 |
+
MultiheadAttention,
|
15 |
+
Linear,
|
16 |
+
RMSNorm,
|
17 |
+
GELU,
|
18 |
+
Dropout1d,
|
19 |
+
CrossEntropyLoss,
|
20 |
+
Parameter,
|
21 |
+
Buffer,
|
22 |
+
)
|
23 |
+
|
24 |
+
from torch.nn.functional import softmax, log_softmax
|
25 |
+
from torch.nn.utils.parametrize import register_parametrization, remove_parametrizations
|
26 |
+
from torch.utils.checkpoint import checkpoint
|
27 |
+
|
28 |
+
|
29 |
+
class GPT(Module):
|
30 |
+
"""A generative pre-trained transformer."""
|
31 |
+
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
block_size: int = 1024,
|
35 |
+
embedding_dimensions: int = 1024,
|
36 |
+
num_heads: int = 16,
|
37 |
+
num_layers: int = 24,
|
38 |
+
dropout: float = 0.1,
|
39 |
+
activation_checkpointing: bool = False,
|
40 |
+
vocabulary_size: int = 50257,
|
41 |
+
padding_index: int = -100,
|
42 |
+
eos_index: int = 50256,
|
43 |
+
):
|
44 |
+
super().__init__()
|
45 |
+
|
46 |
+
if vocabulary_size <= 0:
|
47 |
+
raise ValueError(
|
48 |
+
f"Vocabulary size must be greater than 0, {vocabulary_size} given."
|
49 |
+
)
|
50 |
+
|
51 |
+
if num_layers <= 0:
|
52 |
+
raise ValueError(f"Num layers must be greater than 0, {num_layers} given.")
|
53 |
+
|
54 |
+
token_embeddings = Embedding(
|
55 |
+
vocabulary_size, embedding_dimensions, padding_idx=padding_index
|
56 |
+
)
|
57 |
+
|
58 |
+
output_layer = Linear(embedding_dimensions, vocabulary_size, bias=False)
|
59 |
+
|
60 |
+
token_embeddings.weight = output_layer.weight # Tie weights
|
61 |
+
|
62 |
+
self.token_embeddings = token_embeddings
|
63 |
+
|
64 |
+
causal_mask = torch.full((block_size, block_size), float("-inf"))
|
65 |
+
causal_mask = torch.triu(causal_mask, diagonal=1)
|
66 |
+
|
67 |
+
self.causal_mask = Buffer(causal_mask, persistent=False)
|
68 |
+
|
69 |
+
self.body = ModuleList(
|
70 |
+
[
|
71 |
+
CausalSelfAttentionBlock(
|
72 |
+
embedding_dimensions, block_size, num_heads, dropout
|
73 |
+
)
|
74 |
+
for _ in range(num_layers)
|
75 |
+
]
|
76 |
+
)
|
77 |
+
|
78 |
+
if activation_checkpointing:
|
79 |
+
self.checkpoint = partial(checkpoint, use_reentrant=False)
|
80 |
+
else:
|
81 |
+
self.checkpoint = lambda layer, x, attention_mask: layer(x, attention_mask)
|
82 |
+
|
83 |
+
self.output_norm = RMSNorm(embedding_dimensions)
|
84 |
+
self.output_layer = output_layer
|
85 |
+
|
86 |
+
self.loss_function = CrossEntropyLoss(ignore_index=padding_index)
|
87 |
+
|
88 |
+
self.vocabulary_size = vocabulary_size
|
89 |
+
self.block_size = block_size
|
90 |
+
self.eos_index = eos_index
|
91 |
+
|
92 |
+
@property
|
93 |
+
def num_trainable_params(self) -> int:
|
94 |
+
return sum(param.numel() for param in self.parameters() if param.requires_grad)
|
95 |
+
|
96 |
+
def forward(
|
97 |
+
self, x: Tensor, y: Tensor | None = None
|
98 |
+
) -> tuple[Tensor, Tensor | None]:
|
99 |
+
z = self.token_embeddings(x)
|
100 |
+
|
101 |
+
b, t = x.size()
|
102 |
+
|
103 |
+
causal_mask = self.causal_mask[:t, :t]
|
104 |
+
|
105 |
+
for layer in self.body:
|
106 |
+
z = self.checkpoint(layer, z, causal_mask)
|
107 |
+
|
108 |
+
z = self.output_norm(z)
|
109 |
+
z = self.output_layer(z)
|
110 |
+
|
111 |
+
if y is not None:
|
112 |
+
y_pred = z.view(-1, z.size(-1))
|
113 |
+
labels = y.view(-1)
|
114 |
+
|
115 |
+
loss = self.loss_function(y_pred, labels)
|
116 |
+
else:
|
117 |
+
loss = None
|
118 |
+
|
119 |
+
return z, loss
|
120 |
+
|
121 |
+
@torch.no_grad()
|
122 |
+
def generate(
|
123 |
+
self,
|
124 |
+
prompt: Tensor,
|
125 |
+
max_tokens: int = 500,
|
126 |
+
temperature: float = 1.0,
|
127 |
+
top_k: int = 500,
|
128 |
+
top_p: float = 0.9,
|
129 |
+
) -> Iterator:
|
130 |
+
"""
|
131 |
+
Given a prompt, sample the next {max_tokens} tokens from the model weighted
|
132 |
+
by their predicted probabilities.
|
133 |
+
"""
|
134 |
+
|
135 |
+
if max_tokens <= 0:
|
136 |
+
raise ValueError(f"Max tokens must be greater than 0, {max_tokens} given.")
|
137 |
+
|
138 |
+
if temperature <= 0:
|
139 |
+
raise ValueError(
|
140 |
+
f"Temperature must be greater than 0, {temperature} given."
|
141 |
+
)
|
142 |
+
|
143 |
+
if top_k <= 0 or top_k > self.vocabulary_size:
|
144 |
+
raise ValueError(
|
145 |
+
f"Top k must be between 1 and {self.vocabulary_size}, {top_k} given."
|
146 |
+
)
|
147 |
+
|
148 |
+
if top_p <= 0.0 or top_p > 1.0:
|
149 |
+
raise ValueError(f"Top p must be between 0 and 1, {top_p} given.")
|
150 |
+
|
151 |
+
context_window = prompt
|
152 |
+
|
153 |
+
for _ in range(max_tokens):
|
154 |
+
context_window = context_window[-self.block_size :]
|
155 |
+
|
156 |
+
y_pred, _ = self.forward(context_window.unsqueeze(0))
|
157 |
+
|
158 |
+
logits = y_pred[0, -1, :]
|
159 |
+
|
160 |
+
logits, indices = torch.topk(logits, top_k, sorted=True)
|
161 |
+
|
162 |
+
probabilities = softmax(logits, dim=0)
|
163 |
+
|
164 |
+
cumulative_probability_mass = torch.cumsum(probabilities, dim=0)
|
165 |
+
|
166 |
+
min_probability_mass = cumulative_probability_mass[0]
|
167 |
+
|
168 |
+
threshold_p = max(top_p, min_probability_mass.item())
|
169 |
+
|
170 |
+
selected_indices = cumulative_probability_mass <= threshold_p
|
171 |
+
|
172 |
+
logits = logits[selected_indices]
|
173 |
+
indices = indices[selected_indices]
|
174 |
+
|
175 |
+
logits /= temperature
|
176 |
+
|
177 |
+
probabilities = softmax(logits, dim=0)
|
178 |
+
|
179 |
+
offset = torch.multinomial(probabilities, num_samples=1).squeeze(0)
|
180 |
+
|
181 |
+
next_token = indices[offset]
|
182 |
+
|
183 |
+
if next_token == self.eos_index:
|
184 |
+
break
|
185 |
+
|
186 |
+
yield next_token
|
187 |
+
|
188 |
+
context_window = torch.cat((context_window, next_token.unsqueeze(0)))
|
189 |
+
|
190 |
+
@torch.no_grad()
|
191 |
+
def beam_search(
|
192 |
+
self,
|
193 |
+
prompt: Tensor,
|
194 |
+
max_tokens: int = 200,
|
195 |
+
num_candidates: int = 3,
|
196 |
+
beam_width: int = 16,
|
197 |
+
) -> list:
|
198 |
+
"""
|
199 |
+
Given a prompt, return the {num_candidates} highest probability sequences.
|
200 |
+
"""
|
201 |
+
|
202 |
+
if max_tokens <= 0:
|
203 |
+
raise ValueError(f"Max tokens must be greater than 0, {max_tokens} given.")
|
204 |
+
|
205 |
+
if num_candidates <= 0:
|
206 |
+
raise ValueError(
|
207 |
+
f"Num candidates must be greater than 0, {num_candidates} given."
|
208 |
+
)
|
209 |
+
|
210 |
+
if beam_width <= 0:
|
211 |
+
raise ValueError(f"Beam width must be greater than 0, {beam_width} given.")
|
212 |
+
|
213 |
+
@dataclass(order=True)
|
214 |
+
class Candidate:
|
215 |
+
log_probability: float
|
216 |
+
tokens: Tensor
|
217 |
+
|
218 |
+
@property
|
219 |
+
def priority(self) -> float:
|
220 |
+
return self.log_probability
|
221 |
+
|
222 |
+
sort_candidates = partial(
|
223 |
+
sorted,
|
224 |
+
key=lambda candidate: candidate.priority,
|
225 |
+
reverse=True,
|
226 |
+
)
|
227 |
+
|
228 |
+
candidates, completed = [], []
|
229 |
+
|
230 |
+
tokens = torch.tensor([], dtype=prompt.dtype).to(prompt.device)
|
231 |
+
|
232 |
+
candidates.append(Candidate(0.0, tokens))
|
233 |
+
|
234 |
+
while len(candidates) > 0:
|
235 |
+
candidate = candidates.pop()
|
236 |
+
|
237 |
+
if len(completed) >= num_candidates:
|
238 |
+
completed = sort_candidates(completed)
|
239 |
+
|
240 |
+
completed = completed[:num_candidates]
|
241 |
+
|
242 |
+
worst_candidate = completed[-1]
|
243 |
+
|
244 |
+
if candidate.log_probability < worst_candidate.log_probability:
|
245 |
+
break
|
246 |
+
|
247 |
+
if len(candidate.tokens) > 0 and candidate.tokens[-1] == self.eos_index:
|
248 |
+
candidate.tokens = candidate.tokens[:-1]
|
249 |
+
|
250 |
+
completed.append(candidate)
|
251 |
+
|
252 |
+
continue
|
253 |
+
|
254 |
+
if len(candidate.tokens) >= max_tokens:
|
255 |
+
completed.append(candidate)
|
256 |
+
|
257 |
+
continue
|
258 |
+
|
259 |
+
context_window = torch.cat((prompt, candidate.tokens))
|
260 |
+
|
261 |
+
context_window = context_window[-self.block_size :]
|
262 |
+
|
263 |
+
y_pred, _ = self.forward(context_window.unsqueeze(0))
|
264 |
+
|
265 |
+
logits = y_pred[0, -1, :]
|
266 |
+
|
267 |
+
logits, indices = torch.topk(logits, beam_width, sorted=False)
|
268 |
+
|
269 |
+
log_probabilities = log_softmax(logits, dim=0)
|
270 |
+
|
271 |
+
for log_probability, index in zip(log_probabilities, indices):
|
272 |
+
log_probability = candidate.log_probability + log_probability
|
273 |
+
|
274 |
+
tokens = torch.cat((candidate.tokens, index.unsqueeze(0)))
|
275 |
+
|
276 |
+
candidates.append(Candidate(log_probability, tokens))
|
277 |
+
|
278 |
+
candidates = sort_candidates(candidates)
|
279 |
+
|
280 |
+
candidates = candidates[:beam_width]
|
281 |
+
|
282 |
+
return completed
|
283 |
+
|
284 |
+
|
285 |
+
class GPTWithLoRA(Module):
|
286 |
+
"""
|
287 |
+
A wrapper for pre-trained GPT models that applies a LoRA reparameterization
|
288 |
+
to the intermediate layers of the network.
|
289 |
+
"""
|
290 |
+
|
291 |
+
def __init__(
|
292 |
+
self, model: GPT, rank: int = 8, alpha: float = 1.0, dropout: float = 0.05
|
293 |
+
):
|
294 |
+
super().__init__()
|
295 |
+
|
296 |
+
if rank <= 0:
|
297 |
+
raise ValueError(f"Rank must be greater than 0, {rank} given.")
|
298 |
+
|
299 |
+
if alpha <= 0.0:
|
300 |
+
raise ValueError(f"Alpha must be greater than 0, {alpha} given.")
|
301 |
+
|
302 |
+
for param in model.parameters():
|
303 |
+
param.requires_grad = False
|
304 |
+
|
305 |
+
for module in model.body:
|
306 |
+
out_features, in_features = module.attention.in_proj_weight.shape
|
307 |
+
|
308 |
+
register_parametrization(
|
309 |
+
module.attention,
|
310 |
+
"in_proj_weight",
|
311 |
+
LoRA(in_features, out_features, rank, alpha, dropout),
|
312 |
+
)
|
313 |
+
|
314 |
+
out_features, in_features = module.attention.out_proj.weight.shape
|
315 |
+
|
316 |
+
register_parametrization(
|
317 |
+
module.attention.out_proj,
|
318 |
+
"weight",
|
319 |
+
LoRA(in_features, out_features, rank, alpha, dropout),
|
320 |
+
)
|
321 |
+
|
322 |
+
for layer in module.mlp.layers:
|
323 |
+
if isinstance(layer, Linear):
|
324 |
+
register_parametrization(
|
325 |
+
layer,
|
326 |
+
"weight",
|
327 |
+
LoRA.from_linear(layer, rank, alpha, dropout),
|
328 |
+
)
|
329 |
+
|
330 |
+
self.model = model
|
331 |
+
|
332 |
+
@property
|
333 |
+
def num_trainable_params(self) -> int:
|
334 |
+
return self.model.num_trainable_params
|
335 |
+
|
336 |
+
def state_dict(self):
|
337 |
+
return {
|
338 |
+
name: module
|
339 |
+
for name, module in super().state_dict().items()
|
340 |
+
if "lora" in name
|
341 |
+
}
|
342 |
+
|
343 |
+
def merge_lora_parameters(self):
|
344 |
+
"""Merge the LoRA parameters with the original parameters."""
|
345 |
+
|
346 |
+
for module in self.model.modules():
|
347 |
+
if hasattr(module, "parametrizations"):
|
348 |
+
lora_params = [name for name in module.parametrizations.keys()]
|
349 |
+
|
350 |
+
for name in lora_params:
|
351 |
+
remove_parametrizations(module, name, leave_parametrized=True)
|
352 |
+
|
353 |
+
def forward(
|
354 |
+
self, x: Tensor, y: Tensor | None = None
|
355 |
+
) -> tuple[Tensor, Tensor | None]:
|
356 |
+
return self.model.forward(x, y)
|
357 |
+
|
358 |
+
def generate(
|
359 |
+
self,
|
360 |
+
prompt: Tensor,
|
361 |
+
max_tokens: int = 500,
|
362 |
+
temperature: float = 1.0,
|
363 |
+
top_k: int = 500,
|
364 |
+
top_p: float = 0.9,
|
365 |
+
) -> Iterator:
|
366 |
+
return self.model.generate(prompt, max_tokens, temperature, top_k)
|
367 |
+
|
368 |
+
def beam_search(
|
369 |
+
self,
|
370 |
+
prompt: Tensor,
|
371 |
+
max_tokens: int = 200,
|
372 |
+
num_candidates: int = 3,
|
373 |
+
beam_width: int = 16,
|
374 |
+
) -> list:
|
375 |
+
return self.model.beam_search(prompt, max_tokens, num_candidates, beam_width)
|
376 |
+
|
377 |
+
|
378 |
+
class CausalSelfAttentionBlock(Module):
|
379 |
+
"""Causal self-attention block with residual connections."""
|
380 |
+
|
381 |
+
def __init__(
|
382 |
+
self, embedding_dimensions: int, block_size: int, num_heads: int, dropout: float
|
383 |
+
):
|
384 |
+
super().__init__()
|
385 |
+
|
386 |
+
if embedding_dimensions <= 0:
|
387 |
+
raise ValueError(
|
388 |
+
f"Embedding dimensions must be greater than 0, {embedding_dimensions} given."
|
389 |
+
)
|
390 |
+
|
391 |
+
if block_size <= 0:
|
392 |
+
raise ValueError(f"Block size must be greater than 0, {block_size} given.")
|
393 |
+
|
394 |
+
if num_heads <= 0:
|
395 |
+
raise ValueError(f"Num heads must be greater than 0, {num_heads} given.")
|
396 |
+
|
397 |
+
if dropout < 0 or dropout > 1:
|
398 |
+
raise ValueError(f"Dropout must be between 0 and 1, {dropout} given")
|
399 |
+
|
400 |
+
self.norm1 = RMSNorm(embedding_dimensions)
|
401 |
+
self.attention = MultiheadAttention(
|
402 |
+
embedding_dimensions,
|
403 |
+
num_heads,
|
404 |
+
batch_first=True,
|
405 |
+
dropout=dropout,
|
406 |
+
bias=False,
|
407 |
+
)
|
408 |
+
|
409 |
+
self.norm2 = RMSNorm(embedding_dimensions)
|
410 |
+
self.mlp = MLP(embedding_dimensions, 4 * embedding_dimensions, dropout)
|
411 |
+
|
412 |
+
def forward(self, x: Tensor, attention_mask: Tensor) -> Tensor:
|
413 |
+
z = self.norm1(x)
|
414 |
+
z, _ = self.attention(z, z, z, attn_mask=attention_mask, is_causal=True)
|
415 |
+
|
416 |
+
z = x + z # Residual connection
|
417 |
+
|
418 |
+
x = z
|
419 |
+
|
420 |
+
z = self.norm2(x)
|
421 |
+
z = self.mlp(z)
|
422 |
+
|
423 |
+
z = x + z # Residual connection
|
424 |
+
|
425 |
+
return z
|
426 |
+
|
427 |
+
|
428 |
+
class MLP(Module):
|
429 |
+
"""A two-layer fully-connected network with dropout."""
|
430 |
+
|
431 |
+
def __init__(
|
432 |
+
self, embedding_dimensions: int, hidden_dimensions: int, dropout: float
|
433 |
+
):
|
434 |
+
super().__init__()
|
435 |
+
|
436 |
+
if embedding_dimensions <= 0:
|
437 |
+
raise ValueError(
|
438 |
+
f"Embedding dimensions must be greater than 0, {embedding_dimensions} given."
|
439 |
+
)
|
440 |
+
|
441 |
+
if hidden_dimensions <= 0:
|
442 |
+
raise ValueError(
|
443 |
+
f"Hidden dimensions must be greater than 0, {hidden_dimensions} given."
|
444 |
+
)
|
445 |
+
|
446 |
+
self.layers = Sequential(
|
447 |
+
Linear(embedding_dimensions, hidden_dimensions, bias=False),
|
448 |
+
GELU(),
|
449 |
+
Linear(hidden_dimensions, embedding_dimensions, bias=False),
|
450 |
+
)
|
451 |
+
|
452 |
+
self.dropout = Dropout1d(p=dropout)
|
453 |
+
|
454 |
+
def forward(self, x: Tensor) -> Tensor:
|
455 |
+
return self.dropout(self.layers(x))
|
456 |
+
|
457 |
+
|
458 |
+
class LoRA(Module):
|
459 |
+
"""Rank decomposition transformation."""
|
460 |
+
|
461 |
+
@classmethod
|
462 |
+
def from_linear(
|
463 |
+
cls, linear: Linear, rank: int, alpha: float, dropout: float
|
464 |
+
) -> Self:
|
465 |
+
out_features, in_features = linear.weight.shape
|
466 |
+
|
467 |
+
return cls(in_features, out_features, rank, alpha, dropout)
|
468 |
+
|
469 |
+
def __init__(
|
470 |
+
self,
|
471 |
+
in_features: int,
|
472 |
+
out_features: int,
|
473 |
+
rank: int,
|
474 |
+
alpha: float,
|
475 |
+
dropout: float,
|
476 |
+
):
|
477 |
+
super().__init__()
|
478 |
+
|
479 |
+
if rank <= 0:
|
480 |
+
raise ValueError(f"Rank must be greater than 0, {rank} given.")
|
481 |
+
|
482 |
+
if alpha <= 0.0:
|
483 |
+
raise ValueError(f"Alpha must be greater than 0, {alpha} given.")
|
484 |
+
|
485 |
+
std_dev = 1.0 / sqrt(rank)
|
486 |
+
|
487 |
+
self.lora_a = Parameter(torch.randn(rank, in_features) * std_dev)
|
488 |
+
self.lora_b = Parameter(torch.zeros(out_features, rank))
|
489 |
+
|
490 |
+
self.dropout = Dropout1d(p=dropout)
|
491 |
+
|
492 |
+
self.alpha = alpha
|
493 |
+
|
494 |
+
def forward(self, x: Tensor) -> Tensor:
|
495 |
+
z = self.lora_b @ self.dropout(self.lora_a)
|
496 |
+
|
497 |
+
z *= self.alpha
|
498 |
+
|
499 |
+
return x + z
|
model_sizing.ipynb
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 63,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"block_size = 1024\n",
|
10 |
+
"vocabulary_size = 50257\n",
|
11 |
+
"embedding_dimensions = 1024\n",
|
12 |
+
"num_hidden_layers = 32"
|
13 |
+
]
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"cell_type": "markdown",
|
17 |
+
"metadata": {},
|
18 |
+
"source": [
|
19 |
+
"First, we'll estimate the total number of parameters in the network."
|
20 |
+
]
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"cell_type": "code",
|
24 |
+
"execution_count": 64,
|
25 |
+
"metadata": {},
|
26 |
+
"outputs": [
|
27 |
+
{
|
28 |
+
"data": {
|
29 |
+
"image/png": "",
|
30 |
+
"text/plain": [
|
31 |
+
"<Figure size 640x480 with 1 Axes>"
|
32 |
+
]
|
33 |
+
},
|
34 |
+
"metadata": {},
|
35 |
+
"output_type": "display_data"
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"name": "stdout",
|
39 |
+
"output_type": "stream",
|
40 |
+
"text": [
|
41 |
+
"Token Embeddings 51,463,168 11.33%\n",
|
42 |
+
"Attention 134,217,728 29.55%\n",
|
43 |
+
"MLP 268,435,456 59.10%\n",
|
44 |
+
"RMS Norm 66,560 0.01%\n",
|
45 |
+
"Output Layer 0 0.00%\n",
|
46 |
+
"\n",
|
47 |
+
"\n",
|
48 |
+
"Total parameters: 454,182,912\n"
|
49 |
+
]
|
50 |
+
}
|
51 |
+
],
|
52 |
+
"source": [
|
53 |
+
"import matplotlib.pyplot as plt\n",
|
54 |
+
"\n",
|
55 |
+
"parameter_counts = {\n",
|
56 |
+
" \"Token Embeddings\": vocabulary_size * embedding_dimensions,\n",
|
57 |
+
" \"Attention\": (embedding_dimensions ** 2 + embedding_dimensions * 3 * embedding_dimensions) * num_hidden_layers,\n",
|
58 |
+
" \"MLP\": embedding_dimensions * 4 * embedding_dimensions * 2 * num_hidden_layers,\n",
|
59 |
+
" \"RMS Norm\": embedding_dimensions * num_hidden_layers * 2 + embedding_dimensions,\n",
|
60 |
+
" \"Output Layer\": 0, # Tied to token embeddings\n",
|
61 |
+
"}\n",
|
62 |
+
"\n",
|
63 |
+
"plt.bar(parameter_counts.keys(), parameter_counts.values())\n",
|
64 |
+
"\n",
|
65 |
+
"plt.title(\"Model Parameters\")\n",
|
66 |
+
"plt.ylabel(\"# of Parameters\")\n",
|
67 |
+
"plt.xticks(rotation=45)\n",
|
68 |
+
"\n",
|
69 |
+
"plt.show()\n",
|
70 |
+
"\n",
|
71 |
+
"total_parameter_count = sum(parameter_counts.values())\n",
|
72 |
+
"\n",
|
73 |
+
"for name, count in parameter_counts.items():\n",
|
74 |
+
" print(f\"{name:20s} {count:20,d} {count / total_parameter_count * 100:10.2f}%\")\n",
|
75 |
+
"\n",
|
76 |
+
"print(\"\\n\")\n",
|
77 |
+
"\n",
|
78 |
+
"print(f\"Total parameters: {total_parameter_count:,}\")"
|
79 |
+
]
|
80 |
+
},
|
81 |
+
{
|
82 |
+
"cell_type": "markdown",
|
83 |
+
"metadata": {},
|
84 |
+
"source": [
|
85 |
+
"Next, we'll estimate the size of the model in memory and on disk. Note that this does not include any intermediate variables that get memorized during training such as activations, gradients, optimizer state, and temporary buffers. Actual memory consumption will likely be much higher."
|
86 |
+
]
|
87 |
+
},
|
88 |
+
{
|
89 |
+
"cell_type": "code",
|
90 |
+
"execution_count": 65,
|
91 |
+
"metadata": {},
|
92 |
+
"outputs": [
|
93 |
+
{
|
94 |
+
"name": "stdout",
|
95 |
+
"output_type": "stream",
|
96 |
+
"text": [
|
97 |
+
"Total gigabytes: 1.82\n"
|
98 |
+
]
|
99 |
+
}
|
100 |
+
],
|
101 |
+
"source": [
|
102 |
+
"bytes_per_parameter = 32 // 8 # Assuming 32-bit floating point\n",
|
103 |
+
"\n",
|
104 |
+
"total_bytes = total_parameter_count * bytes_per_parameter\n",
|
105 |
+
"\n",
|
106 |
+
"total_gigabytes = total_bytes / 1e9\n",
|
107 |
+
"\n",
|
108 |
+
"print(f\"Total gigabytes: {total_gigabytes:,.2f}\")"
|
109 |
+
]
|
110 |
+
},
|
111 |
+
{
|
112 |
+
"cell_type": "markdown",
|
113 |
+
"metadata": {},
|
114 |
+
"source": [
|
115 |
+
"Next, we'll estimate the maximum number of floating point operations (FLOPs) required to perform a full forward pass of the network on a single sample."
|
116 |
+
]
|
117 |
+
},
|
118 |
+
{
|
119 |
+
"cell_type": "code",
|
120 |
+
"execution_count": 66,
|
121 |
+
"metadata": {},
|
122 |
+
"outputs": [
|
123 |
+
{
|
124 |
+
"data": {
|
125 |
+
"image/png": "",
|
126 |
+
"text/plain": [
|
127 |
+
"<Figure size 640x480 with 1 Axes>"
|
128 |
+
]
|
129 |
+
},
|
130 |
+
"metadata": {},
|
131 |
+
"output_type": "display_data"
|
132 |
+
},
|
133 |
+
{
|
134 |
+
"name": "stdout",
|
135 |
+
"output_type": "stream",
|
136 |
+
"text": [
|
137 |
+
"attention 412,316,860,416 38.63%\n",
|
138 |
+
"mlp 549,756,993,536 51.50%\n",
|
139 |
+
"rms_norm 236,544 0.00%\n",
|
140 |
+
"output_layer 105,396,568,064 9.87%\n",
|
141 |
+
"\n",
|
142 |
+
"\n",
|
143 |
+
"Total forward FLOPs: 1,067,470,658,560\n"
|
144 |
+
]
|
145 |
+
}
|
146 |
+
],
|
147 |
+
"source": [
|
148 |
+
"ops_per_matmul = 2 # Multiply + accumulate (MAC)\n",
|
149 |
+
"ops_per_activation = 9 # Assuming GELU\n",
|
150 |
+
"ops_per_rms_norm = 7 # y = (x / sqrt(rms[x] + epsilon)) * gamma\n",
|
151 |
+
"\n",
|
152 |
+
"# K, Q, V projections\n",
|
153 |
+
"attention = ops_per_matmul * block_size * embedding_dimensions * 3 * embedding_dimensions\n",
|
154 |
+
"\n",
|
155 |
+
"# Attention logits\n",
|
156 |
+
"attention += 2 * ops_per_matmul * block_size ** 2 * embedding_dimensions\n",
|
157 |
+
"\n",
|
158 |
+
"# Output projection\n",
|
159 |
+
"attention += ops_per_matmul * block_size * embedding_dimensions ** 2\n",
|
160 |
+
"\n",
|
161 |
+
"attention *= num_hidden_layers\n",
|
162 |
+
"\n",
|
163 |
+
"# Linear transformations\n",
|
164 |
+
"mlp = 2 * ops_per_matmul * block_size * embedding_dimensions * 4 * embedding_dimensions\n",
|
165 |
+
"\n",
|
166 |
+
"# Non-linear activations\n",
|
167 |
+
"mlp += ops_per_activation * 4 * embedding_dimensions\n",
|
168 |
+
"\n",
|
169 |
+
"mlp *= num_hidden_layers\n",
|
170 |
+
"\n",
|
171 |
+
"rms_norm = ops_per_rms_norm * embedding_dimensions * (num_hidden_layers + 1)\n",
|
172 |
+
"\n",
|
173 |
+
"output_layer = ops_per_matmul * block_size * embedding_dimensions * vocabulary_size\n",
|
174 |
+
"\n",
|
175 |
+
"flops = {\n",
|
176 |
+
" \"attention\": attention,\n",
|
177 |
+
" \"mlp\": mlp,\n",
|
178 |
+
" \"rms_norm\": rms_norm,\n",
|
179 |
+
" \"output_layer\": output_layer,\n",
|
180 |
+
"}\n",
|
181 |
+
"\n",
|
182 |
+
"plt.bar(flops.keys(), flops.values())\n",
|
183 |
+
"\n",
|
184 |
+
"plt.title(\"Model Operations\")\n",
|
185 |
+
"plt.ylabel(\"# of FLOPs\")\n",
|
186 |
+
"plt.xticks(rotation=45)\n",
|
187 |
+
"\n",
|
188 |
+
"plt.show()\n",
|
189 |
+
"\n",
|
190 |
+
"total_forward_flops = sum(flops.values())\n",
|
191 |
+
"\n",
|
192 |
+
"for name, count in flops.items():\n",
|
193 |
+
" print(f\"{name:20s} {count:20,d} {count / total_forward_flops * 100:10.2f}%\")\n",
|
194 |
+
"\n",
|
195 |
+
"print(\"\\n\")\n",
|
196 |
+
"\n",
|
197 |
+
"print(f\"Total forward FLOPs: {total_forward_flops:,}\")"
|
198 |
+
]
|
199 |
+
},
|
200 |
+
{
|
201 |
+
"cell_type": "markdown",
|
202 |
+
"metadata": {},
|
203 |
+
"source": [
|
204 |
+
"Next, we'll estimate the number of FLOPs for the backward pass. For this we use a simple heuristic of 2X the forward pass."
|
205 |
+
]
|
206 |
+
},
|
207 |
+
{
|
208 |
+
"cell_type": "code",
|
209 |
+
"execution_count": 67,
|
210 |
+
"metadata": {},
|
211 |
+
"outputs": [
|
212 |
+
{
|
213 |
+
"name": "stdout",
|
214 |
+
"output_type": "stream",
|
215 |
+
"text": [
|
216 |
+
"Total backward FLOPs: 2,134,941,317,120\n"
|
217 |
+
]
|
218 |
+
}
|
219 |
+
],
|
220 |
+
"source": [
|
221 |
+
"total_backward_flops = 2 * total_forward_flops\n",
|
222 |
+
"\n",
|
223 |
+
"print(f\"Total backward FLOPs: {total_backward_flops:,}\")"
|
224 |
+
]
|
225 |
+
},
|
226 |
+
{
|
227 |
+
"cell_type": "markdown",
|
228 |
+
"metadata": {},
|
229 |
+
"source": [
|
230 |
+
"We'll do the same for the total FLOPs per roundtrip."
|
231 |
+
]
|
232 |
+
},
|
233 |
+
{
|
234 |
+
"cell_type": "code",
|
235 |
+
"execution_count": 68,
|
236 |
+
"metadata": {},
|
237 |
+
"outputs": [
|
238 |
+
{
|
239 |
+
"name": "stdout",
|
240 |
+
"output_type": "stream",
|
241 |
+
"text": [
|
242 |
+
"Total roundtrip FLOPs: 3,202,411,975,680\n"
|
243 |
+
]
|
244 |
+
}
|
245 |
+
],
|
246 |
+
"source": [
|
247 |
+
"total_roundtrip_flops = total_forward_flops + total_backward_flops\n",
|
248 |
+
"\n",
|
249 |
+
"print(f\"Total roundtrip FLOPs: {total_roundtrip_flops:,}\")"
|
250 |
+
]
|
251 |
+
},
|
252 |
+
{
|
253 |
+
"cell_type": "markdown",
|
254 |
+
"metadata": {},
|
255 |
+
"source": [
|
256 |
+
"Now, let's estimate how long it would take to train over every sample in the Openwebtext training set at least once in expectation using a few well-known Nvidia GPUs as benchmarks. Note that these results shown here are a best-case scenario and neglect to factor in overhead such as moving data to and from VRAM."
|
257 |
+
]
|
258 |
+
},
|
259 |
+
{
|
260 |
+
"cell_type": "code",
|
261 |
+
"execution_count": 69,
|
262 |
+
"metadata": {},
|
263 |
+
"outputs": [
|
264 |
+
{
|
265 |
+
"name": "stdout",
|
266 |
+
"output_type": "stream",
|
267 |
+
"text": [
|
268 |
+
"Total tokens: 8,994,885,755\n",
|
269 |
+
"Epochs required: 2,145\n",
|
270 |
+
"\n",
|
271 |
+
"RTX A2000: 513.19 seconds/epoch, 12.74 days required\n",
|
272 |
+
"A100 SXM: 52.55 seconds/epoch, 1.30 days required\n",
|
273 |
+
"HGX B100: 1.17 seconds/epoch, 0.03 days required\n"
|
274 |
+
]
|
275 |
+
}
|
276 |
+
],
|
277 |
+
"source": [
|
278 |
+
"RTX_A2000_BF16_FLOPS_PER_SECOND = 63.9e12\n",
|
279 |
+
"A100_SXM_BF16_FLOPS_PER_SECOND = 624.0e12\n",
|
280 |
+
"HGX_B100_BF16_FLOPS_PER_SECOND = 28000e12\n",
|
281 |
+
"\n",
|
282 |
+
"ESTIMATED_FLOPS_UTILIZATION = 0.4\n",
|
283 |
+
"\n",
|
284 |
+
"num_training_tokens = 8994885755\n",
|
285 |
+
"samples_per_epoch = 4096\n",
|
286 |
+
"\n",
|
287 |
+
"num_epochs_required = round(num_training_tokens / (samples_per_epoch * block_size))\n",
|
288 |
+
"\n",
|
289 |
+
"print(f\"Total tokens: {num_training_tokens:,}\")\n",
|
290 |
+
"print(f\"Epochs required: {num_epochs_required:,}\", end=\"\\n\\n\")\n",
|
291 |
+
"\n",
|
292 |
+
"gpus = {\n",
|
293 |
+
" \"RTX A2000\": RTX_A2000_BF16_FLOPS_PER_SECOND,\n",
|
294 |
+
" \"A100 SXM\": A100_SXM_BF16_FLOPS_PER_SECOND,\n",
|
295 |
+
" \"HGX B100\": HGX_B100_BF16_FLOPS_PER_SECOND,\n",
|
296 |
+
"}\n",
|
297 |
+
"\n",
|
298 |
+
"for name, flops_per_second in gpus.items():\n",
|
299 |
+
" flops_per_second *= ESTIMATED_FLOPS_UTILIZATION\n",
|
300 |
+
"\n",
|
301 |
+
" seconds_per_epoch = samples_per_epoch * total_roundtrip_flops / flops_per_second\n",
|
302 |
+
"\n",
|
303 |
+
" days_required = num_epochs_required * seconds_per_epoch / 60 / 60 / 24\n",
|
304 |
+
"\n",
|
305 |
+
" print(f\"{name}: {seconds_per_epoch:.2f} seconds/epoch, {days_required:,.2f} days required\")"
|
306 |
+
]
|
307 |
+
}
|
308 |
+
],
|
309 |
+
"metadata": {
|
310 |
+
"kernelspec": {
|
311 |
+
"display_name": ".venv",
|
312 |
+
"language": "python",
|
313 |
+
"name": "python3"
|
314 |
+
},
|
315 |
+
"language_info": {
|
316 |
+
"codemirror_mode": {
|
317 |
+
"name": "ipython",
|
318 |
+
"version": 3
|
319 |
+
},
|
320 |
+
"file_extension": ".py",
|
321 |
+
"mimetype": "text/x-python",
|
322 |
+
"name": "python",
|
323 |
+
"nbconvert_exporter": "python",
|
324 |
+
"pygments_lexer": "ipython3",
|
325 |
+
"version": "3.12.3"
|
326 |
+
}
|
327 |
+
},
|
328 |
+
"nbformat": 4,
|
329 |
+
"nbformat_minor": 2
|
330 |
+
}
|
models/lightgpt-small.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c9dfe5de2f2668272d38a7b2d29bc904229a359e9e970585eb7457ee4cb1ef8c
|
3 |
+
size 1819529541
|
out/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
pre-train.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import random
|
3 |
+
import signal
|
4 |
+
import warnings
|
5 |
+
|
6 |
+
from os import path, environ
|
7 |
+
from argparse import ArgumentParser
|
8 |
+
from contextlib import nullcontext
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
from torch.optim import Adafactor
|
14 |
+
from torch.amp import autocast
|
15 |
+
from torch.cuda import set_device, is_available as cuda_is_available, is_bf16_supported
|
16 |
+
from torch.nn.utils import clip_grad_norm_
|
17 |
+
from torch.distributed import init_process_group, destroy_process_group
|
18 |
+
from torch.distributed.optim import ZeroRedundancyOptimizer
|
19 |
+
from torch.nn.parallel import DistributedDataParallel
|
20 |
+
|
21 |
+
from torchmetrics.text import Perplexity
|
22 |
+
|
23 |
+
from model import GPT
|
24 |
+
from data import Openwebtext
|
25 |
+
|
26 |
+
from tqdm import tqdm
|
27 |
+
|
28 |
+
RANK = int(environ.get("RANK", -1))
|
29 |
+
LOCAL_RANK = int(environ.get("LOCAL_RANK", -1))
|
30 |
+
WORLD_SIZE = int(environ.get("WORLD_SIZE", -1))
|
31 |
+
|
32 |
+
IS_DDP = WORLD_SIZE > 1
|
33 |
+
|
34 |
+
IS_MASTER = RANK == 0 or not IS_DDP
|
35 |
+
|
36 |
+
DDP_BACKEND = "nccl" # nccl, gloo, etc.
|
37 |
+
|
38 |
+
|
39 |
+
def main():
|
40 |
+
parser = ArgumentParser(description="Pre-train the GPT.")
|
41 |
+
|
42 |
+
parser.add_argument("--batch_size", default=1, type=int)
|
43 |
+
parser.add_argument("--gradient_accumulation_steps", default=128, type=int)
|
44 |
+
parser.add_argument("--samples_per_epoch", default=4096, type=int)
|
45 |
+
parser.add_argument("--learning_rate", default=1e-2, type=float)
|
46 |
+
parser.add_argument("--max_gradient_norm", default=1.0, type=float)
|
47 |
+
parser.add_argument("--dropout", default=0.1, type=float)
|
48 |
+
parser.add_argument("--num_epochs", default=2140, type=int)
|
49 |
+
parser.add_argument("--block_size", default=1024, type=int)
|
50 |
+
parser.add_argument("--embedding_dimensions", default=1024, type=int)
|
51 |
+
parser.add_argument("--num_attention_heads", default=16, type=int)
|
52 |
+
parser.add_argument("--num_hidden_layers", default=32, type=int)
|
53 |
+
parser.add_argument("--activation_checkpointing", action="store_true")
|
54 |
+
parser.add_argument("--eval_interval", default=10, type=int)
|
55 |
+
parser.add_argument("--checkpoint_interval", default=20, type=int)
|
56 |
+
parser.add_argument("--checkpoint_path", default="./out/checkpoint.pt", type=str)
|
57 |
+
parser.add_argument("--checkpoint_history", action="store_true")
|
58 |
+
parser.add_argument("--resume", action="store_true")
|
59 |
+
parser.add_argument("--dataset_path", default="./dataset", type=str)
|
60 |
+
parser.add_argument("--num_dataset_processes", default=8, type=int)
|
61 |
+
parser.add_argument("--device", default="cuda", type=str)
|
62 |
+
parser.add_argument("--seed", default=None, type=int)
|
63 |
+
|
64 |
+
args = parser.parse_args()
|
65 |
+
|
66 |
+
if args.batch_size < 1:
|
67 |
+
raise ValueError(f"Batch size must be greater than 0, {args.batch_size} given.")
|
68 |
+
|
69 |
+
if args.gradient_accumulation_steps < 1:
|
70 |
+
raise ValueError(
|
71 |
+
f"Gradient accumulation steps must be greater than 0, {args.gradient_accumulation_steps} given."
|
72 |
+
)
|
73 |
+
|
74 |
+
if args.learning_rate < 0:
|
75 |
+
raise ValueError(
|
76 |
+
f"Learning rate must be a positive value, {args.learning_rate} given."
|
77 |
+
)
|
78 |
+
|
79 |
+
if args.num_epochs < 1:
|
80 |
+
raise ValueError(f"Must train for at least 1 epoch, {args.num_epochs} given.")
|
81 |
+
|
82 |
+
if args.eval_interval < 1:
|
83 |
+
raise ValueError(
|
84 |
+
f"Eval interval must be greater than 0, {args.eval_interval} given."
|
85 |
+
)
|
86 |
+
|
87 |
+
if args.checkpoint_interval < 1:
|
88 |
+
raise ValueError(
|
89 |
+
f"Checkpoint interval must be greater than 0, {args.checkpoint_interval} given."
|
90 |
+
)
|
91 |
+
|
92 |
+
if IS_DDP:
|
93 |
+
init_process_group(backend=DDP_BACKEND, world_size=WORLD_SIZE)
|
94 |
+
|
95 |
+
args.device = f"cuda:{LOCAL_RANK}"
|
96 |
+
|
97 |
+
set_device(args.device)
|
98 |
+
|
99 |
+
if args.seed:
|
100 |
+
args.seed += RANK
|
101 |
+
|
102 |
+
if args.gradient_accumulation_steps % WORLD_SIZE != 0:
|
103 |
+
warnings.warn(
|
104 |
+
"Number of gradient accumulation steps does not"
|
105 |
+
"divide evenly into the world size."
|
106 |
+
)
|
107 |
+
|
108 |
+
args.gradient_accumulation_steps //= WORLD_SIZE
|
109 |
+
|
110 |
+
assert (
|
111 |
+
args.gradient_accumulation_steps > 0
|
112 |
+
), "World size is larger than the number of gradient accumulation steps."
|
113 |
+
|
114 |
+
if args.samples_per_epoch % WORLD_SIZE != 0:
|
115 |
+
warnings.warn(
|
116 |
+
"Number of samples per epoch does not"
|
117 |
+
"divide evenly into the world size."
|
118 |
+
)
|
119 |
+
|
120 |
+
args.samples_per_epoch //= WORLD_SIZE
|
121 |
+
|
122 |
+
assert (
|
123 |
+
args.samples_per_epoch > 0
|
124 |
+
), "World size is larger than the number of samples per epoch."
|
125 |
+
|
126 |
+
torch.set_float32_matmul_precision("high")
|
127 |
+
|
128 |
+
if "cuda" in args.device and not cuda_is_available():
|
129 |
+
raise RuntimeError("Cuda is not available.")
|
130 |
+
|
131 |
+
dtype = (
|
132 |
+
torch.bfloat16
|
133 |
+
if "cuda" in args.device and is_bf16_supported()
|
134 |
+
else torch.float32
|
135 |
+
)
|
136 |
+
|
137 |
+
forward_context = autocast(device_type=args.device, dtype=dtype)
|
138 |
+
|
139 |
+
if args.seed:
|
140 |
+
torch.manual_seed(args.seed)
|
141 |
+
random.seed(args.seed)
|
142 |
+
|
143 |
+
training = Openwebtext(
|
144 |
+
root_path=args.dataset_path,
|
145 |
+
train=True,
|
146 |
+
tokens_per_sample=args.block_size,
|
147 |
+
samples_per_epoch=args.samples_per_epoch,
|
148 |
+
num_processes=args.num_dataset_processes,
|
149 |
+
)
|
150 |
+
testing = Openwebtext(
|
151 |
+
root_path=args.dataset_path,
|
152 |
+
train=False,
|
153 |
+
tokens_per_sample=args.block_size,
|
154 |
+
samples_per_epoch=args.samples_per_epoch,
|
155 |
+
num_processes=args.num_dataset_processes,
|
156 |
+
)
|
157 |
+
|
158 |
+
train_loader = DataLoader(
|
159 |
+
training, batch_size=args.batch_size, pin_memory="cpu" not in args.device
|
160 |
+
)
|
161 |
+
test_loader = DataLoader(
|
162 |
+
testing, batch_size=args.batch_size, pin_memory="cpu" not in args.device
|
163 |
+
)
|
164 |
+
|
165 |
+
model_args = {
|
166 |
+
"block_size": args.block_size,
|
167 |
+
"embedding_dimensions": args.embedding_dimensions,
|
168 |
+
"num_heads": args.num_attention_heads,
|
169 |
+
"num_layers": args.num_hidden_layers,
|
170 |
+
"dropout": args.dropout,
|
171 |
+
"vocabulary_size": training.vocabulary_size,
|
172 |
+
"padding_index": training.PADDING_INDEX,
|
173 |
+
"eos_index": training.eos_index,
|
174 |
+
}
|
175 |
+
|
176 |
+
model = GPT(**model_args, activation_checkpointing=args.activation_checkpointing)
|
177 |
+
|
178 |
+
if IS_DDP:
|
179 |
+
model = DistributedDataParallel(model, device_ids=[LOCAL_RANK])
|
180 |
+
|
181 |
+
print("Compiling model")
|
182 |
+
model = torch.compile(model).to(args.device)
|
183 |
+
|
184 |
+
if IS_DDP:
|
185 |
+
optimizer = ZeroRedundancyOptimizer(
|
186 |
+
model.parameters(),
|
187 |
+
optimizer_class=Adafactor,
|
188 |
+
lr=args.learning_rate,
|
189 |
+
)
|
190 |
+
else:
|
191 |
+
optimizer = Adafactor(model.parameters(), lr=args.learning_rate)
|
192 |
+
|
193 |
+
starting_epoch = 1
|
194 |
+
|
195 |
+
if args.resume:
|
196 |
+
checkpoint = torch.load(
|
197 |
+
args.checkpoint_path, map_location="cpu", weights_only=True
|
198 |
+
) # Always load into CPU RAM first to prevent CUDA out-of-memory errors.
|
199 |
+
|
200 |
+
model.load_state_dict(checkpoint["model"])
|
201 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
202 |
+
starting_epoch += checkpoint["epoch"]
|
203 |
+
|
204 |
+
model = model.to(args.device)
|
205 |
+
|
206 |
+
print("Previous checkpoint resumed successfully")
|
207 |
+
|
208 |
+
model.train()
|
209 |
+
|
210 |
+
print(f"Model has {model.num_trainable_params:,} trainable parameters")
|
211 |
+
|
212 |
+
perplexity_metric = Perplexity(ignore_index=training.PADDING_INDEX).to(args.device)
|
213 |
+
|
214 |
+
signal.signal(signal.SIGTERM, on_sigterm)
|
215 |
+
|
216 |
+
print("Pre-training ...")
|
217 |
+
|
218 |
+
for epoch in range(starting_epoch, args.num_epochs + 1):
|
219 |
+
total_cross_entropy, total_gradient_norm = 0.0, 0.0
|
220 |
+
total_batches, total_steps = 0, 0
|
221 |
+
|
222 |
+
for step, (x, y) in enumerate(
|
223 |
+
tqdm(train_loader, desc=f"Epoch {epoch}", leave=False), start=1
|
224 |
+
):
|
225 |
+
x = x.to(args.device, non_blocking=True)
|
226 |
+
y = y.to(args.device, non_blocking=True)
|
227 |
+
|
228 |
+
with forward_context:
|
229 |
+
y_pred, loss = model(x, y)
|
230 |
+
|
231 |
+
scaled_loss = loss / args.gradient_accumulation_steps
|
232 |
+
|
233 |
+
sync_and_step = step % args.gradient_accumulation_steps == 0
|
234 |
+
|
235 |
+
backward_context = (
|
236 |
+
model.no_sync() if IS_DDP and not sync_and_step else nullcontext()
|
237 |
+
)
|
238 |
+
|
239 |
+
with backward_context:
|
240 |
+
scaled_loss.backward()
|
241 |
+
|
242 |
+
total_cross_entropy += loss.item()
|
243 |
+
|
244 |
+
if sync_and_step:
|
245 |
+
norm = clip_grad_norm_(model.parameters(), args.max_gradient_norm)
|
246 |
+
|
247 |
+
optimizer.step()
|
248 |
+
|
249 |
+
optimizer.zero_grad(set_to_none=True)
|
250 |
+
|
251 |
+
total_gradient_norm += norm.item()
|
252 |
+
total_steps += 1
|
253 |
+
|
254 |
+
total_batches += 1
|
255 |
+
|
256 |
+
average_cross_entropy = total_cross_entropy / total_batches
|
257 |
+
average_gradient_norm = total_gradient_norm / total_steps
|
258 |
+
|
259 |
+
print(
|
260 |
+
f"Epoch {epoch}:",
|
261 |
+
f"Cross Entropy: {average_cross_entropy:.5f},",
|
262 |
+
f"Gradient Norm: {average_gradient_norm:.4f}",
|
263 |
+
)
|
264 |
+
|
265 |
+
if epoch % args.eval_interval == 0 and IS_MASTER:
|
266 |
+
model.eval()
|
267 |
+
|
268 |
+
for x, y in tqdm(test_loader, desc="Testing", leave=False):
|
269 |
+
x = x.to(args.device, non_blocking=True)
|
270 |
+
y = y.to(args.device, non_blocking=True)
|
271 |
+
|
272 |
+
with torch.no_grad():
|
273 |
+
y_pred, _ = model(x)
|
274 |
+
|
275 |
+
perplexity_metric.update(y_pred, y)
|
276 |
+
|
277 |
+
perplexity = perplexity_metric.compute()
|
278 |
+
|
279 |
+
print(f"Perplexity: {perplexity:.3f}")
|
280 |
+
|
281 |
+
perplexity_metric.reset()
|
282 |
+
|
283 |
+
model.train()
|
284 |
+
|
285 |
+
if epoch % args.checkpoint_interval == 0 and IS_MASTER:
|
286 |
+
checkpoint = {
|
287 |
+
"epoch": epoch,
|
288 |
+
"model_args": model_args,
|
289 |
+
"model": model.state_dict(),
|
290 |
+
"optimizer": optimizer.state_dict(),
|
291 |
+
}
|
292 |
+
|
293 |
+
if args.checkpoint_history:
|
294 |
+
root, ext = path.splitext(args.checkpoint_path)
|
295 |
+
|
296 |
+
checkpoint_path = f"{root}-{epoch}{ext}"
|
297 |
+
else:
|
298 |
+
checkpoint_path = args.checkpoint_path
|
299 |
+
|
300 |
+
torch.save(checkpoint, checkpoint_path)
|
301 |
+
|
302 |
+
print("Checkpoint saved")
|
303 |
+
|
304 |
+
if IS_DDP:
|
305 |
+
destroy_process_group()
|
306 |
+
|
307 |
+
print("Done!")
|
308 |
+
|
309 |
+
|
310 |
+
def on_sigterm(signum, frame):
|
311 |
+
print("Hold on, attempting to exit gracefully.")
|
312 |
+
|
313 |
+
if IS_DDP:
|
314 |
+
destroy_process_group()
|
315 |
+
|
316 |
+
sys.exit(0)
|
317 |
+
|
318 |
+
|
319 |
+
if __name__ == "__main__":
|
320 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets==3.0.2
|
2 |
+
numpy==1.26.4
|
3 |
+
torch==2.5.1
|
4 |
+
torchmetrics==1.5.1
|
5 |
+
tiktoken==0.8.0
|
6 |
+
tqdm==4.66.6
|
7 |
+
matplotlib==3.9.2
|