winglian commited on
Commit
4e705ed
·
unverified ·
2 Parent(s): 2624bc2 4a17a4c

Merge pull request #9 from winglian/dev

Browse files
FAQS.md ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # FAQs
2
+
3
+ - Can you train StableLM with this? Yes, but only with a single GPU atm. Multi GPU support is coming soon! Just waiting on this [PR](https://github.com/huggingface/transformers/pull/22874)
4
+ - Will this work with Deepspeed? That's still a WIP, but setting `export ACCELERATE_USE_DEEPSPEED=true` should work in some cases
configs/galactica_1_3B.yml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: facebook/galactica-1.3b
2
+ model_type: AutoModelForCausalLM
3
+ tokenizer_type: AutoTokenizer
4
+ load_in_8bit: false
5
+ datasets:
6
+ - path: tatsu-lab/alpaca
7
+ type: alpaca
8
+ dataset_prepared_path: last_run_prepared
9
+ val_set_size: 0.1
10
+ adapter:
11
+ lora_model_dir:
12
+ sequence_len: 1024
13
+ max_packed_sequence_len: 1024
14
+ lora_r: 8
15
+ lora_alpha: 16
16
+ lora_dropout: 0.05
17
+ lora_target_modules:
18
+ - q_proj
19
+ - v_proj
20
+ lora_fan_in_fan_out: false
21
+ wandb_project:
22
+ wandb_watch:
23
+ wandb_run_id:
24
+ wandb_log_model: checkpoint
25
+ output_dir: ./lora-llama-alpaca
26
+ batch_size: 32
27
+ micro_batch_size: 16
28
+ num_epochs: 3
29
+ learning_rate: 0.00003
30
+ train_on_inputs: false
31
+ group_by_length: false
32
+ bf16: false
33
+ tf32: false
34
+ early_stopping_patience:
35
+ resume_from_checkpoint:
36
+ local_rank:
37
+ special_tokens:
38
+ pad_token: "[PAD]"
39
+ bos_token: "<s>"
40
+ eos_token: "</s>"
41
+ unk_token: "<unk>"
configs/llama_13B_alpaca.yml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: huggyllama/llama-13b
2
+ model_type: LlamaForCausalLM
3
+ tokenizer_type: LlamaTokenizer
4
+ load_in_8bit: true
5
+ datasets:
6
+ - path: anon8231489123/ShareGPT_Vicuna_unfiltered
7
+ data_files: ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json
8
+ type: sharegpt
9
+ dataset_prepared_path: last_run_prepared
10
+ val_set_size: 0.002
11
+ adapter:
12
+ lora_model_dir:
13
+ sequence_len: 2048
14
+ lora_r: 8
15
+ lora_alpha: 16
16
+ lora_dropout: 0.05
17
+ lora_target_modules:
18
+ - q_proj
19
+ - v_proj
20
+ lora_fan_in_fan_out: false
21
+ wandb_project:
22
+ wandb_watch:
23
+ wandb_run_id:
24
+ wandb_log_model: checkpoint
25
+ output_dir: ./llama-13b-sharegpt
26
+ batch_size: 64
27
+ micro_batch_size: 2
28
+ warmup_steps: 1000
29
+ save_steps:
30
+ eval_steps:
31
+ num_epochs: 5
32
+ learning_rate: 0.00003
33
+ train_on_inputs: false
34
+ group_by_length: false
35
+ bf16: true
36
+ tf32: true
37
+ early_stopping_patience: 5
38
+ resume_from_checkpoint:
39
+ local_rank:
configs/llama_65B_alpaca.yml CHANGED
@@ -5,7 +5,8 @@ load_in_8bit: true
5
  datasets:
6
  - path: data/alpaca_data_gpt4.jsonl
7
  type: alpaca
8
- - path: data/vicuna_cleaned.jsonl
 
9
  type: sharegpt
10
  - path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
11
  type: gpteacher
@@ -30,6 +31,8 @@ wandb_log_model: checkpoint
30
  output_dir: ./lora-llama-alpaca
31
  batch_size: 128
32
  micro_batch_size: 16
 
 
33
  num_epochs: 5
34
  learning_rate: 0.00003
35
  train_on_inputs: false
 
5
  datasets:
6
  - path: data/alpaca_data_gpt4.jsonl
7
  type: alpaca
8
+ - path: anon8231489123/ShareGPT_Vicuna_unfiltered
9
+ data_files: ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json
10
  type: sharegpt
11
  - path: data/gpt4-instruct-similarity-0.6-dataset.jsonl
12
  type: gpteacher
 
31
  output_dir: ./lora-llama-alpaca
32
  batch_size: 128
33
  micro_batch_size: 16
34
+ warmup_steps: 1000
35
+ save_steps:
36
  num_epochs: 5
37
  learning_rate: 0.00003
38
  train_on_inputs: false
configs/stability_3b.yml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: stabilityai/stablelm-base-alpha-3b
2
+ load_in_8bit: true
3
+ datasets:
4
+ - path: vicgalle/alpaca-gpt4
5
+ type: alpaca
6
+ dataset_prepared_path: last_run_prepared
7
+ val_set_size: 0.04
8
+ adapter:
9
+ lora_model_dir:
10
+ sequence_len: 4096
11
+ lora_r: 8
12
+ lora_alpha: 16
13
+ lora_dropout: 0.05
14
+ lora_target_modules:
15
+ - q_proj
16
+ - v_proj
17
+ lora_fan_in_fan_out: false
18
+ wandb_project: stable-llama-3b
19
+ wandb_watch:
20
+ wandb_run_id:
21
+ wandb_log_model: checkpoint
22
+ output_dir: ./stable-llama-3b
23
+ batch_size: 128
24
+ micro_batch_size: 16
25
+ num_epochs: 1
26
+ learning_rate: 0.00003
27
+ train_on_inputs: false
28
+ group_by_length: false
29
+ bf16: true
30
+ tf32: true
31
+ early_stopping_patience: 3
32
+ resume_from_checkpoint:
33
+ local_rank:
ds_config.json CHANGED
@@ -11,11 +11,10 @@
11
  "min_loss_scale": 1
12
  },
13
  "scheduler": {
14
- "type": "WarmupLR",
15
  "params": {
16
- "warmup_min_lr": "auto",
17
- "warmup_max_lr": "auto",
18
- "warmup_num_steps": "auto"
19
  }
20
  },
21
  "zero_optimization": {
@@ -25,7 +24,8 @@
25
  "allgather_bucket_size": 5e8,
26
  "contiguous_gradients": true,
27
  "reduce_bucket_size": "auto",
28
- "reduce_scatter": true
 
29
  },
30
  "gradient_accumulation_steps": "auto",
31
  "gradient_clipping": "auto",
 
11
  "min_loss_scale": 1
12
  },
13
  "scheduler": {
14
+ "type": "OneCycle",
15
  "params": {
16
+ "cycle_min_lr": 1e-7,
17
+ "cycle_max_lr": 1e-4
 
18
  }
19
  },
20
  "zero_optimization": {
 
24
  "allgather_bucket_size": 5e8,
25
  "contiguous_gradients": true,
26
  "reduce_bucket_size": "auto",
27
+ "reduce_scatter": true,
28
+ "stage3_gather_16bit_weights_on_model_save": true
29
  },
30
  "gradient_accumulation_steps": "auto",
31
  "gradient_clipping": "auto",
scripts/finetune.py CHANGED
@@ -159,7 +159,7 @@ def train(
159
  cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
160
  cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
161
  choose_device(cfg)
162
- cfg.ddp = cfg.world_size != 1
163
  if cfg.ddp:
164
  cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
165
  cfg.gradient_accumulation_steps = (
 
159
  cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
160
  cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
161
  choose_device(cfg)
162
+ cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
163
  if cfg.ddp:
164
  cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
165
  cfg.gradient_accumulation_steps = (
src/axolotl/datasets.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import List
2
 
3
  import torch
@@ -92,11 +93,14 @@ class ConstantLengthDataset(IterableDataset):
92
  : self.seq_length
93
  ]
94
  labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
95
- yield {
96
- "input_ids": input_ids,
97
- "labels": labels,
98
- "attention_mask": attention_mask,
99
- }
 
 
 
100
  buffer = {"input_ids": [], "attention_mask": [], "labels": []}
101
  buffer_len = 0
102
 
 
1
+ import logging
2
  from typing import List
3
 
4
  import torch
 
93
  : self.seq_length
94
  ]
95
  labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
96
+ if labels.size() == input_ids.size() and attention_mask.size() == input_ids.size():
97
+ yield {
98
+ "input_ids": input_ids,
99
+ "labels": labels,
100
+ "attention_mask": attention_mask,
101
+ }
102
+ else:
103
+ logging.warning("dropping batch due to tensor size mismatch")
104
  buffer = {"input_ids": [], "attention_mask": [], "labels": []}
105
  buffer_len = 0
106
 
src/axolotl/prompters.py CHANGED
@@ -128,6 +128,10 @@ conv_vicuna_v1_1 = Conversation(
128
 
129
  class ShareGPTPrompter:
130
  def build_prompt(self, source, tokenizer):
 
 
 
 
131
  if len(source) < 2:
132
  # If there isn't a back and forth conversation, ignore it
133
  # also happens on the data splitting leaving empty conversations
 
128
 
129
  class ShareGPTPrompter:
130
  def build_prompt(self, source, tokenizer):
131
+ # ignore the system prompt if provided
132
+ if source[0]["from"] == "system":
133
+ source.pop(0)
134
+
135
  if len(source) < 2:
136
  # If there isn't a back and forth conversation, ignore it
137
  # also happens on the data splitting leaving empty conversations
src/axolotl/utils/data.py CHANGED
@@ -2,7 +2,8 @@ import logging
2
  from hashlib import md5
3
  from pathlib import Path
4
 
5
- from datasets import load_from_disk, load_dataset, IterableDataset, Dataset
 
6
 
7
  from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
8
  from axolotl.prompt_tokenizers import (
@@ -30,7 +31,7 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
30
  ds_hash = str(
31
  md5(
32
  (
33
- str(max_packed_sequence_len)
34
  + "@"
35
  + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
36
  ).encode("utf-8")
@@ -43,13 +44,15 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
43
  )
44
 
45
  if any(prepared_ds_path.glob("*")):
46
- logging.info("Loading prepared dataset from disk...")
47
  dataset = load_from_disk(str(prepared_ds_path))
48
  logging.info("Prepared dataset loaded from disk...")
49
  else:
 
50
  logging.info("Loading raw datasets...")
51
  datasets = []
52
  for d in cfg.datasets:
 
53
  ds_from_hub = False
54
  try:
55
  load_dataset(d.path, streaming=True)
@@ -63,8 +66,14 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
63
  "json", data_files=d.path, streaming=True, split=None
64
  )
65
  elif ds_from_hub:
66
- ds = load_dataset(d.path, streaming=True)
 
 
 
67
  else:
 
 
 
68
  raise Exception("unhandled dataset load")
69
 
70
  if d.type == "alpaca":
@@ -105,20 +114,32 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
105
  datasets.append(ds_wrapper)
106
  else:
107
  logging.error(f"unhandled prompt tokenization strategy: {d.type}")
 
 
 
 
 
 
 
 
 
 
 
108
  constant_len_dataset = ConstantLengthDataset(
109
  tokenizer,
110
- datasets,
111
  seq_length=max_packed_sequence_len,
112
  )
113
- logging.info("merging, packing, shuffling, and splitting master dataset")
114
- dataset = Dataset.from_list([_ for _ in constant_len_dataset]).train_test_split(
115
- test_size=cfg.val_set_size, shuffle=True, seed=42
116
- )
117
 
118
- if cfg.local_rank == 0:
119
- logging.info(f"Saving prepared dataset to disk... {prepared_ds_path}")
120
- dataset.save_to_disk(prepared_ds_path)
121
 
 
 
 
122
  train_dataset = dataset["train"]
123
  eval_dataset = dataset["test"]
124
 
 
2
  from hashlib import md5
3
  from pathlib import Path
4
 
5
+ from datasets import load_from_disk, load_dataset, IterableDataset, Dataset, concatenate_datasets
6
+ from huggingface_hub import hf_hub_download
7
 
8
  from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
9
  from axolotl.prompt_tokenizers import (
 
31
  ds_hash = str(
32
  md5(
33
  (
34
+ str(cfg.sequence_len)
35
  + "@"
36
  + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
37
  ).encode("utf-8")
 
44
  )
45
 
46
  if any(prepared_ds_path.glob("*")):
47
+ logging.info(f"Loading prepared dataset from disk ay {prepared_ds_path}...")
48
  dataset = load_from_disk(str(prepared_ds_path))
49
  logging.info("Prepared dataset loaded from disk...")
50
  else:
51
+ logging.info(f"Unable to find prepared dataset in {prepared_ds_path}")
52
  logging.info("Loading raw datasets...")
53
  datasets = []
54
  for d in cfg.datasets:
55
+ ds = None
56
  ds_from_hub = False
57
  try:
58
  load_dataset(d.path, streaming=True)
 
66
  "json", data_files=d.path, streaming=True, split=None
67
  )
68
  elif ds_from_hub:
69
+ if d.data_files:
70
+ ds = load_dataset(d.path, streaming=True, data_files=d.data_files)
71
+ else:
72
+ ds = load_dataset(d.path, streaming=True)
73
  else:
74
+ fp = hf_hub_download(repo_id=d.path, repo_type="dataset", filename=d.data_files)
75
+ ds = load_dataset("json", data_files=fp, streaming=True, split=None)
76
+ if not ds:
77
  raise Exception("unhandled dataset load")
78
 
79
  if d.type == "alpaca":
 
114
  datasets.append(ds_wrapper)
115
  else:
116
  logging.error(f"unhandled prompt tokenization strategy: {d.type}")
117
+ logging.info("tokenizing, merging, and shuffling master dataset")
118
+
119
+ samples = []
120
+ for d in datasets:
121
+ samples = samples + [i for i in d]
122
+ dataset = Dataset.from_list(samples).shuffle(seed=42)
123
+ if cfg.local_rank == 0:
124
+ logging.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
125
+ dataset.save_to_disk(prepared_ds_path)
126
+
127
+ if cfg.max_packed_sequence_len is not None:
128
  constant_len_dataset = ConstantLengthDataset(
129
  tokenizer,
130
+ [dataset],
131
  seq_length=max_packed_sequence_len,
132
  )
133
+ logging.info(f"packing master dataset to len: {cfg.max_packed_sequence_len}")
134
+ dataset = Dataset.from_list([_ for _ in constant_len_dataset])
 
 
135
 
136
+ if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
137
+ logging.info(f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards")
138
+ dataset = dataset.shard(num_shards=cfg.dataset_shard_num, index=cfg.dataset_shard_idx)
139
 
140
+ dataset = dataset.train_test_split(
141
+ test_size=cfg.val_set_size, shuffle=False
142
+ )
143
  train_dataset = dataset["train"]
144
  eval_dataset = dataset["test"]
145
 
src/axolotl/utils/models.py CHANGED
@@ -7,11 +7,16 @@ import torch
7
  import transformers
8
  from transformers import (
9
  AutoModelForCausalLM,
10
- LlamaForCausalLM,
11
- LlamaTokenizer,
12
  AutoTokenizer,
13
  PreTrainedModel,
14
  )
 
 
 
 
 
 
 
15
 
16
  from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
17
 
@@ -70,7 +75,7 @@ def load_model(
70
  snapshot_download_kwargs = {}
71
  if cfg.base_model_ignore_patterns:
72
  snapshot_download_kwargs["ignore_patterns"] = cfg.base_model_ignore_patterns
73
- cache_model_path = Path(snapshot_download(base_model, ** snapshot_download_kwargs))
74
  files = (
75
  list(cache_model_path.glob("*.pt"))
76
  + list(cache_model_path.glob("*.safetensors"))
@@ -95,15 +100,29 @@ def load_model(
95
  else True,
96
  )
97
  load_in_8bit = False
98
- elif is_llama_derived_model:
99
- model = LlamaForCausalLM.from_pretrained(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  base_model,
101
  load_in_8bit=cfg.load_in_8bit,
102
  torch_dtype=torch_dtype,
103
  device_map=cfg.device_map,
104
  )
105
  else:
106
- model = getattr(transformers, model_type).from_pretrained(
107
  base_model,
108
  load_in_8bit=cfg.load_in_8bit,
109
  torch_dtype=torch_dtype,
@@ -123,7 +142,7 @@ def load_model(
123
 
124
  if not tokenizer:
125
  try:
126
- if is_llama_derived_model:
127
  tokenizer = LlamaTokenizer.from_pretrained(model)
128
  else:
129
  tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
@@ -142,13 +161,17 @@ def load_model(
142
  tokenizer.add_special_tokens({"pad_token": "[PAD]"})
143
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
144
 
 
 
 
 
145
  if load_in_8bit and not cfg.load_4bit:
146
  logging.info("converting model w/ prepare_model_for_int8_training")
147
  model = prepare_model_for_int8_training(model)
148
 
149
  model, lora_config = load_adapter(model, cfg, adapter)
150
 
151
- if cfg.ddp:
152
  model.to(f"cuda:{cfg.local_rank}")
153
 
154
  if cfg.load_4bit:
 
7
  import transformers
8
  from transformers import (
9
  AutoModelForCausalLM,
 
 
10
  AutoTokenizer,
11
  PreTrainedModel,
12
  )
13
+ try:
14
+ from transformers import (
15
+ LlamaForCausalLM,
16
+ LlamaTokenizer,
17
+ )
18
+ except:
19
+ logging.warning("This version of transformers does not support Llama. Consider upgrading.")
20
 
21
  from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
22
 
 
75
  snapshot_download_kwargs = {}
76
  if cfg.base_model_ignore_patterns:
77
  snapshot_download_kwargs["ignore_patterns"] = cfg.base_model_ignore_patterns
78
+ cache_model_path = Path(snapshot_download(base_model, **snapshot_download_kwargs))
79
  files = (
80
  list(cache_model_path.glob("*.pt"))
81
  + list(cache_model_path.glob("*.safetensors"))
 
100
  else True,
101
  )
102
  load_in_8bit = False
103
+ elif is_llama_derived_model and "LlamaForCausalLM" in globals():
104
+ if not cfg.load_in_8bit:
105
+ model = LlamaForCausalLM.from_pretrained(
106
+ base_model,
107
+ device_map=cfg.device_map,
108
+ )
109
+ else:
110
+ model = LlamaForCausalLM.from_pretrained(
111
+ base_model,
112
+ load_in_8bit=cfg.load_in_8bit,
113
+ torch_dtype=torch_dtype,
114
+ device_map=cfg.device_map,
115
+ )
116
+
117
+ elif model_type:
118
+ model = getattr(transformers, model_type).from_pretrained(
119
  base_model,
120
  load_in_8bit=cfg.load_in_8bit,
121
  torch_dtype=torch_dtype,
122
  device_map=cfg.device_map,
123
  )
124
  else:
125
+ model = AutoModelForCausalLM.from_pretrained(
126
  base_model,
127
  load_in_8bit=cfg.load_in_8bit,
128
  torch_dtype=torch_dtype,
 
142
 
143
  if not tokenizer:
144
  try:
145
+ if is_llama_derived_model and "LlamaTokenizer" in globals():
146
  tokenizer = LlamaTokenizer.from_pretrained(model)
147
  else:
148
  tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
 
161
  tokenizer.add_special_tokens({"pad_token": "[PAD]"})
162
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
163
 
164
+ if cfg.special_tokens:
165
+ for k, v in cfg.special_tokens.items():
166
+ setattr(tokenizer, k, v)
167
+
168
  if load_in_8bit and not cfg.load_4bit:
169
  logging.info("converting model w/ prepare_model_for_int8_training")
170
  model = prepare_model_for_int8_training(model)
171
 
172
  model, lora_config = load_adapter(model, cfg, adapter)
173
 
174
+ if cfg.ddp and not load_in_8bit:
175
  model.to(f"cuda:{cfg.local_rank}")
176
 
177
  if cfg.load_4bit:
src/axolotl/utils/trainer.py CHANGED
@@ -1,5 +1,9 @@
1
  import math
 
 
 
2
  import bitsandbytes as bnb
 
3
  import transformers
4
  from torch import nn
5
  from torch.optim.lr_scheduler import OneCycleLR
@@ -12,7 +16,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
12
  math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
13
  )
14
  warmup_steps = cfg.warmup_steps if cfg.warmup_steps else min(int(0.03 * total_num_steps), 100)
15
- logging_steps = max(min(int(0.005 * total_num_steps), 10), 1)
16
  save_steps = eval_steps = cfg.save_steps if cfg.save_steps else min(int(0.05 * total_num_steps), 200)
17
 
18
  training_arguments_kwargs = {}
@@ -26,6 +30,15 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
26
  if cfg.gradient_checkpointing is not None:
27
  training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
28
 
 
 
 
 
 
 
 
 
 
29
  training_args = transformers.TrainingArguments(
30
  per_device_train_batch_size=cfg.micro_batch_size,
31
  gradient_accumulation_steps=cfg.gradient_accumulation_steps,
@@ -37,7 +50,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
37
  save_steps=save_steps,
38
  output_dir=cfg.output_dir,
39
  save_total_limit=3,
40
- load_best_model_at_end=True if cfg.val_set_size > 0 else False,
41
  ddp_find_unused_parameters=False if cfg.ddp else None,
42
  group_by_length=cfg.group_by_length,
43
  report_to="wandb" if cfg.use_wandb else None,
@@ -47,7 +60,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
47
 
48
  trainer_kwargs = {}
49
 
50
- if cfg.load_in_8bit and not cfg.load_4bit:
51
  decay_parameters = get_parameter_names(model, [nn.LayerNorm])
52
  decay_parameters = [name for name in decay_parameters if "bias" not in name]
53
  optimizer_grouped_parameters = [
@@ -94,13 +107,22 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
94
  )
95
  trainer_kwargs["callbacks"] = [early_stop_cb]
96
 
 
 
 
 
 
 
 
97
  trainer = transformers.Trainer(
98
  model=model,
99
  train_dataset=train_dataset,
100
  eval_dataset=eval_dataset,
101
  args=training_args,
102
  data_collator=transformers.DataCollatorForSeq2Seq(
103
- tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
 
 
104
  ),
105
  **trainer_kwargs,
106
  )
 
1
  import math
2
+ import os
3
+ from pathlib import Path
4
+
5
  import bitsandbytes as bnb
6
+ import torch.cuda
7
  import transformers
8
  from torch import nn
9
  from torch.optim.lr_scheduler import OneCycleLR
 
16
  math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
17
  )
18
  warmup_steps = cfg.warmup_steps if cfg.warmup_steps else min(int(0.03 * total_num_steps), 100)
19
+ logging_steps = cfg.logging_steps if cfg.logging_steps else max(min(int(0.005 * total_num_steps), 10), 1)
20
  save_steps = eval_steps = cfg.save_steps if cfg.save_steps else min(int(0.05 * total_num_steps), 200)
21
 
22
  training_arguments_kwargs = {}
 
30
  if cfg.gradient_checkpointing is not None:
31
  training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
32
 
33
+ # deepspeed
34
+ if os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true" and torch.cuda.device_count() > 1:
35
+ if cfg.deepspeed:
36
+ training_arguments_kwargs["deepspeed"] = cfg.deepspeed
37
+ else:
38
+ # make a guess here
39
+ # TODO search Path("./") for one
40
+ training_arguments_kwargs["deepspeed"] = "./ds_config.json"
41
+
42
  training_args = transformers.TrainingArguments(
43
  per_device_train_batch_size=cfg.micro_batch_size,
44
  gradient_accumulation_steps=cfg.gradient_accumulation_steps,
 
50
  save_steps=save_steps,
51
  output_dir=cfg.output_dir,
52
  save_total_limit=3,
53
+ load_best_model_at_end=True if cfg.val_set_size > 0 and save_steps % eval_steps == 0 else False,
54
  ddp_find_unused_parameters=False if cfg.ddp else None,
55
  group_by_length=cfg.group_by_length,
56
  report_to="wandb" if cfg.use_wandb else None,
 
60
 
61
  trainer_kwargs = {}
62
 
63
+ if cfg.optimizer == "adam8bit" and not cfg.load_4bit and not "deepspeed" in training_arguments_kwargs:
64
  decay_parameters = get_parameter_names(model, [nn.LayerNorm])
65
  decay_parameters = [name for name in decay_parameters if "bias" not in name]
66
  optimizer_grouped_parameters = [
 
107
  )
108
  trainer_kwargs["callbacks"] = [early_stop_cb]
109
 
110
+ data_collator_kwargs = {
111
+ "padding": True,
112
+ }
113
+ if cfg.collator_pad_to_longest:
114
+ data_collator_kwargs["padding"] = "longest"
115
+ else:
116
+ data_collator_kwargs["pad_to_multiple_of"] = 8
117
  trainer = transformers.Trainer(
118
  model=model,
119
  train_dataset=train_dataset,
120
  eval_dataset=eval_dataset,
121
  args=training_args,
122
  data_collator=transformers.DataCollatorForSeq2Seq(
123
+ tokenizer,
124
+ return_tensors="pt",
125
+ **data_collator_kwargs,
126
  ),
127
  **trainer_kwargs,
128
  )