various bugfixes (#856)
Browse files* various bugfixes
use latest tinyllama release
check if val_set_size is empty first
update sdp and xformers llama patches for updated upstream transformers
fix system prompt when no input
calculate total and total supervised tokens even when not sample packing
* add fix for when eval size is estimated to be too small
* should be len 1 for dataset length
* add catchall kwargs
- examples/llama-2/tiny-llama.yml +1 -1
- src/axolotl/core/trainer_builder.py +4 -4
- src/axolotl/monkeypatch/llama_attn_hijack_sdp.py +2 -0
- src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +2 -0
- src/axolotl/prompters.py +1 -1
- src/axolotl/utils/samplers/multipack.py +12 -9
- src/axolotl/utils/trainer.py +23 -22
examples/llama-2/tiny-llama.yml
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
base_model: PY007/TinyLlama-1.1B-step-
|
2 |
|
3 |
model_type: LlamaForCausalLM
|
4 |
tokenizer_type: LlamaTokenizer
|
|
|
1 |
+
base_model: PY007/TinyLlama-1.1B-intermediate-step-715k-1.5T
|
2 |
|
3 |
model_type: LlamaForCausalLM
|
4 |
tokenizer_type: LlamaTokenizer
|
src/axolotl/core/trainer_builder.py
CHANGED
@@ -543,16 +543,16 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
543 |
"dataloader_prefetch_factor"
|
544 |
] = self.cfg.dataloader_prefetch_factor
|
545 |
|
546 |
-
if self.cfg.
|
|
|
|
|
|
|
547 |
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
548 |
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
|
549 |
elif self.cfg.evaluation_strategy:
|
550 |
training_arguments_kwargs[
|
551 |
"evaluation_strategy"
|
552 |
] = self.cfg.evaluation_strategy
|
553 |
-
elif self.cfg.val_set_size == 0:
|
554 |
-
# no eval set, so don't eval
|
555 |
-
training_arguments_kwargs["evaluation_strategy"] = "no"
|
556 |
else:
|
557 |
# we have an eval set, but no steps defined, default to use epoch
|
558 |
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
|
|
543 |
"dataloader_prefetch_factor"
|
544 |
] = self.cfg.dataloader_prefetch_factor
|
545 |
|
546 |
+
if self.cfg.val_set_size == 0:
|
547 |
+
# no eval set, so don't eval
|
548 |
+
training_arguments_kwargs["evaluation_strategy"] = "no"
|
549 |
+
elif self.cfg.eval_steps:
|
550 |
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
551 |
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
|
552 |
elif self.cfg.evaluation_strategy:
|
553 |
training_arguments_kwargs[
|
554 |
"evaluation_strategy"
|
555 |
] = self.cfg.evaluation_strategy
|
|
|
|
|
|
|
556 |
else:
|
557 |
# we have an eval set, but no steps defined, default to use epoch
|
558 |
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
src/axolotl/monkeypatch/llama_attn_hijack_sdp.py
CHANGED
@@ -25,6 +25,8 @@ def sdp_attention_forward(
|
|
25 |
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
26 |
output_attentions: bool = False,
|
27 |
use_cache: bool = False,
|
|
|
|
|
28 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
29 |
# pylint: disable=duplicate-code
|
30 |
bsz, q_len, _ = hidden_states.size()
|
|
|
25 |
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
26 |
output_attentions: bool = False,
|
27 |
use_cache: bool = False,
|
28 |
+
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
29 |
+
**kwargs, # pylint: disable=unused-argument
|
30 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
31 |
# pylint: disable=duplicate-code
|
32 |
bsz, q_len, _ = hidden_states.size()
|
src/axolotl/monkeypatch/llama_attn_hijack_xformers.py
CHANGED
@@ -29,6 +29,8 @@ def xformers_forward(
|
|
29 |
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
30 |
output_attentions: bool = False,
|
31 |
use_cache: bool = False,
|
|
|
|
|
32 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
33 |
# pylint: disable=duplicate-code
|
34 |
bsz, q_len, _ = hidden_states.size()
|
|
|
29 |
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
30 |
output_attentions: bool = False,
|
31 |
use_cache: bool = False,
|
32 |
+
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
33 |
+
**kwargs, # pylint: disable=unused-argument
|
34 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
35 |
# pylint: disable=duplicate-code
|
36 |
bsz, q_len, _ = hidden_states.size()
|
src/axolotl/prompters.py
CHANGED
@@ -75,7 +75,7 @@ class AlpacaPrompter(Prompter):
|
|
75 |
else:
|
76 |
res = (
|
77 |
self.system_format.format(system=self.system_no_input_prompt)
|
78 |
-
if self.
|
79 |
else ""
|
80 |
) + self.turn_no_input_format.format(instruction=instruction)
|
81 |
if output:
|
|
|
75 |
else:
|
76 |
res = (
|
77 |
self.system_format.format(system=self.system_no_input_prompt)
|
78 |
+
if self.system_no_input_prompt
|
79 |
else ""
|
80 |
) + self.turn_no_input_format.format(instruction=instruction)
|
81 |
if output:
|
src/axolotl/utils/samplers/multipack.py
CHANGED
@@ -181,13 +181,16 @@ class MultipackBatchSampler(BatchSampler):
|
|
181 |
)
|
182 |
|
183 |
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
184 |
-
return (
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
*
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
|
|
|
|
|
|
193 |
)
|
|
|
181 |
)
|
182 |
|
183 |
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
184 |
+
return min(
|
185 |
+
1,
|
186 |
+
(
|
187 |
+
world_size
|
188 |
+
* math.floor(
|
189 |
+
0.99
|
190 |
+
* lengths_sum_per_device
|
191 |
+
/ self.packing_efficiency_estimate
|
192 |
+
// self.batch_max_len
|
193 |
+
)
|
194 |
+
- 1
|
195 |
+
),
|
196 |
)
|
src/axolotl/utils/trainer.py
CHANGED
@@ -142,31 +142,32 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
|
142 |
|
143 |
|
144 |
def calculate_total_num_steps(cfg, train_dataset):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
if cfg.sample_packing:
|
146 |
# we have to drop anything longer then sequence len otherwise
|
147 |
# flash attention with position ids fails
|
148 |
-
if not cfg.total_num_tokens:
|
149 |
-
total_num_tokens = np.sum(
|
150 |
-
train_dataset.data.column("input_ids")
|
151 |
-
.to_pandas()
|
152 |
-
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
|
153 |
-
.values
|
154 |
-
)
|
155 |
-
LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
|
156 |
-
cfg.total_num_tokens = total_num_tokens
|
157 |
-
|
158 |
-
if not cfg.total_supervised_tokens:
|
159 |
-
total_supervised_tokens = (
|
160 |
-
train_dataset.data.column("labels")
|
161 |
-
.to_pandas()
|
162 |
-
.apply(lambda x: np.sum(np.array(x) != -100))
|
163 |
-
.sum()
|
164 |
-
)
|
165 |
-
LOG.debug(
|
166 |
-
f"`total_supervised_tokens: {total_supervised_tokens}`",
|
167 |
-
main_process_only=True,
|
168 |
-
)
|
169 |
-
cfg.total_supervised_tokens = total_supervised_tokens
|
170 |
|
171 |
if cfg.sample_packing_eff_est:
|
172 |
total_num_steps = (
|
|
|
142 |
|
143 |
|
144 |
def calculate_total_num_steps(cfg, train_dataset):
|
145 |
+
if not cfg.total_num_tokens:
|
146 |
+
total_num_tokens = np.sum(
|
147 |
+
train_dataset.data.column("input_ids")
|
148 |
+
.to_pandas()
|
149 |
+
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
|
150 |
+
.values
|
151 |
+
)
|
152 |
+
LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
|
153 |
+
cfg.total_num_tokens = total_num_tokens
|
154 |
+
|
155 |
+
if not cfg.total_supervised_tokens:
|
156 |
+
total_supervised_tokens = (
|
157 |
+
train_dataset.data.column("labels")
|
158 |
+
.to_pandas()
|
159 |
+
.apply(lambda x: np.sum(np.array(x) != -100))
|
160 |
+
.sum()
|
161 |
+
)
|
162 |
+
LOG.debug(
|
163 |
+
f"`total_supervised_tokens: {total_supervised_tokens}`",
|
164 |
+
main_process_only=True,
|
165 |
+
)
|
166 |
+
cfg.total_supervised_tokens = total_supervised_tokens
|
167 |
+
|
168 |
if cfg.sample_packing:
|
169 |
# we have to drop anything longer then sequence len otherwise
|
170 |
# flash attention with position ids fails
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
|
172 |
if cfg.sample_packing_eff_est:
|
173 |
total_num_steps = (
|