winglian commited on
Commit
a12fb0a
·
unverified ·
1 Parent(s): a4329b1

Jeopardy bot! (#17)

Browse files

* support for jeopardy dataset

* commit the final config for jeopardy bot

configs/llama_7B_jeopardy.yml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: huggyllama/llama-7b
2
+ base_model_config: huggyllama/llama-7b
3
+ model_type: LlamaForCausalLM
4
+ tokenizer_type: LlamaTokenizer
5
+ load_in_8bit: false
6
+ datasets:
7
+ - path: openaccess-ai-collective/jeopardy
8
+ type: jeopardy
9
+ dataset_prepared_path: last_run_prepared
10
+ val_set_size: 0.01
11
+ adapter:
12
+ lora_model_dir:
13
+ sequence_len: 2048
14
+ max_packed_sequence_len: 2048
15
+ lora_r: 8
16
+ lora_alpha: 16
17
+ lora_dropout: 0.05
18
+ lora_target_modules:
19
+ - q_proj
20
+ - v_proj
21
+ lora_fan_in_fan_out: false
22
+ wandb_project: jeopardy-bot-7b
23
+ wandb_watch:
24
+ wandb_run_id:
25
+ wandb_log_model: checkpoint
26
+ output_dir: ./jeopardy-bot-7b
27
+ batch_size: 4
28
+ micro_batch_size: 1
29
+ num_epochs: 2
30
+ optimizer: adamw_bnb_8bit
31
+ torchdistx_path:
32
+ lr_scheduler: cosine
33
+ learning_rate: 0.0000002
34
+ train_on_inputs: false
35
+ group_by_length: false
36
+ bf16: true
37
+ tf32: true
38
+ early_stopping_patience:
39
+ resume_from_checkpoint:
40
+ local_rank:
41
+ logging_steps: 5
42
+ xformers_attention: true
43
+ flash_attention:
44
+ gptq_groupsize:
45
+ gptq_model_v1:
46
+ warmup_steps: 20
47
+ eval_steps: 110
48
+ save_steps: 660
49
+ debug:
50
+ deepspeed:
51
+ weight_decay: 0.0001
52
+ fsdp:
53
+ fsdp_config:
54
+ special_tokens:
55
+ pad_token: "[PAD]"
56
+ bos_token: "<s>"
57
+ eos_token: "</s>"
58
+ unk_token: "<unk>"
src/axolotl/prompt_tokenizers.py CHANGED
@@ -89,6 +89,15 @@ class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
89
  )
90
 
91
 
 
 
 
 
 
 
 
 
 
92
  class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
93
  def parse_instruction_fields(self, prompt) -> (str, str, str):
94
  return (
 
89
  )
90
 
91
 
92
+ class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
93
+ def parse_instruction_fields(self, prompt) -> (str, str, str):
94
+ return (
95
+ prompt["question"],
96
+ prompt["category"],
97
+ "what is " + prompt["answer"],
98
+ )
99
+
100
+
101
  class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
102
  def parse_instruction_fields(self, prompt) -> (str, str, str):
103
  return (
src/axolotl/prompters.py CHANGED
@@ -31,6 +31,10 @@ class AlpacaPrompter:
31
  return output.split(self.response_split)[1].strip()
32
 
33
 
 
 
 
 
34
  class GPTeacherPrompter(AlpacaPrompter):
35
  ...
36
 
 
31
  return output.split(self.response_split)[1].strip()
32
 
33
 
34
+ class JeopardyPrompter(AlpacaPrompter):
35
+ prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
36
+
37
+
38
  class GPTeacherPrompter(AlpacaPrompter):
39
  ...
40
 
src/axolotl/utils/data.py CHANGED
@@ -11,13 +11,13 @@ from axolotl.prompt_tokenizers import (
11
  GPTeacherPromptTokenizingStrategy,
12
  OpenAssistantPromptTokenizingStrategy,
13
  AlpacaReflectionPTStrategy,
14
- ShareGPTPromptTokenizingStrategy,
15
  )
16
  from axolotl.prompters import (
17
  AlpacaPrompter,
18
  GPTeacherPrompter,
19
  ReflectAlpacaPrompter,
20
- ShareGPTPrompter,
21
  )
22
 
23
 
@@ -82,6 +82,12 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
82
  )
83
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
84
  datasets.append(ds_wrapper)
 
 
 
 
 
 
85
  elif d.type == "oasst":
86
  ds_strategy = OpenAssistantPromptTokenizingStrategy(
87
  AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
 
11
  GPTeacherPromptTokenizingStrategy,
12
  OpenAssistantPromptTokenizingStrategy,
13
  AlpacaReflectionPTStrategy,
14
+ ShareGPTPromptTokenizingStrategy, JeopardyPromptTokenizingStrategy,
15
  )
16
  from axolotl.prompters import (
17
  AlpacaPrompter,
18
  GPTeacherPrompter,
19
  ReflectAlpacaPrompter,
20
+ ShareGPTPrompter, JeopardyPrompter,
21
  )
22
 
23
 
 
82
  )
83
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
84
  datasets.append(ds_wrapper)
85
+ if d.type == "jeopardy":
86
+ ds_strategy = JeopardyPromptTokenizingStrategy(
87
+ JeopardyPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
88
+ )
89
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
90
+ datasets.append(ds_wrapper)
91
  elif d.type == "oasst":
92
  ds_strategy = OpenAssistantPromptTokenizingStrategy(
93
  AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len