Nanobit commited on
Commit
37293dc
·
1 Parent(s): 96e8378

Apply isort then black

Browse files
scripts/alpaca_json_to_jsonl.py CHANGED
@@ -2,23 +2,20 @@
2
 
3
  import os
4
  import sys
5
-
6
- from typing import Optional, Union
7
  from pathlib import Path
 
8
 
9
  import fire
10
 
11
-
12
  from axolotl.convert import (
13
  FileReader,
14
- StdoutWriter,
15
  FileWriter,
16
  JsonlSerializer,
17
  JsonParser,
18
  JsonToJsonlConverter,
 
19
  )
20
 
21
-
22
  # add src to the pythonpath so we don't need to pip install this
23
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
24
  src_dir = os.path.join(project_root, "src")
 
2
 
3
  import os
4
  import sys
 
 
5
  from pathlib import Path
6
+ from typing import Optional, Union
7
 
8
  import fire
9
 
 
10
  from axolotl.convert import (
11
  FileReader,
 
12
  FileWriter,
13
  JsonlSerializer,
14
  JsonParser,
15
  JsonToJsonlConverter,
16
+ StdoutWriter,
17
  )
18
 
 
19
  # add src to the pythonpath so we don't need to pip install this
20
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
21
  src_dir = os.path.join(project_root, "src")
scripts/finetune.py CHANGED
@@ -7,20 +7,20 @@ import random
7
  import signal
8
  import sys
9
  from pathlib import Path
10
- from typing import Optional, List, Dict, Any, Union
11
 
12
  import fire
13
  import torch
14
  import yaml
15
 
16
- # add src to the pythonpath so we don't need to pip install this
17
- from axolotl.utils.tokenization import check_dataset_labels
18
- from axolotl.utils.validation import validate_config
19
- from axolotl.utils.dict import DictDefault
20
-
21
  from axolotl.utils.data import load_prepare_datasets
 
22
  from axolotl.utils.models import load_model, load_tokenizer
 
 
 
23
  from axolotl.utils.trainer import setup_trainer
 
24
  from axolotl.utils.wandb import setup_wandb_env_vars
25
 
26
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -242,7 +242,10 @@ def train(
242
  if cfg.local_rank == 0:
243
  signal.signal(
244
  signal.SIGINT,
245
- lambda signal, frame: (model.save_pretrained(cfg.output_dir), sys.exit(0)),
 
 
 
246
  )
247
 
248
  logging.info("Starting trainer...")
@@ -255,7 +258,8 @@ def train(
255
  ]
256
  if len(possible_checkpoints) > 0:
257
  sorted_paths = sorted(
258
- possible_checkpoints, key=lambda path: int(path.split("-")[-1])
 
259
  )
260
  resume_from_checkpoint = sorted_paths[-1]
261
  logging.info(
 
7
  import signal
8
  import sys
9
  from pathlib import Path
10
+ from typing import Any, Dict, List, Optional, Union
11
 
12
  import fire
13
  import torch
14
  import yaml
15
 
 
 
 
 
 
16
  from axolotl.utils.data import load_prepare_datasets
17
+ from axolotl.utils.dict import DictDefault
18
  from axolotl.utils.models import load_model, load_tokenizer
19
+
20
+ # add src to the pythonpath so we don't need to pip install this
21
+ from axolotl.utils.tokenization import check_dataset_labels
22
  from axolotl.utils.trainer import setup_trainer
23
+ from axolotl.utils.validation import validate_config
24
  from axolotl.utils.wandb import setup_wandb_env_vars
25
 
26
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
 
242
  if cfg.local_rank == 0:
243
  signal.signal(
244
  signal.SIGINT,
245
+ lambda signal, frame: (
246
+ model.save_pretrained(cfg.output_dir),
247
+ sys.exit(0),
248
+ ),
249
  )
250
 
251
  logging.info("Starting trainer...")
 
258
  ]
259
  if len(possible_checkpoints) > 0:
260
  sorted_paths = sorted(
261
+ possible_checkpoints,
262
+ key=lambda path: int(path.split("-")[-1]),
263
  )
264
  resume_from_checkpoint = sorted_paths[-1]
265
  logging.info(
setup.py CHANGED
@@ -1,6 +1,6 @@
1
  """setup.py for axolotl"""
2
 
3
- from setuptools import setup, find_packages
4
 
5
  install_requires = []
6
  with open("./requirements.txt", encoding="utf-8") as requirements_file:
 
1
  """setup.py for axolotl"""
2
 
3
+ from setuptools import find_packages, setup
4
 
5
  install_requires = []
6
  with open("./requirements.txt", encoding="utf-8") as requirements_file:
src/axolotl/datasets.py CHANGED
@@ -5,8 +5,8 @@ from typing import List
5
 
6
  import torch
7
  from datasets import IterableDataset
8
- from .prompt_tokenizers import PromptTokenizingStrategy, InvalidDataException
9
 
 
10
 
11
  # We want this to be a wrapper for an existing dataset that we have loaded
12
  # lets use the concept of middlewares to wrap each dataset, for example
@@ -114,7 +114,11 @@ class ConstantLengthDataset(IterableDataset):
114
  logging.warning(
115
  f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
116
  )
117
- buffer = {"input_ids": [], "attention_mask": [], "labels": []}
 
 
 
 
118
  buffer_len = 0
119
 
120
  if example:
 
5
 
6
  import torch
7
  from datasets import IterableDataset
 
8
 
9
+ from .prompt_tokenizers import InvalidDataException, PromptTokenizingStrategy
10
 
11
  # We want this to be a wrapper for an existing dataset that we have loaded
12
  # lets use the concept of middlewares to wrap each dataset, for example
 
114
  logging.warning(
115
  f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
116
  )
117
+ buffer = {
118
+ "input_ids": [],
119
+ "attention_mask": [],
120
+ "labels": [],
121
+ }
122
  buffer_len = 0
123
 
124
  if example:
src/axolotl/flash_attn.py CHANGED
@@ -5,14 +5,11 @@
5
  from typing import Optional, Tuple
6
 
7
  import torch
8
-
9
  import transformers
10
- from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
11
-
12
  from einops import rearrange
13
-
14
  from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
15
- from flash_attn.bert_padding import unpad_input, pad_input
16
 
17
 
18
  def forward(
@@ -75,7 +72,11 @@ def forward(
75
  qkv = rearrange(qkv, "b s ... -> (b s) ...")
76
  max_s = q_len
77
  cu_q_lens = torch.arange(
78
- 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
 
 
 
 
79
  )
80
  output = flash_attn_unpadded_qkvpacked_func(
81
  qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
@@ -88,25 +89,44 @@ def forward(
88
  x = rearrange(qkv, "b s three h d -> b s (three h d)")
89
  x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
90
  x_unpad = rearrange(
91
- x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
 
 
 
92
  )
93
  output_unpad = flash_attn_unpadded_qkvpacked_func(
94
- x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
 
 
 
 
 
95
  )
96
  output = rearrange(
97
  pad_input(
98
- rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
 
 
 
99
  ),
100
  "b s (h d) -> b s h d",
101
  h=nheads,
102
  )
103
- return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
 
 
 
 
104
 
105
 
106
  # Disable the transformation of the attention mask in LlamaModel as the flash attention
107
  # requires the attention mask to be the same as the key_padding_mask
108
  def _prepare_decoder_attention_mask(
109
- self, attention_mask, input_shape, inputs_embeds, past_key_values_length
 
 
 
 
110
  ): # pylint: disable=unused-argument
111
  # [bsz, seq_len]
112
  return attention_mask
 
5
  from typing import Optional, Tuple
6
 
7
  import torch
 
8
  import transformers
 
 
9
  from einops import rearrange
10
+ from flash_attn.bert_padding import pad_input, unpad_input
11
  from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
12
+ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
13
 
14
 
15
  def forward(
 
72
  qkv = rearrange(qkv, "b s ... -> (b s) ...")
73
  max_s = q_len
74
  cu_q_lens = torch.arange(
75
+ 0,
76
+ (bsz + 1) * q_len,
77
+ step=q_len,
78
+ dtype=torch.int32,
79
+ device=qkv.device,
80
  )
81
  output = flash_attn_unpadded_qkvpacked_func(
82
  qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
 
89
  x = rearrange(qkv, "b s three h d -> b s (three h d)")
90
  x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
91
  x_unpad = rearrange(
92
+ x_unpad,
93
+ "nnz (three h d) -> nnz three h d",
94
+ three=3,
95
+ h=nheads,
96
  )
97
  output_unpad = flash_attn_unpadded_qkvpacked_func(
98
+ x_unpad,
99
+ cu_q_lens,
100
+ max_s,
101
+ 0.0,
102
+ softmax_scale=None,
103
+ causal=True,
104
  )
105
  output = rearrange(
106
  pad_input(
107
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"),
108
+ indices,
109
+ bsz,
110
+ q_len,
111
  ),
112
  "b s (h d) -> b s h d",
113
  h=nheads,
114
  )
115
+ return (
116
+ self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
117
+ None,
118
+ None,
119
+ )
120
 
121
 
122
  # Disable the transformation of the attention mask in LlamaModel as the flash attention
123
  # requires the attention mask to be the same as the key_padding_mask
124
  def _prepare_decoder_attention_mask(
125
+ self,
126
+ attention_mask,
127
+ input_shape,
128
+ inputs_embeds,
129
+ past_key_values_length,
130
  ): # pylint: disable=unused-argument
131
  # [bsz, seq_len]
132
  return attention_mask
src/axolotl/prompt_strategies/alpaca_chat.py CHANGED
@@ -1,6 +1,7 @@
1
  """Module containing the AlpacaQAPromptTokenizingStrategy class"""
2
 
3
  from typing import Tuple
 
4
  from axolotl.prompt_tokenizers import (
5
  AlpacaPromptTokenizingStrategy,
6
  InstructionPromptTokenizingStrategy,
 
1
  """Module containing the AlpacaQAPromptTokenizingStrategy class"""
2
 
3
  from typing import Tuple
4
+
5
  from axolotl.prompt_tokenizers import (
6
  AlpacaPromptTokenizingStrategy,
7
  InstructionPromptTokenizingStrategy,
src/axolotl/prompt_strategies/creative_acr.py CHANGED
@@ -1,8 +1,9 @@
1
  """Module loading the CreativePromptTokenizingStrategy and similar classes"""
2
 
3
- from typing import Tuple, Union, Generator
4
 
5
  import yaml
 
6
  from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
7
 
8
 
@@ -61,10 +62,14 @@ Answer: {answer}
61
 
62
  def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
63
  scores = yaml.dump(
64
- prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper
 
 
65
  )
66
  critiques = yaml.dump(
67
- prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper
 
 
68
  )
69
  evaluation = scores + critiques
70
  question = prompt["instruction"]
@@ -97,10 +102,14 @@ Evaluation:
97
 
98
  def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
99
  scores = yaml.dump(
100
- prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper
 
 
101
  )
102
  critiques = yaml.dump(
103
- prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper
 
 
104
  )
105
  evaluation = scores + critiques
106
  question = prompt["instruction"]
@@ -165,17 +174,26 @@ class CreativeRevisePrompter(CreativePrompterBase):
165
 
166
  def load_answer(tokenizer, cfg):
167
  return CreativeAnsweringPromptTokenizingStrategy(
168
- CreativeAnswerPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
 
 
 
169
  )
170
 
171
 
172
  def load_critique(tokenizer, cfg):
173
  return CreativeCritiquePromptTokenizingStrategy(
174
- CreativeCritiquePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
 
 
 
175
  )
176
 
177
 
178
  def load_revise(tokenizer, cfg):
179
  return CreativeRevisePromptTokenizingStrategy(
180
- CreativeRevisePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
 
 
 
181
  )
 
1
  """Module loading the CreativePromptTokenizingStrategy and similar classes"""
2
 
3
+ from typing import Generator, Tuple, Union
4
 
5
  import yaml
6
+
7
  from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
8
 
9
 
 
62
 
63
  def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
64
  scores = yaml.dump(
65
+ prompt["scores"],
66
+ default_flow_style=False,
67
+ Dumper=yaml.Dumper,
68
  )
69
  critiques = yaml.dump(
70
+ prompt["critiques"],
71
+ default_flow_style=False,
72
+ Dumper=yaml.Dumper,
73
  )
74
  evaluation = scores + critiques
75
  question = prompt["instruction"]
 
102
 
103
  def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
104
  scores = yaml.dump(
105
+ prompt["scores"],
106
+ default_flow_style=False,
107
+ Dumper=yaml.Dumper,
108
  )
109
  critiques = yaml.dump(
110
+ prompt["critiques"],
111
+ default_flow_style=False,
112
+ Dumper=yaml.Dumper,
113
  )
114
  evaluation = scores + critiques
115
  question = prompt["instruction"]
 
174
 
175
  def load_answer(tokenizer, cfg):
176
  return CreativeAnsweringPromptTokenizingStrategy(
177
+ CreativeAnswerPrompter(),
178
+ tokenizer,
179
+ cfg.train_on_inputs,
180
+ cfg.sequence_len,
181
  )
182
 
183
 
184
  def load_critique(tokenizer, cfg):
185
  return CreativeCritiquePromptTokenizingStrategy(
186
+ CreativeCritiquePrompter(),
187
+ tokenizer,
188
+ cfg.train_on_inputs,
189
+ cfg.sequence_len,
190
  )
191
 
192
 
193
  def load_revise(tokenizer, cfg):
194
  return CreativeRevisePromptTokenizingStrategy(
195
+ CreativeRevisePrompter(),
196
+ tokenizer,
197
+ cfg.train_on_inputs,
198
+ cfg.sequence_len,
199
  )
src/axolotl/prompt_tokenizers.py CHANGED
@@ -347,7 +347,9 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
347
  part = part[0] + part[1] if not user_token else part[1]
348
  # this is still the user query, we should
349
  res = self._tokenize(
350
- part.strip(), add_eos_token=False, strip_bos_token=True
 
 
351
  )
352
  if user_token:
353
  res["input_ids"] = [user_token, *res["input_ids"]]
@@ -358,10 +360,15 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
358
  part = part[0] + part[1] if not assistant_token else part[1]
359
  # this should be the assistent response, should end with an eos token
360
  res = self._tokenize(
361
- part.strip(), add_eos_token=True, strip_bos_token=True
 
 
362
  )
363
  if assistant_token:
364
- res["input_ids"] = [assistant_token, *res["input_ids"]]
 
 
 
365
  # not masked out from labels
366
  labels = copy.deepcopy(res["input_ids"])
367
  else:
 
347
  part = part[0] + part[1] if not user_token else part[1]
348
  # this is still the user query, we should
349
  res = self._tokenize(
350
+ part.strip(),
351
+ add_eos_token=False,
352
+ strip_bos_token=True,
353
  )
354
  if user_token:
355
  res["input_ids"] = [user_token, *res["input_ids"]]
 
360
  part = part[0] + part[1] if not assistant_token else part[1]
361
  # this should be the assistent response, should end with an eos token
362
  res = self._tokenize(
363
+ part.strip(),
364
+ add_eos_token=True,
365
+ strip_bos_token=True,
366
  )
367
  if assistant_token:
368
+ res["input_ids"] = [
369
+ assistant_token,
370
+ *res["input_ids"],
371
+ ]
372
  # not masked out from labels
373
  labels = copy.deepcopy(res["input_ids"])
374
  else:
src/axolotl/prompters.py CHANGED
@@ -2,8 +2,8 @@
2
 
3
  import dataclasses
4
  import logging
5
- from enum import auto, Enum
6
- from typing import List, Optional, Union, Generator
7
 
8
  IGNORE_TOKEN_ID = -100
9
 
@@ -203,7 +203,9 @@ class ReflectAlpacaPrompter:
203
  res = self.prompt_no_input.format(instruction=instruction)
204
  if output and reflection and corrected:
205
  label = self.agent_label.format(
206
- output=output, reflection=reflection, corrected=corrected
 
 
207
  )
208
  res = f"{res}{label}"
209
  yield res
 
2
 
3
  import dataclasses
4
  import logging
5
+ from enum import Enum, auto
6
+ from typing import Generator, List, Optional, Union
7
 
8
  IGNORE_TOKEN_ID = -100
9
 
 
203
  res = self.prompt_no_input.format(instruction=instruction)
204
  if output and reflection and corrected:
205
  label = self.agent_label.format(
206
+ output=output,
207
+ reflection=reflection,
208
+ corrected=corrected,
209
  )
210
  res = f"{res}{label}"
211
  yield res
src/axolotl/utils/callbacks.py CHANGED
@@ -4,9 +4,9 @@ import os
4
 
5
  from transformers import (
6
  TrainerCallback,
7
- TrainingArguments,
8
- TrainerState,
9
  TrainerControl,
 
 
10
  )
11
  from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
12
 
@@ -22,7 +22,8 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
22
  **kwargs,
23
  ):
24
  checkpoint_folder = os.path.join(
25
- args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
 
26
  )
27
 
28
  peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
 
4
 
5
  from transformers import (
6
  TrainerCallback,
 
 
7
  TrainerControl,
8
+ TrainerState,
9
+ TrainingArguments,
10
  )
11
  from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
12
 
 
22
  **kwargs,
23
  ):
24
  checkpoint_folder = os.path.join(
25
+ args.output_dir,
26
+ f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
27
  )
28
 
29
  peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
src/axolotl/utils/data.py CHANGED
@@ -5,38 +5,33 @@ from hashlib import md5
5
  from pathlib import Path
6
  from typing import List, Tuple, Union
7
 
8
- from datasets import (
9
- load_from_disk,
10
- load_dataset,
11
- Dataset,
12
- DatasetDict,
13
- )
14
  from huggingface_hub import hf_hub_download
15
  from transformers import PreTrainedTokenizerBase
16
 
17
- from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
18
  from axolotl.prompt_strategies import load
19
  from axolotl.prompt_tokenizers import (
 
20
  AlpacaPromptTokenizingStrategy,
 
 
21
  GPTeacherPromptTokenizingStrategy,
 
22
  OpenAssistantPromptTokenizingStrategy,
23
- AlpacaReflectionPTStrategy,
24
  ShareGPTPromptTokenizingStrategy,
25
- JeopardyPromptTokenizingStrategy,
26
- CompletionPromptTokenizingStrategy,
27
- AlpacaMultipleChoicePromptTokenizingStrategy,
28
  SummarizeTLDRPromptTokenizingStrategy,
29
  )
30
  from axolotl.prompters import (
31
  AlpacaPrompter,
 
32
  GPTeacherPrompter,
33
- ReflectAlpacaPrompter,
34
- ShareGPTPrompter,
35
  JeopardyPrompter,
36
- CompletionPrompter,
37
  MultipleChoiceExplainPrompter,
 
 
38
  SummarizeTLDRPrompter,
39
- MultipleChoiceConcisePrompter,
40
  )
41
 
42
 
@@ -67,7 +62,8 @@ def load_tokenized_prepared_datasets(
67
  try:
68
  if cfg.push_dataset_to_hub:
69
  dataset = load_dataset(
70
- f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token
 
71
  )
72
  dataset = dataset["train"]
73
  except Exception: # pylint: disable=broad-except
@@ -88,7 +84,11 @@ def load_tokenized_prepared_datasets(
88
  ds: Union[Dataset, DatasetDict] = None
89
  ds_from_hub = False
90
  try:
91
- load_dataset(d.path, streaming=True, use_auth_token=use_auth_token)
 
 
 
 
92
  ds_from_hub = True
93
  except FileNotFoundError:
94
  pass
@@ -96,7 +96,10 @@ def load_tokenized_prepared_datasets(
96
  # prefer local dataset, even if hub exists
97
  if Path(d.path).exists():
98
  ds = load_dataset(
99
- "json", data_files=d.path, streaming=False, split=None
 
 
 
100
  )
101
  elif ds_from_hub:
102
  if d.data_files:
@@ -108,11 +111,15 @@ def load_tokenized_prepared_datasets(
108
  )
109
  else:
110
  ds = load_dataset(
111
- d.path, streaming=False, use_auth_token=use_auth_token
 
 
112
  )
113
  else:
114
  fp = hf_hub_download(
115
- repo_id=d.path, repo_type="dataset", filename=d.data_files
 
 
116
  )
117
  ds = load_dataset("json", data_files=fp, streaming=False, split=None)
118
  if not ds:
@@ -249,7 +256,9 @@ def load_tokenized_prepared_datasets(
249
 
250
 
251
  def load_prepare_datasets(
252
- tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path
 
 
253
  ) -> Tuple[Dataset, Dataset]:
254
  max_packed_sequence_len = (
255
  cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
@@ -353,7 +362,8 @@ def load_prepare_datasets(
353
  f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
354
  )
355
  dataset.push_to_hub(
356
- f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
 
357
  )
358
  else:
359
  dataset = load_tokenized_prepared_datasets(
@@ -365,7 +375,8 @@ def load_prepare_datasets(
365
  f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards"
366
  )
367
  dataset = dataset.shard(
368
- num_shards=cfg.dataset_shard_num, index=cfg.dataset_shard_idx
 
369
  )
370
 
371
  dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
 
5
  from pathlib import Path
6
  from typing import List, Tuple, Union
7
 
8
+ from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
 
 
 
 
 
9
  from huggingface_hub import hf_hub_download
10
  from transformers import PreTrainedTokenizerBase
11
 
12
+ from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
13
  from axolotl.prompt_strategies import load
14
  from axolotl.prompt_tokenizers import (
15
+ AlpacaMultipleChoicePromptTokenizingStrategy,
16
  AlpacaPromptTokenizingStrategy,
17
+ AlpacaReflectionPTStrategy,
18
+ CompletionPromptTokenizingStrategy,
19
  GPTeacherPromptTokenizingStrategy,
20
+ JeopardyPromptTokenizingStrategy,
21
  OpenAssistantPromptTokenizingStrategy,
 
22
  ShareGPTPromptTokenizingStrategy,
 
 
 
23
  SummarizeTLDRPromptTokenizingStrategy,
24
  )
25
  from axolotl.prompters import (
26
  AlpacaPrompter,
27
+ CompletionPrompter,
28
  GPTeacherPrompter,
 
 
29
  JeopardyPrompter,
30
+ MultipleChoiceConcisePrompter,
31
  MultipleChoiceExplainPrompter,
32
+ ReflectAlpacaPrompter,
33
+ ShareGPTPrompter,
34
  SummarizeTLDRPrompter,
 
35
  )
36
 
37
 
 
62
  try:
63
  if cfg.push_dataset_to_hub:
64
  dataset = load_dataset(
65
+ f"{cfg.push_dataset_to_hub}/{ds_hash}",
66
+ use_auth_token=use_auth_token,
67
  )
68
  dataset = dataset["train"]
69
  except Exception: # pylint: disable=broad-except
 
84
  ds: Union[Dataset, DatasetDict] = None
85
  ds_from_hub = False
86
  try:
87
+ load_dataset(
88
+ d.path,
89
+ streaming=True,
90
+ use_auth_token=use_auth_token,
91
+ )
92
  ds_from_hub = True
93
  except FileNotFoundError:
94
  pass
 
96
  # prefer local dataset, even if hub exists
97
  if Path(d.path).exists():
98
  ds = load_dataset(
99
+ "json",
100
+ data_files=d.path,
101
+ streaming=False,
102
+ split=None,
103
  )
104
  elif ds_from_hub:
105
  if d.data_files:
 
111
  )
112
  else:
113
  ds = load_dataset(
114
+ d.path,
115
+ streaming=False,
116
+ use_auth_token=use_auth_token,
117
  )
118
  else:
119
  fp = hf_hub_download(
120
+ repo_id=d.path,
121
+ repo_type="dataset",
122
+ filename=d.data_files,
123
  )
124
  ds = load_dataset("json", data_files=fp, streaming=False, split=None)
125
  if not ds:
 
256
 
257
 
258
  def load_prepare_datasets(
259
+ tokenizer: PreTrainedTokenizerBase,
260
+ cfg,
261
+ default_dataset_prepared_path,
262
  ) -> Tuple[Dataset, Dataset]:
263
  max_packed_sequence_len = (
264
  cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
 
362
  f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}"
363
  )
364
  dataset.push_to_hub(
365
+ f"{cfg.push_dataset_to_hub}/{ds_hash}",
366
+ private=True,
367
  )
368
  else:
369
  dataset = load_tokenized_prepared_datasets(
 
375
  f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards"
376
  )
377
  dataset = dataset.shard(
378
+ num_shards=cfg.dataset_shard_num,
379
+ index=cfg.dataset_shard_idx,
380
  )
381
 
382
  dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
src/axolotl/utils/models.py CHANGED
@@ -5,23 +5,17 @@ import logging
5
  import math
6
  import os
7
  from pathlib import Path
8
- from typing import Optional, Tuple, TYPE_CHECKING # noqa: F401
9
 
10
  import bitsandbytes as bnb
11
  import torch
12
  import transformers
13
- from transformers import ( # noqa: F401
14
- AutoModelForCausalLM,
15
- AutoTokenizer,
16
- PreTrainedModel,
17
- AutoConfig,
18
- BitsAndBytesConfig,
19
- )
20
 
21
  try:
22
- from transformers import (
23
- LlamaForCausalLM,
24
- )
25
  except ImportError:
26
  logging.warning(
27
  "This version of transformers does not support Llama. Consider upgrading."
@@ -31,9 +25,10 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
31
 
32
  if TYPE_CHECKING:
33
  from peft import PeftConfig # noqa: F401
34
- from axolotl.utils.dict import DictDefault # noqa: F401
35
  from transformers import PreTrainedTokenizer # noqa: F401
36
 
 
 
37
 
38
  def load_tokenizer(
39
  base_model_config,
@@ -56,7 +51,10 @@ def load_tokenizer(
56
  logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
57
  logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
58
 
59
- if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]:
 
 
 
60
  tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
61
 
62
  if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
@@ -312,11 +310,7 @@ def load_adapter(model, cfg, adapter):
312
 
313
  def load_llama_adapter(model, cfg):
314
  # type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
315
- from peft import (
316
- AdaptionPromptConfig,
317
- get_peft_model,
318
- PeftModel,
319
- )
320
 
321
  peft_config = AdaptionPromptConfig(
322
  adapter_layers=cfg.peft_adapter.layers, # layers (L)
@@ -361,11 +355,7 @@ def find_all_linear_names(bits, model):
361
  def load_lora(model, cfg):
362
  # type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
363
 
364
- from peft import (
365
- LoraConfig,
366
- get_peft_model,
367
- PeftModel,
368
- )
369
 
370
  lora_target_modules = list(cfg.lora_target_modules or [])
371
 
 
5
  import math
6
  import os
7
  from pathlib import Path
8
+ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
9
 
10
  import bitsandbytes as bnb
11
  import torch
12
  import transformers
13
+ from transformers import AutoModelForCausalLM # noqa: F401
14
+ from transformers import PreTrainedModel # noqa: F401
15
+ from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig
 
 
 
 
16
 
17
  try:
18
+ from transformers import LlamaForCausalLM
 
 
19
  except ImportError:
20
  logging.warning(
21
  "This version of transformers does not support Llama. Consider upgrading."
 
25
 
26
  if TYPE_CHECKING:
27
  from peft import PeftConfig # noqa: F401
 
28
  from transformers import PreTrainedTokenizer # noqa: F401
29
 
30
+ from axolotl.utils.dict import DictDefault # noqa: F401
31
+
32
 
33
  def load_tokenizer(
34
  base_model_config,
 
51
  logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
52
  logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
53
 
54
+ if tokenizer.__class__.__name__ in [
55
+ "LlamaTokenizer",
56
+ "LlamaTokenizerFast",
57
+ ]:
58
  tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
59
 
60
  if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
 
310
 
311
  def load_llama_adapter(model, cfg):
312
  # type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
313
+ from peft import AdaptionPromptConfig, PeftModel, get_peft_model
 
 
 
 
314
 
315
  peft_config = AdaptionPromptConfig(
316
  adapter_layers=cfg.peft_adapter.layers, # layers (L)
 
355
  def load_lora(model, cfg):
356
  # type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
357
 
358
+ from peft import LoraConfig, PeftModel, get_peft_model
 
 
 
 
359
 
360
  lora_target_modules = list(cfg.lora_target_modules or [])
361
 
src/axolotl/utils/tokenization.py CHANGED
@@ -2,6 +2,7 @@
2
 
3
 
4
  import logging
 
5
  from termcolor import colored
6
 
7
 
 
2
 
3
 
4
  import logging
5
+
6
  from termcolor import colored
7
 
8
 
src/axolotl/utils/trainer.py CHANGED
@@ -15,8 +15,8 @@ from torch.optim.lr_scheduler import OneCycleLR
15
  from transformers import EarlyStoppingCallback, Trainer
16
  from transformers.trainer_pt_utils import get_parameter_names
17
 
18
- from axolotl.utils.schedulers import InterpolatingLogScheduler
19
  from axolotl.utils.callbacks import SavePeftModelCallback
 
20
 
21
 
22
  class OneCycleLRSchedulerTrainer(Trainer):
@@ -29,7 +29,9 @@ class OneCycleLRSchedulerTrainer(Trainer):
29
  self.lr_scheduler = None
30
 
31
  def create_scheduler(
32
- self, num_training_steps: int, optimizer: Optional[torch.optim.Optimizer] = None
 
 
33
  ):
34
  optimizer = self.optimizer if optimizer is None else optimizer
35
  num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
@@ -216,7 +218,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
216
  )
217
  callbacks.append(early_stop_cb)
218
 
219
- if cfg.local_rank == 0 and cfg.adapter in ["lora", "qlora"]: # only save in rank 0
 
 
 
220
  callbacks.append(SavePeftModelCallback)
221
 
222
  data_collator_kwargs = {
 
15
  from transformers import EarlyStoppingCallback, Trainer
16
  from transformers.trainer_pt_utils import get_parameter_names
17
 
 
18
  from axolotl.utils.callbacks import SavePeftModelCallback
19
+ from axolotl.utils.schedulers import InterpolatingLogScheduler
20
 
21
 
22
  class OneCycleLRSchedulerTrainer(Trainer):
 
29
  self.lr_scheduler = None
30
 
31
  def create_scheduler(
32
+ self,
33
+ num_training_steps: int,
34
+ optimizer: Optional[torch.optim.Optimizer] = None,
35
  ):
36
  optimizer = self.optimizer if optimizer is None else optimizer
37
  num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
 
218
  )
219
  callbacks.append(early_stop_cb)
220
 
221
+ if cfg.local_rank == 0 and cfg.adapter in [
222
+ "lora",
223
+ "qlora",
224
+ ]: # only save in rank 0
225
  callbacks.append(SavePeftModelCallback)
226
 
227
  data_collator_kwargs = {
tests/test_validation.py CHANGED
@@ -4,8 +4,8 @@ import unittest
4
 
5
  import pytest
6
 
7
- from axolotl.utils.validation import validate_config
8
  from axolotl.utils.dict import DictDefault
 
9
 
10
 
11
  class ValidationTest(unittest.TestCase):
 
4
 
5
  import pytest
6
 
 
7
  from axolotl.utils.dict import DictDefault
8
+ from axolotl.utils.validation import validate_config
9
 
10
 
11
  class ValidationTest(unittest.TestCase):