# -*- coding: utf-8 -*- # Transformers installation # ! pip install transformers datasets # To install from source instead of the last release, comment the command above and uncomment the following one. # ! pip install git+https://github.com/huggingface/transformers.git # #@title # from IPython.display import HTML # HTML('') # from huggingface_hub import notebook_login # notebook_login() # from datasets import load_dataset # eli5 = load_dataset("eli5", split="train_asks[:5000]") from datasets import load_dataset # Falcon = load_dataset("csv", data_files="FalconData.csv") Falcon = load_dataset('csv', data_files={"train": 'FalconData_train2.csv', "validation": 'FalconData_validation2.csv'}) print('Dataset Loaded!') # Falcon = Falcon.train_test_split(test_size=0.10) """Then take a look at an example:""" Falcon['train'][0] Falcon['validation'][0] # #@title # from IPython.display import HTML # HTML('') """The next step is to load a DistilGPT2 tokenizer to process the `text` subfield:""" from transformers import AutoTokenizer, GPT2TokenizerFast tokenizer = AutoTokenizer.from_pretrained("distilgpt2") # tokenizer = GPT2TokenizerFast.from_pretrained("Xenova/gpt-4")#, cache_dir=cache_dir) # tokenizer.pad_token # tokenizer.eos_token=128000 # tokenizer.bos_token='128000' # tokenizer.eos_token='128001' tokenizer.pad_token = tokenizer.eos_token Falcon = Falcon.flatten() Falcon["train"][0] def preprocess_function(examples): return tokenizer([" ".join(x) for x in examples["Text"]]) tokenized_Falcon = Falcon.map( preprocess_function, batched=True, num_proc=4, remove_columns=Falcon["train"].column_names, ) block_size = tokenizer.model_max_length # block_size = 2048 def group_texts(examples): # Concatenate all texts. concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. if total_length >= block_size: total_length = (total_length // block_size) * block_size # Split by chunks of block_size. result = { k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items() } result["labels"] = result["input_ids"].copy() return result """Apply the `group_texts` function over the entire dataset:""" lm_dataset = tokenized_Falcon.map(group_texts, batched=True, num_proc=4) from transformers import DataCollatorForLanguageModeling # tokenizer.pad_token # tokenizer.bos_token='128000' # tokenizer.eos_token='128001' data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) from transformers import AutoModelForCausalLM, TrainingArguments, Trainer import torch model = AutoModelForCausalLM.from_pretrained("rwh/tiny8", torch_dtype=torch.bfloat16) print('Model Loaded!') # import torch # torch.cuda.empty_cache() # import torch # import gc # # del tensor_name # Delete the tensor # gc.collect() # Collect garbage # torch.cuda.empty_cache() # Clear cache # torch.cuda.empty_cache() # torch.no_grad() model.to('cuda') OutputDir = "C1ReadyModel" training_args = TrainingArguments( output_dir=OutputDir, overwrite_output_dir=True, bf16=True, # evaluation_strategy="epoch", evaluation_strategy="steps", # learning_rate=3.25e-06, # learning_rate=2e-5, learning_rate=1e-5, weight_decay=0.01, # weight_decay=0.001, num_train_epochs=6, per_device_train_batch_size=8, per_device_eval_batch_size=8, # lr_scheduler_type = 'cosine', lr_scheduler_type = 'linear', push_to_hub=False, save_total_limit = 2, save_strategy = "steps", load_best_model_at_end=True, save_safetensors=True, ) trainer = Trainer( model=model, args=training_args, train_dataset=lm_dataset["train"], eval_dataset=lm_dataset["validation"], # eval_dataset=lm_dataset["test"], data_collator=data_collator, ) # trainer.train() print('Started Training!') trainer.train() trainer.save_model(OutputDir) print('Saved Model Path:', OutputDir) import math eval_results = trainer.evaluate() print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")