winglian commited on
Commit
a5d739b
·
1 Parent(s): 951facb

fixes w/ example for super basic lora starter

Browse files
examples/lora-alpaca-7b/config.yml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: huggyllama/llama-7b
2
+ base_model_config: huggyllama/llama-7b
3
+ model_type: LlamaForCausalLM
4
+ tokenizer_type: LlamaTokenizer
5
+ load_in_8bit: true
6
+ load_in_4bit: false
7
+ strict: false
8
+ push_dataset_to_hub:
9
+ datasets:
10
+ - path: teknium/GPT4-LLM-Cleaned
11
+ type: alpaca
12
+ dataset_prepared_path: last_run_prepared
13
+ val_set_size: 0.02
14
+ adapter: lora
15
+ lora_model_dir:
16
+ sequence_len: 512
17
+ max_packed_sequence_len:
18
+ lora_r: 8
19
+ lora_alpha: 16
20
+ lora_dropout: 0.0
21
+ lora_target_modules:
22
+ - gate_proj
23
+ - down_proj
24
+ - up_proj
25
+ - q_proj
26
+ - v_proj
27
+ - k_proj
28
+ - o_proj
29
+ lora_fan_in_fan_out:
30
+ wandb_project:
31
+ wandb_watch:
32
+ wandb_run_id:
33
+ wandb_log_model:
34
+ output_dir: ./lora-out
35
+ batch_size: 4
36
+ micro_batch_size: 1
37
+ num_epochs: 4
38
+ optimizer: adamw_bnb_8bit
39
+ torchdistx_path:
40
+ lr_scheduler: cosine
41
+ learning_rate: 0.0002
42
+ train_on_inputs: false
43
+ group_by_length: false
44
+ bf16: false
45
+ fp16: true
46
+ tf32: true
47
+ gradient_checkpointing: true
48
+ early_stopping_patience:
49
+ resume_from_checkpoint:
50
+ local_rank:
51
+ logging_steps: 1
52
+ xformers_attention: true
53
+ flash_attention:
54
+ gptq_groupsize:
55
+ gptq_model_v1:
56
+ warmup_steps: 10
57
+ eval_steps: 50
58
+ save_steps:
59
+ debug:
60
+ deepspeed:
61
+ weight_decay: 0.0
62
+ fsdp:
63
+ fsdp_config:
64
+ special_tokens:
65
+ bos_token: "<s>"
66
+ eos_token: "</s>"
67
+ unk_token: "<unk>"
src/axolotl/prompters.py CHANGED
@@ -18,7 +18,7 @@ class AlpacaPrompter:
18
  prompt_style = None
19
 
20
  def __init__(self, prompt_style="instruct"):
21
- self.prompt_style = prompt_style
22
  self.match_prompt_style()
23
 
24
  def match_prompt_style(self):
 
18
  prompt_style = None
19
 
20
  def __init__(self, prompt_style="instruct"):
21
+ self.prompt_style = prompt_style if prompt_style else PromptStyle.instruct.value
22
  self.match_prompt_style()
23
 
24
  def match_prompt_style(self):
src/axolotl/utils/data.py CHANGED
@@ -60,10 +60,12 @@ def load_tokenized_prepared_datasets(
60
  else Path(default_dataset_prepared_path) / ds_hash
61
  )
62
  dataset = None
 
63
  try:
64
  if cfg.push_dataset_to_hub:
 
65
  dataset = load_dataset(
66
- f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=True
67
  )
68
  dataset = dataset["train"]
69
  except:
@@ -83,7 +85,7 @@ def load_tokenized_prepared_datasets(
83
  ds = None
84
  ds_from_hub = False
85
  try:
86
- load_dataset(d.path, streaming=True, use_auth_token=True)
87
  ds_from_hub = True
88
  except FileNotFoundError:
89
  pass
@@ -99,10 +101,10 @@ def load_tokenized_prepared_datasets(
99
  d.path,
100
  streaming=False,
101
  data_files=d.data_files,
102
- use_auth_token=True,
103
  )
104
  else:
105
- ds = load_dataset(d.path, streaming=False, use_auth_token=True)
106
  else:
107
  fp = hf_hub_download(
108
  repo_id=d.path, repo_type="dataset", filename=d.data_files
 
60
  else Path(default_dataset_prepared_path) / ds_hash
61
  )
62
  dataset = None
63
+ use_auth_token = False
64
  try:
65
  if cfg.push_dataset_to_hub:
66
+ use_auth_token = True
67
  dataset = load_dataset(
68
+ f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
69
  )
70
  dataset = dataset["train"]
71
  except:
 
85
  ds = None
86
  ds_from_hub = False
87
  try:
88
+ load_dataset(d.path, streaming=True, use_auth_token=use_auth_token)
89
  ds_from_hub = True
90
  except FileNotFoundError:
91
  pass
 
101
  d.path,
102
  streaming=False,
103
  data_files=d.data_files,
104
+ use_auth_token=use_auth_token,
105
  )
106
  else:
107
+ ds = load_dataset(d.path, streaming=False, use_auth_token=use_auth_token)
108
  else:
109
  fp = hf_hub_download(
110
  repo_id=d.path, repo_type="dataset", filename=d.data_files