# -*- coding: utf-8 -*-
# Transformers installation
# ! pip install transformers datasets
# To install from source instead of the last release, comment the command above and uncomment the following one.
# ! pip install git+https://github.com/huggingface/transformers.git
# #@title
# from IPython.display import HTML
# HTML('')
# from huggingface_hub import notebook_login
# notebook_login()
# from datasets import load_dataset
# eli5 = load_dataset("eli5", split="train_asks[:5000]")
from datasets import load_dataset
# Falcon = load_dataset("csv", data_files="FalconData.csv")
Falcon = load_dataset('csv', data_files={"train": 'FalconData_train2.csv', "validation": 'FalconData_validation2.csv'})
print('Dataset Loaded!')
# Falcon = Falcon.train_test_split(test_size=0.10)
"""Then take a look at an example:"""
Falcon['train'][0]
Falcon['validation'][0]
# #@title
# from IPython.display import HTML
# HTML('')
"""The next step is to load a DistilGPT2 tokenizer to process the `text` subfield:"""
from transformers import AutoTokenizer, GPT2TokenizerFast
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
# tokenizer = GPT2TokenizerFast.from_pretrained("Xenova/gpt-4")#, cache_dir=cache_dir)
# tokenizer.pad_token
# tokenizer.eos_token=128000
# tokenizer.bos_token='128000'
# tokenizer.eos_token='128001'
tokenizer.pad_token = tokenizer.eos_token
Falcon = Falcon.flatten()
Falcon["train"][0]
def preprocess_function(examples):
return tokenizer([" ".join(x) for x in examples["Text"]])
tokenized_Falcon = Falcon.map(
preprocess_function,
batched=True,
num_proc=4,
remove_columns=Falcon["train"].column_names,
)
block_size = tokenizer.model_max_length
# block_size = 2048
def group_texts(examples):
# Concatenate all texts.
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.
if total_length >= block_size:
total_length = (total_length // block_size) * block_size
# Split by chunks of block_size.
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
result["labels"] = result["input_ids"].copy()
return result
"""Apply the `group_texts` function over the entire dataset:"""
lm_dataset = tokenized_Falcon.map(group_texts, batched=True, num_proc=4)
from transformers import DataCollatorForLanguageModeling
# tokenizer.pad_token
# tokenizer.bos_token='128000'
# tokenizer.eos_token='128001'
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
import torch
model = AutoModelForCausalLM.from_pretrained("rwh/tiny8", torch_dtype=torch.bfloat16)
print('Model Loaded!')
# import torch
# torch.cuda.empty_cache()
# import torch
# import gc
# # del tensor_name # Delete the tensor
# gc.collect() # Collect garbage
# torch.cuda.empty_cache() # Clear cache
# torch.cuda.empty_cache()
# torch.no_grad()
model.to('cuda')
OutputDir = "C1ReadyModel"
training_args = TrainingArguments(
output_dir=OutputDir,
overwrite_output_dir=True,
bf16=True,
# evaluation_strategy="epoch",
evaluation_strategy="steps",
# learning_rate=3.25e-06,
# learning_rate=2e-5,
learning_rate=1e-5,
weight_decay=0.01,
# weight_decay=0.001,
num_train_epochs=6,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
# lr_scheduler_type = 'cosine',
lr_scheduler_type = 'linear',
push_to_hub=False,
save_total_limit = 2,
save_strategy = "steps",
load_best_model_at_end=True,
save_safetensors=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=lm_dataset["train"],
eval_dataset=lm_dataset["validation"],
# eval_dataset=lm_dataset["test"],
data_collator=data_collator,
)
# trainer.train()
print('Started Training!')
trainer.train()
trainer.save_model(OutputDir)
print('Saved Model Path:', OutputDir)
import math
eval_results = trainer.evaluate()
print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")