Codexchan-Code_x.exe-codex.sh-codexglhf / Train your own codexchan checkpoint if you prefer using this.py
yjgjhgjh's picture
Create Train your own codexchan checkpoint if you prefer using this.py
14279ce verified
#this script will let you train your own distillgpt checkpoint or fine tune the one in checkpoint-4000
import os
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer, TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments, TrainerCallback # Added TrainerCallback here
from datasets import load_dataset
from datetime import datetime
# Data preparation
data_dir = r"https://github.com/zrebarchak/Codexchan.exe-Archive"
#
"""replace this with folder of txt files
the github link This is the base dataset. it includes all of codexchan's videos where they spoke.
theres nothing wrong with the errored folder, you should combine it-
and train on them both fom . note that this dataset doesnt include the faq
(https://etherpad.mit.edu/p/r.46c0a7842e569d53dc22b44afed6bc40)
or this https://www.onlinegdb.com/fork/IrQRJkyX0
also note checkpoint-4000 was not trained on these either, just this base dataset. have fun!"""
#
dataset = load_dataset("text", data_files=os.path.join(data_dir, "*.txt"))
# Model and tokenizer setup
model_name = "distilgpt2"
base_output_dir = "./distilgpt2-fine-tuned"
# Generate a unique name for this training run
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = os.path.join(base_output_dir, f"distilgpt2_continuous_{current_time}")
# Function to find the most recent model directory
def find_most_recent_model(base_dir):
if not os.path.exists(base_dir):
return None
subdirs = [os.path.join(base_dir, d) for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]
valid_dirs = [d for d in subdirs if os.path.exists(os.path.join(d, 'config.json'))]
return max(valid_dirs, key=os.path.getmtime) if valid_dirs else None
most_recent_dir = find_most_recent_model(base_output_dir)
if most_recent_dir:
print(f"Loading most recent saved model from: {most_recent_dir}")
try:
model = GPT2LMHeadModel.from_pretrained(most_recent_dir)
tokenizer = GPT2Tokenizer.from_pretrained(most_recent_dir)
except Exception as e:
print(f"Error loading saved model: {e}")
print("Starting with fresh model instead.")
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
else:
print("No valid saved model found. Starting with fresh model...")
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
# Tokenize the dataset
def tokenize_function(examples):
return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
# Training arguments
training_args = TrainingArguments(
output_dir=output_dir,
overwrite_output_dir=True,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
save_steps=1000,
save_total_limit=5,
fp16=True,
gradient_checkpointing=True,
learning_rate=1e-4,
warmup_steps=100,
logging_steps=10, # Log more frequently
max_steps=-1, # No limit on the number of steps
num_train_epochs=215, # This will be ignored due to max_steps=-1
)
# Custom callback to print progress
class ProgressCallback(TrainerCallback):
def __init__(self, total_steps=1000000): # A large number, but not so large it causes display issues
self.total_steps = total_steps
def on_log(self, args, state, control, logs=None, **kwargs):
if state.global_step % 10 == 0: # Print every 10 steps
print(f"Step: {state.global_step}/{self.total_steps} - Loss: {logs.get('loss', 'N/A'):.4f}")
# Trainer setup
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
callbacks=[ProgressCallback()]
)
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Start training
print(f"Starting long-running training. Models will be saved to {output_dir}")
print("Press Ctrl+C to stop...")
try:
trainer.train()
except KeyboardInterrupt:
print("\nTraining interrupted. Saving model...")
trainer.save_model()
print(f"Model saved to {output_dir}. You can resume training later by running this script again.")
print("Training completed or interrupted. Final model saved.")