ehartford winglian Nanobit commited on
Commit
e0f1895
·
unverified ·
1 Parent(s): 8984bf1

add starcoder2 (#1349)

Browse files

* add starcoder2

* Apply suggestions from code review

Co-authored-by: NanoCode012 <[email protected]>

* chore: lint

* Apply suggestions from code review

Co-authored-by: NanoCode012 <[email protected]>

---------

Co-authored-by: Wing Lian <[email protected]>
Co-authored-by: NanoCode012 <[email protected]>

examples/starcoder2/qlora.yml ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: bigcode/starcoder2-3b
2
+
3
+ load_in_8bit: false
4
+ load_in_4bit: true
5
+ strict: false
6
+
7
+ datasets:
8
+ - path: mhenrichsen/alpaca_2k_test
9
+ type: alpaca
10
+
11
+
12
+ dataset_prepared_path:
13
+ val_set_size: 0.2
14
+ output_dir: ./qlora
15
+
16
+ adapter: qlora
17
+ lora_model_dir:
18
+
19
+ sequence_len: 8192
20
+ sample_packing: true
21
+ pad_to_sequence_len: true
22
+
23
+ lora_r: 32
24
+ lora_alpha: 16
25
+ lora_dropout: 0.05
26
+ lora_target_modules:
27
+ lora_target_linear: true
28
+ lora_fan_in_fan_out:
29
+
30
+ wandb_project:
31
+ wandb_entity:
32
+ wandb_watch:
33
+ wandb_run_id:
34
+ wandb_log_model:
35
+
36
+ gradient_accumulation_steps: 8
37
+ micro_batch_size: 2
38
+ num_epochs: 3
39
+ optimizer: adamw_bnb_8bit
40
+ lr_scheduler: cosine
41
+ learning_rate: 2e-5
42
+
43
+ train_on_inputs: false
44
+ group_by_length: false
45
+ bf16: auto
46
+ fp16: false
47
+ tf32: false
48
+
49
+ gradient_checkpointing: true
50
+ early_stopping_patience:
51
+ resume_from_checkpoint:
52
+ local_rank:
53
+ logging_steps: 1
54
+ xformers_attention:
55
+ flash_attention: true
56
+
57
+ warmup_steps: 20
58
+ evals_per_epoch: 4
59
+ eval_steps:
60
+ eval_table_size:
61
+ saves_per_epoch: 4
62
+ save_steps:
63
+ save_total_limit: 2
64
+ debug:
65
+ deepspeed:
66
+ weight_decay:
67
+ fsdp:
68
+ fsdp_config:
69
+ special_tokens:
src/axolotl/monkeypatch/multipack.py CHANGED
@@ -6,7 +6,14 @@ from transformers.integrations import is_deepspeed_zero3_enabled
6
  from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
7
  from axolotl.monkeypatch.utils import get_unpad_data
8
 
9
- SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi", "gemma"]
 
 
 
 
 
 
 
10
 
11
 
12
  def patch_for_multipack(model_type):
@@ -32,3 +39,7 @@ def patch_for_multipack(model_type):
32
  transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
33
  get_unpad_data
34
  )
 
 
 
 
 
6
  from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
7
  from axolotl.monkeypatch.utils import get_unpad_data
8
 
9
+ SUPPORTED_MULTIPACK_MODEL_TYPES = [
10
+ "mixtral",
11
+ "qwen2",
12
+ "falcon",
13
+ "phi",
14
+ "gemma",
15
+ "starcoder2",
16
+ ]
17
 
18
 
19
  def patch_for_multipack(model_type):
 
39
  transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
40
  get_unpad_data
41
  )
42
+ elif model_type == "starcoder2":
43
+ transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
44
+ get_unpad_data
45
+ )