Better performance on short context lengths

#1
by NilanE - opened

Nice seeing the work you're doing with my dataset!
I recommend chunking the chapters to fit within a reasonable context size (2048-4096). Using RoPE had poor results from my testing, and also took more VRAM.

from transformers import AutoTokenizer
import jsonlines
import os

tokenizer = AutoTokenizer.from_pretrained("NilanE/tinyllama-relora-merge")

max_seq_len = 2048 # max context length

prompt = "Translate this from Japanese to English:\n### JAPANESE:\n\n### ENGLISH:\n</s>" # insert SFT prompt to add to token count

input_file_path = "dataset.jsonl"

output_file_path = input_file_path.split('.')[0] + "-chunked." + input_file_path.split('.')[1]
promptTokens = len(tokenizer.tokenize(prompt))

#tolerance
max_seq_len -= 10

skippedDocs = 0


if os.path.exists(output_file_path):
    os.remove(output_file_path)

with jsonlines.open(input_file_path) as reader, jsonlines.open(output_file_path, 'a') as writer:
    for entry in reader:
        src_lines = entry['src'].strip().split('\n')
        trg_lines = entry['trg'].strip().split('\n')

        out_src = []
        out_trg = []
        tokenCount = 0
        lastTokenCount = 0
        longLines = 0

        try:
            for x in range(len(src_lines)):
                out_src.append(src_lines[x])
                out_trg.append(trg_lines[x])
                out_src_string = "\n".join(out_src)
                trg_src_string = "\n".join(out_trg)
                tokenCount = len(tokenizer.tokenize(out_src_string.strip() + trg_src_string.strip())) + promptTokens
                if tokenCount-lastTokenCount < max_seq_len-1: # avoid lines > max line length    
                    if tokenCount > max_seq_len-1:
                        src_end = out_src.pop()
                        trg_end = out_trg.pop()
                        out_src_string = "\n".join(out_src)
                        trg_src_string = "\n".join(out_trg)
                        data = {
                            'src' : out_src_string.strip(),
                            'trg' : trg_src_string.strip()
                        }
                        writer.write(data)
                        out_src = [src_end]
                        out_trg = [trg_end]
                    elif x+1 == len(src_lines): #and len(out_src) > 2:
                        data = {
                            'src' : out_src_string.strip(),
                            'trg' : trg_src_string.strip()
                        }
                        writer.write(data)
                else:
                    # remove offending line > max_seq_len
                    out_src.pop()
                    out_trg.pop()
                    out_src_string = "\n".join(out_src)
                    trg_src_string = "\n".join(out_trg)
                    tokenCount = len(tokenizer.tokenize(prompt + out_src_string.strip() + trg_src_string.strip()))   
                    longLines += 1

                lastTokenCount = tokenCount
        except:
            skippedDocs += 1

print(f"LINES LONGER THAN MAX SEQUENCE LENTH: {longLines}")
print(f"SKIPPED DOCS: {skippedDocs}")

Here's the script I use for chunking my dataset. I actually tested a 7B model (augmxnt/shisa-gamma-7b-v1) on a subset of the dataset a while ago and had amazing result, even without much training.

I also filtered out the partial chapter titles in some entries, currently pushing the fixed dataset to the hub!

Yeah I'll try that script with the newer version of your dataset, and I did train it first with 4k context as well and that did seem to work better than training it with longer context. I might try other Japanese models but I chose this base model since it had pretty good scores in benchmarks.

Sign up or log in to comment