Andrew DalPino
commited on
Commit
·
f28a628
1
Parent(s):
160e81f
Add FSDP
Browse files- README.md +8 -6
- beam_search.py +1 -1
- model.py +1 -2
- pre-train.py +32 -17
README.md
CHANGED
@@ -9,19 +9,20 @@ metrics:
|
|
9 |
- perplexity
|
10 |
pipeline_tag: text-generation
|
11 |
tags:
|
12 |
-
-
|
|
|
13 |
---
|
14 |
# LightGPT
|
15 |
|
16 |
-
LightGPT is a lightweight generative pre-trained Transformer (GPT) model for the people! Built using pure PyTorch, LightGPT can generate text, answer questions, summarize documents, and more
|
17 |
|
18 |
-
##
|
19 |
|
20 |
-
- **Parameter-efficiency**: LightGPT aims to be a more parsimonious model by only training parameters that are absolutely necessary. As such, biases and positional embeddings have been completely removed from the
|
21 |
|
22 |
-
- **Low
|
23 |
|
24 |
-
- **Fully
|
25 |
|
26 |
## Install Project Dependencies
|
27 |
|
@@ -100,6 +101,7 @@ Soon ...
|
|
100 |
| --num_hidden_layers | 24 | int | The number of attention/MLP blocks within the hidden layer of the network. |
|
101 |
| --dropout | 0.1 | float | The proportion of signals to send to zero during training as regularization. |
|
102 |
| --activation_checkpointing | False | bool | Should we use activation checkpointing? |
|
|
|
103 |
| --checkpoint_interval | 20 | int | Save the model parameters to disk every this many epochs. |
|
104 |
| --checkpoint_path | "./out/checkpoint.pt" | string | The path to the checkpoint file on disk. |
|
105 |
| --dataset_path | "./dataset" | string | The path to the dataset files on disk. |
|
|
|
9 |
- perplexity
|
10 |
pipeline_tag: text-generation
|
11 |
tags:
|
12 |
+
- LightGPT
|
13 |
+
- Open-source
|
14 |
---
|
15 |
# LightGPT
|
16 |
|
17 |
+
LightGPT is a lightweight generative pre-trained Transformer (GPT) model for the people! Built using pure PyTorch, LightGPT can generate text, answer questions, summarize documents, and more. A unique feature of LightGPT is that it allows you to train larger models on smaller hardware by taking advantage of memory optimizations wherever possible.
|
18 |
|
19 |
+
## Features
|
20 |
|
21 |
+
- **Parameter-efficiency**: LightGPT aims to be a more parsimonious model by only training parameters that are absolutely necessary. As such, biases and positional embeddings have been completely removed from the architecture. In addition, the token embeddings and output layer share weight matrices resulting in a buy-one-get-one-free deal on trainable parameters.
|
22 |
|
23 |
+
- **Low Memory Utilization**: LightGPT employs a number of training-time optimizations that conserve precious VRAM. With zero-redundancy distributed pre-training using fully-sharded data-parallel (FSDP), activation checkpointing, and automatic mixed precision, you'll be able to train larger models by accepting a relatively small amount of communication and computational overhead.
|
24 |
|
25 |
+
- **Fully Open-source**: Unlike closed-source LLMs, LightGPT provides both the model weights *and* the source code to train, fine-tune, and generate text from the model using your own hardware. With the help of the open-source software community, we aim to democratize AI and continually improve the models.
|
26 |
|
27 |
## Install Project Dependencies
|
28 |
|
|
|
101 |
| --num_hidden_layers | 24 | int | The number of attention/MLP blocks within the hidden layer of the network. |
|
102 |
| --dropout | 0.1 | float | The proportion of signals to send to zero during training as regularization. |
|
103 |
| --activation_checkpointing | False | bool | Should we use activation checkpointing? |
|
104 |
+
| --ddp_sharding_level | 2 | (0, 2, 3) | int | The level of sharding to use for DDP training. |
|
105 |
| --checkpoint_interval | 20 | int | Save the model parameters to disk every this many epochs. |
|
106 |
| --checkpoint_path | "./out/checkpoint.pt" | string | The path to the checkpoint file on disk. |
|
107 |
| --dataset_path | "./dataset" | string | The path to the dataset files on disk. |
|
beam_search.py
CHANGED
@@ -20,7 +20,7 @@ def main():
|
|
20 |
|
21 |
parser.add_argument("--checkpoint_path", default="./out/checkpoint.pt", type=str)
|
22 |
parser.add_argument("--lora_path", default=None, type=str)
|
23 |
-
parser.add_argument("--max_tokens", default=
|
24 |
parser.add_argument("--num_candidates", default=3, type=int)
|
25 |
parser.add_argument("--beam_width", default=16, type=int)
|
26 |
parser.add_argument("--device", default="cuda", type=str)
|
|
|
20 |
|
21 |
parser.add_argument("--checkpoint_path", default="./out/checkpoint.pt", type=str)
|
22 |
parser.add_argument("--lora_path", default=None, type=str)
|
23 |
+
parser.add_argument("--max_tokens", default=500, type=int)
|
24 |
parser.add_argument("--num_candidates", default=3, type=int)
|
25 |
parser.add_argument("--beam_width", default=16, type=int)
|
26 |
parser.add_argument("--device", default="cuda", type=str)
|
model.py
CHANGED
@@ -215,13 +215,12 @@ class GPT(Module):
|
|
215 |
log_probability: float
|
216 |
tokens: Tensor
|
217 |
|
218 |
-
@property
|
219 |
def priority(self) -> float:
|
220 |
return self.log_probability
|
221 |
|
222 |
sort_candidates = partial(
|
223 |
sorted,
|
224 |
-
key=lambda candidate: candidate.priority,
|
225 |
reverse=True,
|
226 |
)
|
227 |
|
|
|
215 |
log_probability: float
|
216 |
tokens: Tensor
|
217 |
|
|
|
218 |
def priority(self) -> float:
|
219 |
return self.log_probability
|
220 |
|
221 |
sort_candidates = partial(
|
222 |
sorted,
|
223 |
+
key=lambda candidate: candidate.priority(),
|
224 |
reverse=True,
|
225 |
)
|
226 |
|
pre-train.py
CHANGED
@@ -15,8 +15,7 @@ from torch.amp import autocast
|
|
15 |
from torch.cuda import set_device, is_available as cuda_is_available, is_bf16_supported
|
16 |
from torch.nn.utils import clip_grad_norm_
|
17 |
from torch.distributed import init_process_group, destroy_process_group
|
18 |
-
from torch.distributed.
|
19 |
-
from torch.nn.parallel import DistributedDataParallel
|
20 |
|
21 |
from torchmetrics.text import Perplexity
|
22 |
|
@@ -33,7 +32,7 @@ IS_DDP = WORLD_SIZE > 1
|
|
33 |
|
34 |
IS_MASTER = RANK == 0 or not IS_DDP
|
35 |
|
36 |
-
DDP_BACKEND = "nccl"
|
37 |
|
38 |
|
39 |
def main():
|
@@ -51,6 +50,7 @@ def main():
|
|
51 |
parser.add_argument("--num_attention_heads", default=16, type=int)
|
52 |
parser.add_argument("--num_hidden_layers", default=32, type=int)
|
53 |
parser.add_argument("--activation_checkpointing", action="store_true")
|
|
|
54 |
parser.add_argument("--eval_interval", default=10, type=int)
|
55 |
parser.add_argument("--checkpoint_interval", default=20, type=int)
|
56 |
parser.add_argument("--checkpoint_path", default="./out/checkpoint.pt", type=str)
|
@@ -175,19 +175,25 @@ def main():
|
|
175 |
model = GPT(**model_args, activation_checkpointing=args.activation_checkpointing)
|
176 |
|
177 |
if IS_DDP:
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
print("Compiling model")
|
181 |
model = torch.compile(model).to(args.device)
|
182 |
|
183 |
-
|
184 |
-
optimizer = ZeroRedundancyOptimizer(
|
185 |
-
model.parameters(),
|
186 |
-
optimizer_class=Adafactor,
|
187 |
-
lr=args.learning_rate,
|
188 |
-
)
|
189 |
-
else:
|
190 |
-
optimizer = Adafactor(model.parameters(), lr=args.learning_rate)
|
191 |
|
192 |
starting_epoch = 1
|
193 |
|
@@ -210,7 +216,7 @@ def main():
|
|
210 |
|
211 |
perplexity_metric = Perplexity(ignore_index=training.PADDING_INDEX).to(args.device)
|
212 |
|
213 |
-
|
214 |
|
215 |
print("Pre-training ...")
|
216 |
|
@@ -294,19 +300,28 @@ def main():
|
|
294 |
print("Checkpoint saved")
|
295 |
|
296 |
if IS_DDP:
|
297 |
-
|
298 |
|
299 |
print("Done!")
|
300 |
|
301 |
|
302 |
-
def
|
303 |
-
|
|
|
|
|
|
|
|
|
|
|
304 |
|
305 |
if IS_DDP:
|
306 |
-
|
307 |
|
308 |
sys.exit(0)
|
309 |
|
310 |
|
|
|
|
|
|
|
|
|
311 |
if __name__ == "__main__":
|
312 |
main()
|
|
|
15 |
from torch.cuda import set_device, is_available as cuda_is_available, is_bf16_supported
|
16 |
from torch.nn.utils import clip_grad_norm_
|
17 |
from torch.distributed import init_process_group, destroy_process_group
|
18 |
+
from torch.distributed.fsdp import FullyShardedDataParallel, ShardingStrategy
|
|
|
19 |
|
20 |
from torchmetrics.text import Perplexity
|
21 |
|
|
|
32 |
|
33 |
IS_MASTER = RANK == 0 or not IS_DDP
|
34 |
|
35 |
+
DDP_BACKEND = "nccl"
|
36 |
|
37 |
|
38 |
def main():
|
|
|
50 |
parser.add_argument("--num_attention_heads", default=16, type=int)
|
51 |
parser.add_argument("--num_hidden_layers", default=32, type=int)
|
52 |
parser.add_argument("--activation_checkpointing", action="store_true")
|
53 |
+
parser.add_argument("--ddp_sharding_level", default=2, choices=[0, 2, 3])
|
54 |
parser.add_argument("--eval_interval", default=10, type=int)
|
55 |
parser.add_argument("--checkpoint_interval", default=20, type=int)
|
56 |
parser.add_argument("--checkpoint_path", default="./out/checkpoint.pt", type=str)
|
|
|
175 |
model = GPT(**model_args, activation_checkpointing=args.activation_checkpointing)
|
176 |
|
177 |
if IS_DDP:
|
178 |
+
match args.ddp_sharding_level:
|
179 |
+
case 0:
|
180 |
+
sharding_strategy = ShardingStrategy.NO_SHARD
|
181 |
+
case 2:
|
182 |
+
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
|
183 |
+
case 3:
|
184 |
+
sharding_strategy = ShardingStrategy.FULL_SHARD
|
185 |
+
|
186 |
+
model = FullyShardedDataParallel(
|
187 |
+
model,
|
188 |
+
device_id=LOCAL_RANK,
|
189 |
+
sharding_strategy=sharding_strategy,
|
190 |
+
use_orig_params=True,
|
191 |
+
)
|
192 |
|
193 |
print("Compiling model")
|
194 |
model = torch.compile(model).to(args.device)
|
195 |
|
196 |
+
optimizer = Adafactor(model.parameters(), lr=args.learning_rate)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
|
198 |
starting_epoch = 1
|
199 |
|
|
|
216 |
|
217 |
perplexity_metric = Perplexity(ignore_index=training.PADDING_INDEX).to(args.device)
|
218 |
|
219 |
+
register_signal_handlers()
|
220 |
|
221 |
print("Pre-training ...")
|
222 |
|
|
|
300 |
print("Checkpoint saved")
|
301 |
|
302 |
if IS_DDP:
|
303 |
+
ddp_cleanup()
|
304 |
|
305 |
print("Done!")
|
306 |
|
307 |
|
308 |
+
def register_signal_handlers():
|
309 |
+
signal.signal(signal.SIGINT, shutdown)
|
310 |
+
signal.signal(signal.SIGTERM, shutdown)
|
311 |
+
|
312 |
+
|
313 |
+
def shutdown(signum, frame):
|
314 |
+
print("Hold on, attempting to exit gracefully")
|
315 |
|
316 |
if IS_DDP:
|
317 |
+
ddp_cleanup()
|
318 |
|
319 |
sys.exit(0)
|
320 |
|
321 |
|
322 |
+
def ddp_cleanup():
|
323 |
+
destroy_process_group()
|
324 |
+
|
325 |
+
|
326 |
if __name__ == "__main__":
|
327 |
main()
|