winglian commited on
Commit
2df63ef
·
1 Parent(s): b164725

refactor trainer setup to account for deepspeed integration

Browse files
Files changed (1) hide show
  1. scripts/finetune.py +86 -67
scripts/finetune.py CHANGED
@@ -16,7 +16,7 @@ from peft import (
16
  LoraConfig,
17
  get_peft_model,
18
  prepare_model_for_int8_training,
19
- get_peft_model_state_dict, PeftModel,
20
  )
21
  from torch import nn
22
  from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaTokenizer
@@ -214,6 +214,89 @@ def choose_config(path: Path):
214
  return chosen_file
215
 
216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  def train(
218
  config: Path = Path("configs/"),
219
  **kwargs,
@@ -308,73 +391,8 @@ def train(
308
  tokenizer,
309
  )
310
 
311
- total_num_steps = int(
312
- math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
313
- )
314
- warmup_steps = min(int(0.03 * total_num_steps), 100)
315
- logging_steps = min(int(0.005 * total_num_steps), 10)
316
- save_steps = eval_steps = min(int(0.05 * total_num_steps), 200)
317
-
318
- training_args = transformers.TrainingArguments(
319
- per_device_train_batch_size=cfg.micro_batch_size,
320
- gradient_accumulation_steps=cfg.gradient_accumulation_steps,
321
- warmup_steps=warmup_steps,
322
- num_train_epochs=cfg.num_epochs,
323
- learning_rate=cfg.learning_rate,
324
- bf16=cfg.bf16,
325
- tf32=cfg.tf32,
326
- logging_steps=logging_steps,
327
- evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
328
- save_strategy="steps",
329
- eval_steps=eval_steps if cfg.val_set_size > 0 else None,
330
- save_steps=save_steps,
331
- output_dir=cfg.output_dir,
332
- save_total_limit=3,
333
- load_best_model_at_end=True if cfg.val_set_size > 0 else False,
334
- ddp_find_unused_parameters=False if cfg.ddp else None,
335
- group_by_length=cfg.group_by_length,
336
- report_to="wandb" if cfg.use_wandb else None,
337
- run_name=cfg.wandb_run_name if cfg.use_wandb else None,
338
- )
339
-
340
- decay_parameters = get_parameter_names(model, [nn.LayerNorm])
341
- decay_parameters = [name for name in decay_parameters if "bias" not in name]
342
- optimizer_grouped_parameters = [
343
- {
344
- "params": [p for n, p in model.named_parameters() if n in decay_parameters],
345
- "weight_decay": training_args.weight_decay,
346
- },
347
- {
348
- "params": [
349
- p for n, p in model.named_parameters() if n not in decay_parameters
350
- ],
351
- "weight_decay": 0.0,
352
- },
353
- ]
354
-
355
- adam_bnb_optim = bnb.optim.Adam8bit(
356
- optimizer_grouped_parameters,
357
- betas=(training_args.adam_beta1, training_args.adam_beta2),
358
- eps=training_args.adam_epsilon,
359
- lr=training_args.learning_rate,
360
- )
361
-
362
- lr_scheduler = transformers.get_cosine_schedule_with_warmup(
363
- adam_bnb_optim,
364
- training_args.warmup_steps,
365
- total_num_steps,
366
- )
367
 
368
- trainer = transformers.Trainer(
369
- model=model,
370
- train_dataset=train_dataset,
371
- eval_dataset=eval_dataset,
372
- args=training_args,
373
- optimizers=(adam_bnb_optim, lr_scheduler),
374
- data_collator=transformers.DataCollatorForSeq2Seq(
375
- tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
376
- ),
377
- )
378
  model.config.use_cache = False
379
 
380
  if torch.__version__ >= "2" and sys.platform != "win32":
@@ -391,6 +409,7 @@ def train(
391
 
392
  trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint)
393
 
 
394
  model.save_pretrained(cfg.output_dir)
395
 
396
 
 
16
  LoraConfig,
17
  get_peft_model,
18
  prepare_model_for_int8_training,
19
+ PeftModel,
20
  )
21
  from torch import nn
22
  from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaTokenizer
 
214
  return chosen_file
215
 
216
 
217
+ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
218
+ total_num_steps = int(
219
+ math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
220
+ )
221
+ save_steps = eval_steps = min(int(0.05 * total_num_steps), 200)
222
+
223
+ training_arguments_kwargs = {}
224
+
225
+ if not cfg.deepspeed:
226
+ warmup_steps = min(int(0.03 * total_num_steps), 100)
227
+ logging_steps = min(int(0.005 * total_num_steps), 10)
228
+
229
+ training_arguments_kwargs["warmup_steps"] = warmup_steps
230
+ training_arguments_kwargs["logging_steps"] = logging_steps
231
+ training_arguments_kwargs["logging_steps"] = logging_steps
232
+ training_arguments_kwargs["bf16"] = cfg.bf16
233
+ training_arguments_kwargs["tf32"] = cfg.tf32
234
+
235
+ training_args = transformers.TrainingArguments(
236
+ per_device_train_batch_size=cfg.micro_batch_size,
237
+ gradient_accumulation_steps=cfg.gradient_accumulation_steps,
238
+ num_train_epochs=cfg.num_epochs,
239
+ learning_rate=cfg.learning_rate,
240
+ evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
241
+ save_strategy="steps",
242
+ eval_steps=eval_steps if cfg.val_set_size > 0 else None,
243
+ save_steps=save_steps,
244
+ output_dir=cfg.output_dir,
245
+ save_total_limit=3,
246
+ load_best_model_at_end=True if cfg.val_set_size > 0 else False,
247
+ ddp_find_unused_parameters=False if cfg.ddp else None,
248
+ group_by_length=cfg.group_by_length,
249
+ report_to="wandb" if cfg.use_wandb else None,
250
+ run_name=cfg.wandb_run_name if cfg.use_wandb else None,
251
+ **training_arguments_kwargs,
252
+ )
253
+
254
+ trainer_kwargs = {}
255
+
256
+ if not cfg.deepspeed:
257
+ decay_parameters = get_parameter_names(model, [nn.LayerNorm])
258
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
259
+ optimizer_grouped_parameters = [
260
+ {
261
+ "params": [p for n, p in model.named_parameters() if n in decay_parameters],
262
+ "weight_decay": training_args.weight_decay,
263
+ },
264
+ {
265
+ "params": [
266
+ p for n, p in model.named_parameters() if n not in decay_parameters
267
+ ],
268
+ "weight_decay": 0.0,
269
+ },
270
+ ]
271
+
272
+ adam_bnb_optim = bnb.optim.Adam8bit(
273
+ optimizer_grouped_parameters,
274
+ betas=(training_args.adam_beta1, training_args.adam_beta2),
275
+ eps=training_args.adam_epsilon,
276
+ lr=training_args.learning_rate,
277
+ )
278
+
279
+ lr_scheduler = transformers.get_cosine_schedule_with_warmup(
280
+ adam_bnb_optim,
281
+ training_args.warmup_steps,
282
+ total_num_steps,
283
+ )
284
+ trainer_kwargs["optimizers"] = (adam_bnb_optim, lr_scheduler)
285
+
286
+
287
+ trainer = transformers.Trainer(
288
+ model=model,
289
+ train_dataset=train_dataset,
290
+ eval_dataset=eval_dataset,
291
+ args=training_args,
292
+ data_collator=transformers.DataCollatorForSeq2Seq(
293
+ tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
294
+ ),
295
+ **trainer_kwargs,
296
+ )
297
+
298
+ return trainer
299
+
300
  def train(
301
  config: Path = Path("configs/"),
302
  **kwargs,
 
391
  tokenizer,
392
  )
393
 
394
+ trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
 
 
 
 
 
 
 
 
 
 
 
396
  model.config.use_cache = False
397
 
398
  if torch.__version__ >= "2" and sys.platform != "win32":
 
409
 
410
  trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint)
411
 
412
+ # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
413
  model.save_pretrained(cfg.output_dir)
414
 
415