Andrew DalPino commited on
Commit
f28a628
·
1 Parent(s): 160e81f
Files changed (4) hide show
  1. README.md +8 -6
  2. beam_search.py +1 -1
  3. model.py +1 -2
  4. pre-train.py +32 -17
README.md CHANGED
@@ -9,19 +9,20 @@ metrics:
9
  - perplexity
10
  pipeline_tag: text-generation
11
  tags:
12
- - GPT
 
13
  ---
14
  # LightGPT
15
 
16
- LightGPT is a lightweight generative pre-trained Transformer (GPT) model for the people! Built using pure PyTorch, LightGPT can generate text, answer questions, summarize documents, and more, all using consumer hardware. A unique feature of LightGPT is that you can trade off compute for additional memory-efficiency as needed - allowing you to train larger models on smaller hardware. It also supports memory-efficient pre-training over multiple GPUs or clusters of GPUs using PyTorch's Distributed Data Parallel (DDP) protocol with ZeRO Redundancy sharding.
17
 
18
- ## What makes LightGPT different?
19
 
20
- - **Parameter-efficiency**: LightGPT aims to be a more parsimonious model by only training parameters that are absolutely necessary. As such, biases and positional embeddings have been completely removed from the neural network architecture. In addition, the token embeddings and output layer share weight matrices resulting in a further reduction in trainable parameters.
21
 
22
- - **Low VRAM Utilization**: LightGPT's Adafactor optimizer reduces the number of training-time buffers over Adam from O(n*m) to O(n+m) for every trainable weight matrix with minimal effect on runtime and minima quality. In addition, with activation check-pointing enabled, buffers needed to compute gradients during training are reduced by a factor of 10X or more.
23
 
24
- - **Fully open-source**: Unlike closed-source LLMs, LightGPT provides both the model weights *and* the source code to train, fine-tune, and generate text from the model using your own hardware. With the help of the open-source software community, we aim to democratize AI and continually improve the models.
25
 
26
  ## Install Project Dependencies
27
 
@@ -100,6 +101,7 @@ Soon ...
100
  | --num_hidden_layers | 24 | int | The number of attention/MLP blocks within the hidden layer of the network. |
101
  | --dropout | 0.1 | float | The proportion of signals to send to zero during training as regularization. |
102
  | --activation_checkpointing | False | bool | Should we use activation checkpointing? |
 
103
  | --checkpoint_interval | 20 | int | Save the model parameters to disk every this many epochs. |
104
  | --checkpoint_path | "./out/checkpoint.pt" | string | The path to the checkpoint file on disk. |
105
  | --dataset_path | "./dataset" | string | The path to the dataset files on disk. |
 
9
  - perplexity
10
  pipeline_tag: text-generation
11
  tags:
12
+ - LightGPT
13
+ - Open-source
14
  ---
15
  # LightGPT
16
 
17
+ LightGPT is a lightweight generative pre-trained Transformer (GPT) model for the people! Built using pure PyTorch, LightGPT can generate text, answer questions, summarize documents, and more. A unique feature of LightGPT is that it allows you to train larger models on smaller hardware by taking advantage of memory optimizations wherever possible.
18
 
19
+ ## Features
20
 
21
+ - **Parameter-efficiency**: LightGPT aims to be a more parsimonious model by only training parameters that are absolutely necessary. As such, biases and positional embeddings have been completely removed from the architecture. In addition, the token embeddings and output layer share weight matrices resulting in a buy-one-get-one-free deal on trainable parameters.
22
 
23
+ - **Low Memory Utilization**: LightGPT employs a number of training-time optimizations that conserve precious VRAM. With zero-redundancy distributed pre-training using fully-sharded data-parallel (FSDP), activation checkpointing, and automatic mixed precision, you'll be able to train larger models by accepting a relatively small amount of communication and computational overhead.
24
 
25
+ - **Fully Open-source**: Unlike closed-source LLMs, LightGPT provides both the model weights *and* the source code to train, fine-tune, and generate text from the model using your own hardware. With the help of the open-source software community, we aim to democratize AI and continually improve the models.
26
 
27
  ## Install Project Dependencies
28
 
 
101
  | --num_hidden_layers | 24 | int | The number of attention/MLP blocks within the hidden layer of the network. |
102
  | --dropout | 0.1 | float | The proportion of signals to send to zero during training as regularization. |
103
  | --activation_checkpointing | False | bool | Should we use activation checkpointing? |
104
+ | --ddp_sharding_level | 2 | (0, 2, 3) | int | The level of sharding to use for DDP training. |
105
  | --checkpoint_interval | 20 | int | Save the model parameters to disk every this many epochs. |
106
  | --checkpoint_path | "./out/checkpoint.pt" | string | The path to the checkpoint file on disk. |
107
  | --dataset_path | "./dataset" | string | The path to the dataset files on disk. |
beam_search.py CHANGED
@@ -20,7 +20,7 @@ def main():
20
 
21
  parser.add_argument("--checkpoint_path", default="./out/checkpoint.pt", type=str)
22
  parser.add_argument("--lora_path", default=None, type=str)
23
- parser.add_argument("--max_tokens", default=200, type=int)
24
  parser.add_argument("--num_candidates", default=3, type=int)
25
  parser.add_argument("--beam_width", default=16, type=int)
26
  parser.add_argument("--device", default="cuda", type=str)
 
20
 
21
  parser.add_argument("--checkpoint_path", default="./out/checkpoint.pt", type=str)
22
  parser.add_argument("--lora_path", default=None, type=str)
23
+ parser.add_argument("--max_tokens", default=500, type=int)
24
  parser.add_argument("--num_candidates", default=3, type=int)
25
  parser.add_argument("--beam_width", default=16, type=int)
26
  parser.add_argument("--device", default="cuda", type=str)
model.py CHANGED
@@ -215,13 +215,12 @@ class GPT(Module):
215
  log_probability: float
216
  tokens: Tensor
217
 
218
- @property
219
  def priority(self) -> float:
220
  return self.log_probability
221
 
222
  sort_candidates = partial(
223
  sorted,
224
- key=lambda candidate: candidate.priority,
225
  reverse=True,
226
  )
227
 
 
215
  log_probability: float
216
  tokens: Tensor
217
 
 
218
  def priority(self) -> float:
219
  return self.log_probability
220
 
221
  sort_candidates = partial(
222
  sorted,
223
+ key=lambda candidate: candidate.priority(),
224
  reverse=True,
225
  )
226
 
pre-train.py CHANGED
@@ -15,8 +15,7 @@ from torch.amp import autocast
15
  from torch.cuda import set_device, is_available as cuda_is_available, is_bf16_supported
16
  from torch.nn.utils import clip_grad_norm_
17
  from torch.distributed import init_process_group, destroy_process_group
18
- from torch.distributed.optim import ZeroRedundancyOptimizer
19
- from torch.nn.parallel import DistributedDataParallel
20
 
21
  from torchmetrics.text import Perplexity
22
 
@@ -33,7 +32,7 @@ IS_DDP = WORLD_SIZE > 1
33
 
34
  IS_MASTER = RANK == 0 or not IS_DDP
35
 
36
- DDP_BACKEND = "nccl" # nccl, gloo, etc.
37
 
38
 
39
  def main():
@@ -51,6 +50,7 @@ def main():
51
  parser.add_argument("--num_attention_heads", default=16, type=int)
52
  parser.add_argument("--num_hidden_layers", default=32, type=int)
53
  parser.add_argument("--activation_checkpointing", action="store_true")
 
54
  parser.add_argument("--eval_interval", default=10, type=int)
55
  parser.add_argument("--checkpoint_interval", default=20, type=int)
56
  parser.add_argument("--checkpoint_path", default="./out/checkpoint.pt", type=str)
@@ -175,19 +175,25 @@ def main():
175
  model = GPT(**model_args, activation_checkpointing=args.activation_checkpointing)
176
 
177
  if IS_DDP:
178
- model = DistributedDataParallel(model, device_ids=[LOCAL_RANK])
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
  print("Compiling model")
181
  model = torch.compile(model).to(args.device)
182
 
183
- if IS_DDP:
184
- optimizer = ZeroRedundancyOptimizer(
185
- model.parameters(),
186
- optimizer_class=Adafactor,
187
- lr=args.learning_rate,
188
- )
189
- else:
190
- optimizer = Adafactor(model.parameters(), lr=args.learning_rate)
191
 
192
  starting_epoch = 1
193
 
@@ -210,7 +216,7 @@ def main():
210
 
211
  perplexity_metric = Perplexity(ignore_index=training.PADDING_INDEX).to(args.device)
212
 
213
- signal.signal(signal.SIGTERM, on_sigterm)
214
 
215
  print("Pre-training ...")
216
 
@@ -294,19 +300,28 @@ def main():
294
  print("Checkpoint saved")
295
 
296
  if IS_DDP:
297
- destroy_process_group()
298
 
299
  print("Done!")
300
 
301
 
302
- def on_sigterm(signum, frame):
303
- print("Hold on, attempting to exit gracefully.")
 
 
 
 
 
304
 
305
  if IS_DDP:
306
- destroy_process_group()
307
 
308
  sys.exit(0)
309
 
310
 
 
 
 
 
311
  if __name__ == "__main__":
312
  main()
 
15
  from torch.cuda import set_device, is_available as cuda_is_available, is_bf16_supported
16
  from torch.nn.utils import clip_grad_norm_
17
  from torch.distributed import init_process_group, destroy_process_group
18
+ from torch.distributed.fsdp import FullyShardedDataParallel, ShardingStrategy
 
19
 
20
  from torchmetrics.text import Perplexity
21
 
 
32
 
33
  IS_MASTER = RANK == 0 or not IS_DDP
34
 
35
+ DDP_BACKEND = "nccl"
36
 
37
 
38
  def main():
 
50
  parser.add_argument("--num_attention_heads", default=16, type=int)
51
  parser.add_argument("--num_hidden_layers", default=32, type=int)
52
  parser.add_argument("--activation_checkpointing", action="store_true")
53
+ parser.add_argument("--ddp_sharding_level", default=2, choices=[0, 2, 3])
54
  parser.add_argument("--eval_interval", default=10, type=int)
55
  parser.add_argument("--checkpoint_interval", default=20, type=int)
56
  parser.add_argument("--checkpoint_path", default="./out/checkpoint.pt", type=str)
 
175
  model = GPT(**model_args, activation_checkpointing=args.activation_checkpointing)
176
 
177
  if IS_DDP:
178
+ match args.ddp_sharding_level:
179
+ case 0:
180
+ sharding_strategy = ShardingStrategy.NO_SHARD
181
+ case 2:
182
+ sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
183
+ case 3:
184
+ sharding_strategy = ShardingStrategy.FULL_SHARD
185
+
186
+ model = FullyShardedDataParallel(
187
+ model,
188
+ device_id=LOCAL_RANK,
189
+ sharding_strategy=sharding_strategy,
190
+ use_orig_params=True,
191
+ )
192
 
193
  print("Compiling model")
194
  model = torch.compile(model).to(args.device)
195
 
196
+ optimizer = Adafactor(model.parameters(), lr=args.learning_rate)
 
 
 
 
 
 
 
197
 
198
  starting_epoch = 1
199
 
 
216
 
217
  perplexity_metric = Perplexity(ignore_index=training.PADDING_INDEX).to(args.device)
218
 
219
+ register_signal_handlers()
220
 
221
  print("Pre-training ...")
222
 
 
300
  print("Checkpoint saved")
301
 
302
  if IS_DDP:
303
+ ddp_cleanup()
304
 
305
  print("Done!")
306
 
307
 
308
+ def register_signal_handlers():
309
+ signal.signal(signal.SIGINT, shutdown)
310
+ signal.signal(signal.SIGTERM, shutdown)
311
+
312
+
313
+ def shutdown(signum, frame):
314
+ print("Hold on, attempting to exit gracefully")
315
 
316
  if IS_DDP:
317
+ ddp_cleanup()
318
 
319
  sys.exit(0)
320
 
321
 
322
+ def ddp_cleanup():
323
+ destroy_process_group()
324
+
325
+
326
  if __name__ == "__main__":
327
  main()