Spaces:
Sleeping
Sleeping
""" | |
Download, preprocess and serve the TinyStories dataset as a DataLoader. | |
""" | |
import argparse | |
import glob | |
import json | |
import os | |
import random | |
from typing import List | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
import numpy as np | |
import requests | |
import torch | |
import torch.distributed as dist | |
from tqdm import tqdm | |
from tokenizer import Tokenizer | |
DATA_CACHE_DIR = "data" | |
def download_file(url: str, fname: str, chunk_size=1024): | |
"""Helper function to download a file from a given url""" | |
resp = requests.get(url, stream=True) | |
total = int(resp.headers.get("content-length", 0)) | |
with open(fname, "wb") as file, tqdm( | |
desc=fname, | |
total=total, | |
unit="iB", | |
unit_scale=True, | |
unit_divisor=1024, | |
) as bar: | |
for data in resp.iter_content(chunk_size=chunk_size): | |
size = file.write(data) | |
bar.update(size) | |
def download(): | |
"""Downloads the dataset to disk.""" | |
os.makedirs(DATA_CACHE_DIR, exist_ok=True) | |
# download the TinyStories dataset, unless it's already downloaded | |
data_url = "https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories_all_data.tar.gz" | |
data_filename = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data.tar.gz") | |
if not os.path.exists(data_filename): | |
print(f"Downloading {data_url} to {data_filename}...") | |
download_file(data_url, data_filename) | |
else: | |
print(f"{data_filename} already exists, skipping download...") | |
# unpack the tar.gz file into all the data shards (json files) | |
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data") | |
if not os.path.exists(data_dir): | |
os.makedirs(data_dir, exist_ok=True) | |
print(f"Unpacking {data_filename}...") | |
os.system(f"tar -xzf {data_filename} -C {data_dir}") | |
else: | |
print(f"{data_dir} already exists, skipping unpacking...") | |
# print a single example just for debugging and such | |
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json"))) | |
with open(shard_filenames[0], "r") as f: | |
data = json.load(f) | |
print("Download done.") | |
print(f"Number of shards: {len(shard_filenames)}") | |
print(f"Example story:\n{data[0]}") | |
def pretokenize(): | |
enc = Tokenizer() | |
def process_shard(shard): | |
with open(shard, "r") as f: | |
data = json.load(f) | |
all_tokens = [] | |
for example in tqdm(data): | |
text = example["story"] | |
text = text.strip() # get rid of leading/trailing whitespace | |
tokens = enc.encode(text, bos=True, eos=False) # encode the text, use BOS | |
all_tokens.extend(tokens) | |
# convert to uint16 nparray | |
all_tokens = np.array(all_tokens, dtype=np.uint16) | |
# write to disk | |
tokenized_filename = shard.replace(".json", ".bin") | |
with open(tokenized_filename, "wb") as f: | |
f.write(all_tokens.tobytes()) | |
print(f"Saved {tokenized_filename}") | |
# iterate the shards and tokenize all of them one by one | |
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data") | |
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.json"))) | |
# process all the shards in a threadpool | |
with ThreadPoolExecutor(max_workers=8) as executor: | |
executor.map(process_shard, shard_filenames) | |
print("Done.") | |
class PretokDataset(torch.utils.data.IterableDataset): | |
"""Loads pretokenized examples from disk and yields them as PyTorch tensors.""" | |
def __init__(self, split, max_seq_len): | |
super().__init__() | |
self.split = split | |
self.max_seq_len = max_seq_len | |
def __iter__(self): | |
# get worker info within a DataLoader | |
worker_info = torch.utils.data.get_worker_info() | |
worker_id = worker_info.id if worker_info else 0 | |
# get DDP rank info | |
rank = dist.get_rank() if dist.is_initialized() else 0 | |
# combine the worker_id and worker_rank to create a unique seed for rng | |
seed = 42 + worker_id + 1337 * rank | |
rng = random.Random(seed) | |
print(f"Created a PretokDataset with rng seed {seed}") | |
data_dir = os.path.join(DATA_CACHE_DIR, "TinyStories_all_data") | |
shard_filenames = sorted(glob.glob(os.path.join(data_dir, "*.bin"))) | |
# train/test split. let's use only shard 0 for test split, rest train | |
shard_filenames = shard_filenames[1:] if self.split == "train" else shard_filenames[:1] | |
while True: | |
rng.shuffle(shard_filenames) | |
for shard in shard_filenames: | |
# open the dataset for reading but keep it on disk with memmap | |
m = np.memmap(shard, dtype=np.uint16, mode="r") | |
num_batches = len(m) // self.max_seq_len | |
num_batches -= 1 # drop the last partial batch | |
assert num_batches > 0, "this shard is way too small? investigate." | |
ixs = list(range(num_batches)) | |
rng.shuffle(ixs) | |
for ix in ixs: | |
start = ix * self.max_seq_len | |
end = start + self.max_seq_len + 1 | |
# calling .astype will copy the data into a new numpy array, now in RAM | |
chunk = torch.from_numpy((m[start:end]).astype(np.int64)) | |
x = chunk[:-1] | |
y = chunk[1:] | |
yield x, y | |
class Task: | |
def iter_batches(split, batch_size, max_seq_len, device, num_workers=0): | |
ds = PretokDataset(split, max_seq_len) | |
dl = torch.utils.data.DataLoader( | |
ds, batch_size=batch_size, pin_memory=True, num_workers=num_workers | |
) | |
for x, y in dl: | |
x = x.to(device, non_blocking=True) | |
y = y.to(device, non_blocking=True) | |
yield x, y | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("stage", type=str, choices=["download", "train_tokenizer", "pretokenize"]) | |
args = parser.parse_args() | |
# depending on the stage call the appropriate function | |
fun = { | |
"download": download, | |
"pretokenize": pretokenize, | |
} | |
fun[args.stage]() |