experimental expansion of ctx len
Browse files- scripts/finetune.py +26 -18
- src/axolotl/utils/data.py +31 -1
scripts/finetune.py
CHANGED
@@ -6,22 +6,20 @@ import os
|
|
6 |
import random
|
7 |
import signal
|
8 |
import sys
|
9 |
-
from functools import partial
|
10 |
from pathlib import Path
|
11 |
from typing import Any, Dict, List, Optional, Union
|
12 |
|
13 |
import fire
|
14 |
import torch
|
15 |
import yaml
|
16 |
-
from transformers import GenerationConfig, TextStreamer
|
17 |
-
|
18 |
-
from axolotl.utils.data import load_prepare_datasets
|
19 |
-
from axolotl.utils.dict import DictDefault
|
20 |
-
from axolotl.utils.models import load_model, load_tokenizer
|
21 |
|
22 |
# add src to the pythonpath so we don't need to pip install this
|
23 |
from optimum.bettertransformer import BetterTransformer
|
|
|
24 |
|
|
|
|
|
|
|
25 |
from axolotl.utils.tokenization import check_dataset_labels
|
26 |
from axolotl.utils.trainer import setup_trainer
|
27 |
from axolotl.utils.validation import validate_config
|
@@ -204,9 +202,19 @@ def train(
|
|
204 |
if check_not_in(
|
205 |
["inference", "shard", "merge_lora"], kwargs
|
206 |
): # don't need to load dataset for these
|
207 |
-
|
208 |
-
|
209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
|
211 |
if cfg.debug or "debug" in kwargs:
|
212 |
logging.info("check_dataset_labels...")
|
@@ -256,7 +264,7 @@ def train(
|
|
256 |
logging.info("check_dataset_labels...")
|
257 |
check_dataset_labels(
|
258 |
train_dataset.select(
|
259 |
-
[random.randrange(0, len(train_dataset) - 1) for i in range(5)]
|
260 |
),
|
261 |
tokenizer,
|
262 |
)
|
@@ -265,10 +273,7 @@ def train(
|
|
265 |
logging.info("Finished preparing dataset. Exiting...")
|
266 |
return
|
267 |
|
268 |
-
|
269 |
-
model.train()
|
270 |
-
except:
|
271 |
-
pass
|
272 |
|
273 |
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
|
274 |
|
@@ -285,14 +290,15 @@ def train(
|
|
285 |
|
286 |
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
287 |
if cfg.local_rank == 0:
|
288 |
-
|
|
|
289 |
if cfg.flash_optimum:
|
290 |
model = BetterTransformer.reverse(model)
|
291 |
model.save_pretrained(cfg.output_dir)
|
292 |
sys.exit(0)
|
|
|
293 |
signal.signal(
|
294 |
-
signal.SIGINT,
|
295 |
-
lambda signum, frame: terminate_handler(signum, frame, model)
|
296 |
)
|
297 |
|
298 |
logging.info("Starting trainer...")
|
@@ -316,7 +322,9 @@ def train(
|
|
316 |
if not Path(cfg.output_dir).is_dir():
|
317 |
os.makedirs(cfg.output_dir, exist_ok=True)
|
318 |
if cfg.flash_optimum:
|
319 |
-
with torch.backends.cuda.sdp_kernel(
|
|
|
|
|
320 |
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
321 |
else:
|
322 |
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
|
|
6 |
import random
|
7 |
import signal
|
8 |
import sys
|
|
|
9 |
from pathlib import Path
|
10 |
from typing import Any, Dict, List, Optional, Union
|
11 |
|
12 |
import fire
|
13 |
import torch
|
14 |
import yaml
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
# add src to the pythonpath so we don't need to pip install this
|
17 |
from optimum.bettertransformer import BetterTransformer
|
18 |
+
from transformers import GenerationConfig, TextStreamer
|
19 |
|
20 |
+
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
21 |
+
from axolotl.utils.dict import DictDefault
|
22 |
+
from axolotl.utils.models import load_model, load_tokenizer
|
23 |
from axolotl.utils.tokenization import check_dataset_labels
|
24 |
from axolotl.utils.trainer import setup_trainer
|
25 |
from axolotl.utils.validation import validate_config
|
|
|
202 |
if check_not_in(
|
203 |
["inference", "shard", "merge_lora"], kwargs
|
204 |
): # don't need to load dataset for these
|
205 |
+
if not cfg.pretraining_dataset:
|
206 |
+
train_dataset, eval_dataset = load_prepare_datasets(
|
207 |
+
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
208 |
+
)
|
209 |
+
else:
|
210 |
+
if cfg.pretraining_dataset is True:
|
211 |
+
pretraining_dataset = "togethercomputer/RedPajama-Data-1T"
|
212 |
+
else:
|
213 |
+
pretraining_dataset = cfg.pretraining_dataset
|
214 |
+
train_dataset = load_pretraining_dataset(
|
215 |
+
pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len
|
216 |
+
)
|
217 |
+
eval_dataset = None
|
218 |
|
219 |
if cfg.debug or "debug" in kwargs:
|
220 |
logging.info("check_dataset_labels...")
|
|
|
264 |
logging.info("check_dataset_labels...")
|
265 |
check_dataset_labels(
|
266 |
train_dataset.select(
|
267 |
+
[random.randrange(0, len(train_dataset) - 1) for i in range(5)] # nosec
|
268 |
),
|
269 |
tokenizer,
|
270 |
)
|
|
|
273 |
logging.info("Finished preparing dataset. Exiting...")
|
274 |
return
|
275 |
|
276 |
+
model.train()
|
|
|
|
|
|
|
277 |
|
278 |
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
|
279 |
|
|
|
290 |
|
291 |
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
292 |
if cfg.local_rank == 0:
|
293 |
+
|
294 |
+
def terminate_handler(_, __, model):
|
295 |
if cfg.flash_optimum:
|
296 |
model = BetterTransformer.reverse(model)
|
297 |
model.save_pretrained(cfg.output_dir)
|
298 |
sys.exit(0)
|
299 |
+
|
300 |
signal.signal(
|
301 |
+
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
|
|
|
302 |
)
|
303 |
|
304 |
logging.info("Starting trainer...")
|
|
|
322 |
if not Path(cfg.output_dir).is_dir():
|
323 |
os.makedirs(cfg.output_dir, exist_ok=True)
|
324 |
if cfg.flash_optimum:
|
325 |
+
with torch.backends.cuda.sdp_kernel(
|
326 |
+
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
327 |
+
):
|
328 |
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
329 |
else:
|
330 |
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
src/axolotl/utils/data.py
CHANGED
@@ -5,7 +5,8 @@ from hashlib import md5
|
|
5 |
from pathlib import Path
|
6 |
from typing import List, Tuple, Union
|
7 |
|
8 |
-
|
|
|
9 |
from huggingface_hub import hf_hub_download
|
10 |
from transformers import PreTrainedTokenizerBase
|
11 |
|
@@ -392,3 +393,32 @@ def load_prepare_datasets(
|
|
392 |
eval_dataset = dataset["test"]
|
393 |
|
394 |
return train_dataset, eval_dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
from pathlib import Path
|
6 |
from typing import List, Tuple, Union
|
7 |
|
8 |
+
import torch
|
9 |
+
from datasets import Dataset, DatasetDict, IterableDataset, load_dataset, load_from_disk
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
from transformers import PreTrainedTokenizerBase
|
12 |
|
|
|
393 |
eval_dataset = dataset["test"]
|
394 |
|
395 |
return train_dataset, eval_dataset
|
396 |
+
|
397 |
+
|
398 |
+
class PretrainingDatasetWrapper(IterableDataset):
|
399 |
+
"""
|
400 |
+
Wrapper for pretraining dataset that avoids loading the dataset into memory
|
401 |
+
"""
|
402 |
+
|
403 |
+
def __init__(self, tokenizer, dataset_path, max_tokens=2048):
|
404 |
+
self.tokenizer = tokenizer
|
405 |
+
self.dataset_path = dataset_path
|
406 |
+
self.max_tokens = max_tokens
|
407 |
+
|
408 |
+
def __iter__(self):
|
409 |
+
buffer = []
|
410 |
+
for sample in load_dataset(
|
411 |
+
self.dataset_path,
|
412 |
+
name="all",
|
413 |
+
split="train",
|
414 |
+
streaming=True,
|
415 |
+
).shuffle(buffer_size=10000):
|
416 |
+
buffer += self.tokenizer(sample["text"])["input_ids"]
|
417 |
+
buffer += [self.tokenizer.eos_token_id]
|
418 |
+
while len(buffer) > self.max_tokens:
|
419 |
+
yield torch.tensor(buffer[: self.max_tokens])
|
420 |
+
buffer = buffer[self.max_tokens :]
|
421 |
+
|
422 |
+
|
423 |
+
def load_pretraining_dataset(path, tokenizer, max_tokens=2048):
|
424 |
+
return PretrainingDatasetWrapper(tokenizer, path, max_tokens=max_tokens)
|