|
"""Module containing data utilities""" |
|
|
|
import logging |
|
from hashlib import md5 |
|
from pathlib import Path |
|
from typing import List, Tuple, Union |
|
|
|
import torch |
|
from datasets import Dataset, DatasetDict, IterableDataset, load_dataset, load_from_disk |
|
from huggingface_hub import hf_hub_download |
|
from transformers import PreTrainedTokenizerBase |
|
|
|
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset |
|
from axolotl.prompt_strategies import load |
|
from axolotl.prompt_tokenizers import ( |
|
AlpacaMultipleChoicePromptTokenizingStrategy, |
|
AlpacaPromptTokenizingStrategy, |
|
AlpacaReflectionPTStrategy, |
|
CompletionPromptTokenizingStrategy, |
|
GPTeacherPromptTokenizingStrategy, |
|
JeopardyPromptTokenizingStrategy, |
|
OpenAssistantPromptTokenizingStrategy, |
|
ShareGPTPromptTokenizingStrategy, |
|
SummarizeTLDRPromptTokenizingStrategy, |
|
) |
|
from axolotl.prompters import ( |
|
AlpacaPrompter, |
|
CompletionPrompter, |
|
GPTeacherPrompter, |
|
JeopardyPrompter, |
|
MultipleChoiceConcisePrompter, |
|
MultipleChoiceExplainPrompter, |
|
ReflectAlpacaPrompter, |
|
ShareGPTPrompter, |
|
SummarizeTLDRPrompter, |
|
) |
|
|
|
|
|
def load_tokenized_prepared_datasets( |
|
tokenizer, cfg, default_dataset_prepared_path |
|
) -> DatasetDict: |
|
tokenizer_name = tokenizer.__class__.__name__ |
|
ds_hash = str( |
|
md5( |
|
( |
|
str(cfg.sequence_len) |
|
+ "@" |
|
+ "|".join( |
|
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]) |
|
) |
|
+ "|" |
|
+ tokenizer_name |
|
).encode("utf-8") |
|
).hexdigest() |
|
) |
|
prepared_ds_path = ( |
|
Path(cfg.dataset_prepared_path) / ds_hash |
|
if cfg.dataset_prepared_path |
|
else Path(default_dataset_prepared_path) / ds_hash |
|
) |
|
dataset = None |
|
use_auth_token = cfg.hf_use_auth_token |
|
try: |
|
if cfg.push_dataset_to_hub: |
|
dataset = load_dataset( |
|
f"{cfg.push_dataset_to_hub}/{ds_hash}", |
|
use_auth_token=use_auth_token, |
|
) |
|
dataset = dataset["train"] |
|
except Exception: |
|
pass |
|
|
|
if dataset: |
|
... |
|
elif any(prepared_ds_path.glob("*")): |
|
logging.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") |
|
dataset = load_from_disk(str(prepared_ds_path)) |
|
logging.info("Prepared dataset loaded from disk...") |
|
else: |
|
logging.info(f"Unable to find prepared dataset in {prepared_ds_path}") |
|
logging.info("Loading raw datasets...") |
|
|
|
if cfg.seed: |
|
seed = cfg.seed |
|
else: |
|
logging.info("No seed provided, using default seed of 42") |
|
seed = 42 |
|
|
|
datasets = [] |
|
|
|
for d in cfg.datasets: |
|
ds: Union[Dataset, DatasetDict] = None |
|
ds_from_hub = False |
|
try: |
|
load_dataset( |
|
d.path, |
|
streaming=True, |
|
use_auth_token=use_auth_token, |
|
) |
|
ds_from_hub = True |
|
except FileNotFoundError: |
|
pass |
|
|
|
|
|
if Path(d.path).exists(): |
|
ds = load_dataset( |
|
"json", |
|
data_files=d.path, |
|
streaming=False, |
|
split=None, |
|
) |
|
elif ds_from_hub: |
|
if d.data_files: |
|
ds = load_dataset( |
|
d.path, |
|
streaming=False, |
|
data_files=d.data_files, |
|
use_auth_token=use_auth_token, |
|
) |
|
else: |
|
ds = load_dataset( |
|
d.path, |
|
streaming=False, |
|
use_auth_token=use_auth_token, |
|
) |
|
else: |
|
fp = hf_hub_download( |
|
repo_id=d.path, |
|
repo_type="dataset", |
|
filename=d.data_files, |
|
) |
|
ds = load_dataset("json", data_files=fp, streaming=False, split=None) |
|
if not ds: |
|
raise ValueError("unhandled dataset load") |
|
|
|
if d.shards: |
|
if "train" in ds: |
|
ds = ds.shuffle(seed=seed)["train"].shard( |
|
num_shards=d.shards, index=0 |
|
) |
|
else: |
|
ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0) |
|
d_type = d.type |
|
d_type_split = d_type.split(":") |
|
d_base_type = d_type_split[0] |
|
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None |
|
if "train" in ds: |
|
ds = ds["train"] |
|
if ds_strategy := load(d.type, tokenizer, cfg): |
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) |
|
datasets.append(ds_wrapper) |
|
elif d_base_type == "alpaca": |
|
ds_strategy = AlpacaPromptTokenizingStrategy( |
|
AlpacaPrompter(d_prompt_style), |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
) |
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) |
|
datasets.append(ds_wrapper) |
|
elif d_base_type == "explainchoice": |
|
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( |
|
MultipleChoiceExplainPrompter(d_prompt_style), |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
) |
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) |
|
datasets.append(ds_wrapper) |
|
elif d_base_type == "concisechoice": |
|
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( |
|
MultipleChoiceConcisePrompter(d_prompt_style), |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
) |
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) |
|
datasets.append(ds_wrapper) |
|
elif d_base_type == "summarizetldr": |
|
ds_strategy = SummarizeTLDRPromptTokenizingStrategy( |
|
SummarizeTLDRPrompter(d_prompt_style), |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
) |
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) |
|
datasets.append(ds_wrapper) |
|
elif d_base_type == "jeopardy": |
|
ds_strategy = JeopardyPromptTokenizingStrategy( |
|
JeopardyPrompter(d_prompt_style), |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
) |
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) |
|
datasets.append(ds_wrapper) |
|
elif d_base_type == "oasst": |
|
ds_strategy = OpenAssistantPromptTokenizingStrategy( |
|
AlpacaPrompter(d_prompt_style), |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
) |
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) |
|
datasets.append(ds_wrapper) |
|
elif d_base_type == "gpteacher": |
|
ds_strategy = GPTeacherPromptTokenizingStrategy( |
|
GPTeacherPrompter(d_prompt_style), |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
) |
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) |
|
datasets.append(ds_wrapper) |
|
elif d_base_type == "reflection": |
|
ds_strategy = AlpacaReflectionPTStrategy( |
|
ReflectAlpacaPrompter(d_prompt_style), |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
) |
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) |
|
datasets.append(ds_wrapper) |
|
elif d_base_type == "sharegpt": |
|
ds_strategy = ShareGPTPromptTokenizingStrategy( |
|
ShareGPTPrompter(d_prompt_style), |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
) |
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) |
|
datasets.append(ds_wrapper) |
|
elif d_base_type == "completion": |
|
ds_strategy = CompletionPromptTokenizingStrategy( |
|
CompletionPrompter(), |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
) |
|
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) |
|
datasets.append(ds_wrapper) |
|
else: |
|
logging.error(f"unhandled prompt tokenization strategy: {d.type}") |
|
raise ValueError(f"unhandled prompt tokenization strategy: {d.type}") |
|
logging.info("tokenizing, merging, and shuffling master dataset") |
|
|
|
samples: List[int] = [] |
|
for d in datasets: |
|
samples = samples + list(d) |
|
dataset = Dataset.from_list(samples).shuffle(seed=seed) |
|
if cfg.local_rank == 0: |
|
logging.info( |
|
f"Saving merged prepared dataset to disk... {prepared_ds_path}" |
|
) |
|
dataset.save_to_disk(prepared_ds_path) |
|
if cfg.push_dataset_to_hub: |
|
logging.info( |
|
f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" |
|
) |
|
dataset.push_to_hub( |
|
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True |
|
) |
|
|
|
return dataset |
|
|
|
|
|
def load_prepare_datasets( |
|
tokenizer: PreTrainedTokenizerBase, |
|
cfg, |
|
default_dataset_prepared_path, |
|
) -> Tuple[Dataset, Dataset]: |
|
max_packed_sequence_len = ( |
|
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len |
|
) |
|
max_packed_sequence_len = min( |
|
max_packed_sequence_len, cfg.sequence_len |
|
) |
|
|
|
tokenizer_name = tokenizer.__class__.__name__ |
|
if cfg.max_packed_sequence_len is not None: |
|
|
|
seed = f"@{str(cfg.seed)}" if cfg.seed else "" |
|
ds_hash = str( |
|
md5( |
|
( |
|
str(cfg.sequence_len) |
|
+ "@" |
|
+ str(max_packed_sequence_len) |
|
+ seed |
|
+ "|".join( |
|
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]) |
|
) |
|
+ "|" |
|
+ tokenizer_name |
|
).encode("utf-8") |
|
).hexdigest() |
|
) |
|
prepared_ds_path = ( |
|
Path(cfg.dataset_prepared_path) / ds_hash |
|
if cfg.dataset_prepared_path |
|
else Path(default_dataset_prepared_path) / ds_hash |
|
) |
|
|
|
dataset = None |
|
use_auth_token = cfg.hf_use_auth_token |
|
try: |
|
if cfg.push_dataset_to_hub: |
|
logging.info( |
|
f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}" |
|
) |
|
dataset = load_dataset( |
|
f"{cfg.push_dataset_to_hub}/{ds_hash}", |
|
use_auth_token=use_auth_token, |
|
) |
|
dataset = dataset["train"] |
|
except Exception: |
|
pass |
|
|
|
if dataset: |
|
... |
|
elif any(prepared_ds_path.glob("*")): |
|
logging.info( |
|
f"Loading prepared packed dataset from disk at {prepared_ds_path}..." |
|
) |
|
dataset = load_from_disk(str(prepared_ds_path)) |
|
logging.info("Prepared packed dataset loaded from disk...") |
|
if cfg.push_dataset_to_hub: |
|
logging.info( |
|
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" |
|
) |
|
dataset.push_to_hub( |
|
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True |
|
) |
|
else: |
|
dataset = load_tokenized_prepared_datasets( |
|
tokenizer, cfg, default_dataset_prepared_path |
|
) |
|
|
|
if cfg.seed: |
|
dataset = dataset.shuffle(seed=cfg.seed) |
|
|
|
constant_len_dataset = ConstantLengthDataset( |
|
tokenizer, |
|
[dataset], |
|
seq_length=max_packed_sequence_len, |
|
) |
|
logging.info( |
|
f"packing master dataset to len: {cfg.max_packed_sequence_len}" |
|
) |
|
dataset = Dataset.from_list(list(constant_len_dataset)) |
|
|
|
|
|
dataset = Dataset.from_list( |
|
[ |
|
d |
|
for d in dataset |
|
if len(d["input_ids"]) < cfg.sequence_len |
|
and len(d["input_ids"]) > 0 |
|
and len(d["input_ids"]) == len(d["attention_mask"]) |
|
and len(d["input_ids"]) == len(d["labels"]) |
|
] |
|
) |
|
|
|
if cfg.local_rank == 0: |
|
logging.info( |
|
f"Saving packed prepared dataset to disk... {prepared_ds_path}" |
|
) |
|
dataset.save_to_disk(prepared_ds_path) |
|
if cfg.push_dataset_to_hub: |
|
logging.info( |
|
f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" |
|
) |
|
dataset.push_to_hub( |
|
f"{cfg.push_dataset_to_hub}/{ds_hash}", |
|
private=True, |
|
) |
|
else: |
|
dataset = load_tokenized_prepared_datasets( |
|
tokenizer, cfg, default_dataset_prepared_path |
|
) |
|
|
|
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None: |
|
logging.info( |
|
f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards" |
|
) |
|
dataset = dataset.shard( |
|
num_shards=cfg.dataset_shard_num, |
|
index=cfg.dataset_shard_idx, |
|
) |
|
|
|
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False) |
|
train_dataset = dataset["train"] |
|
eval_dataset = dataset["test"] |
|
|
|
return train_dataset, eval_dataset |
|
|
|
|
|
class PretrainingDatasetWrapper(IterableDataset): |
|
""" |
|
Wrapper for pretraining dataset that avoids loading the dataset into memory |
|
""" |
|
|
|
def __init__(self, tokenizer, dataset_path, max_tokens=2048): |
|
self.tokenizer = tokenizer |
|
self.dataset_path = dataset_path |
|
self.max_tokens = max_tokens |
|
|
|
def __iter__(self): |
|
buffer = [] |
|
for sample in load_dataset( |
|
self.dataset_path, |
|
)["train"].shuffle(): |
|
buffer += self.tokenizer(sample["text"])["input_ids"] |
|
buffer += [self.tokenizer.eos_token_id] |
|
while len(buffer) > self.max_tokens: |
|
input_ids = torch.tensor(buffer[: self.max_tokens]) |
|
yield { |
|
"input_ids": input_ids, |
|
"attention_mask": torch.ones(input_ids.size()), |
|
"labels": input_ids, |
|
} |
|
buffer = buffer[self.max_tokens :] |
|
|
|
|
|
def load_pretraining_dataset(path, tokenizer, max_tokens=2048): |
|
return PretrainingDatasetWrapper(tokenizer, path, max_tokens=max_tokens) |
|
|