File size: 10,898 Bytes
5a7892c
1
2
{"cells":[{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2024-08-14T11:36:45.359507Z","iopub.status.busy":"2024-08-14T11:36:45.358807Z","iopub.status.idle":"2024-08-14T11:38:49.639719Z","shell.execute_reply":"2024-08-14T11:38:49.638484Z","shell.execute_reply.started":"2024-08-14T11:36:45.359475Z"},"trusted":true},"outputs":[],"source":["# For any HF basic activities like loading models\n","# and tokenizers for running inference\n","# upgrade is a must for the newest Gemma model\n","!pip install -q --upgrade datasets\n","!pip install -q --upgrade transformers\n","\n","# For doing efficient stuff - PEFT\n","!pip install -q --upgrade peft\n","!pip install -q --upgrade trl\n","!pip install -q bitsandbytes\n","!pip install -q accelerate\n","\n","# for logging and visualizing training progress\n","!pip install -q tensorboard\n","# If creating a new dataset, useful for creating *.jsonl files\n","!pip install -q jsonlines"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2024-08-14T09:14:14.730158Z","iopub.status.busy":"2024-08-14T09:14:14.729841Z","iopub.status.idle":"2024-08-14T09:15:58.023466Z","shell.execute_reply":"2024-08-14T09:15:58.022512Z","shell.execute_reply.started":"2024-08-14T09:14:14.730128Z"},"trusted":true},"outputs":[],"source":["! conda install -y gdown"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2024-08-14T09:17:54.015233Z","iopub.status.busy":"2024-08-14T09:17:54.014883Z","iopub.status.idle":"2024-08-14T09:19:23.458473Z","shell.execute_reply":"2024-08-14T09:19:23.457105Z","shell.execute_reply.started":"2024-08-14T09:17:54.015206Z"},"trusted":true},"outputs":[],"source":["import itertools\n","import time\n","import warnings\n","from peft import LoraConfig, get_peft_model\n","from transformers import BertForSequenceClassification, AutoTokenizer, TrainingArguments, Trainer\n","from matplotlib import pyplot as plt\n","from datasets import load_dataset\n","import torch\n","from tqdm import tqdm\n","import numpy as np\n","from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, matthews_corrcoef, roc_auc_score\n","import huggingface_hub\n","\n","huggingface_hub.login(token=hf_token)\n","\n","# Suppress warnings\n","warnings.filterwarnings(\"ignore\")\n","\n","# Layer configurations\n","attention_plus_feed_forward = [\n","    \"bert.encoder.layer.0.attention.self.query\",\n","    \"bert.encoder.layer.0.attention.self.key\",\n","    \"bert.encoder.layer.0.attention.self.value\",\n","    \"bert.encoder.layer.0.attention.output.dense\",\n","    \"bert.encoder.layer.0.intermediate.dense\",\n","    \"bert.encoder.layer.0.output.dense\",\n","    \"bert.encoder.layer.1.attention.self.query\",\n","    \"bert.encoder.layer.1.attention.self.key\",\n","    \"bert.encoder.layer.1.attention.self.value\",\n","    \"bert.encoder.layer.1.attention.output.dense\",\n","    \"bert.encoder.layer.1.intermediate.dense\",\n","    \"bert.encoder.layer.1.output.dense\"\n","]\n","\n","\n","tokenizer = AutoTokenizer.from_pretrained('zhihan1996/DNA_bert_6')\n","# Function to preprocess the dataset\n","def preprocess_function(examples):\n","    try:\n","        return tokenizer(\n","            examples['sequence'],\n","            padding='max_length',\n","            truncation=True,\n","            max_length=512\n","        )\n","    except KeyError:\n","        return tokenizer(\n","            examples['Sequence'],\n","            padding='max_length',\n","            truncation=True,\n","            max_length=512\n","        )\n","\n","\n","def add_labels(examples):\n","    try:\n","        examples['labels'] = examples['label']\n","        return examples\n","    except KeyError:\n","        examples['labels'] = examples['Label']\n","        return examples\n","\n","def create_task_dataset(task_name):\n","    if task_name == 'tfbs':\n","        return load_dataset('csv', data_files='/kaggle/working/tfbs.csv', split='train[0:10000]'), load_dataset('csv', data_files='/kaggle/working/tfbs.csv', split='train[10001:13122]')\n","\n","    elif task_name == 'dnasplice':\n","        return load_dataset('csv', data_files='/kaggle/working/dnasplice.csv', split='train[0:10000]'), load_dataset('csv', data_files='/kaggle/working/dnasplice.csv', split='train[10001:13122]')\n","\n","    elif task_name == 'dnaprom':\n","        return load_dataset('csv', data_files='/kaggle/working/dnaprom.csv', split='train[0:10000]'), load_dataset('csv', data_files='/kaggle/working/dnaprom.csv', split='train[10001:13122]')\n","\n","    else:\n","        raise ValueError(f\"Unknown task: {task_name}\")\n","\n","def create_dataset_maps(train_dataset, test_dataset):\n","    train_dataset = train_dataset.map(preprocess_function, batched=True)\n","    train_dataset = train_dataset.map(add_labels)\n","    test_dataset = test_dataset.map(preprocess_function, batched=True)\n","    test_dataset = test_dataset.map(add_labels)\n","    return train_dataset, test_dataset\n","\n","def train_model(train_dataset, test_dataset, model, task, model_name, config_name):\n","    def specificity_score(y_true, y_pred):\n","        true_negatives = np.sum((y_pred == 0) & (y_true == 0))\n","        false_positives = np.sum((y_pred == 1) & (y_true == 0))\n","        specificity = true_negatives / (true_negatives + false_positives + np.finfo(float).eps)\n","        return specificity\n","\n","    def compute_metrics(eval_pred):\n","        logits, labels = eval_pred\n","        predictions = np.argmax(logits, axis=-1)\n","        y_pred = logits[:, 1]\n","\n","        accuracy = accuracy_score(labels, predictions)\n","        recall = recall_score(labels, predictions)\n","        specificity = specificity_score(labels, predictions)\n","        mcc = matthews_corrcoef(labels, predictions)\n","        roc_auc = roc_auc_score(labels, y_pred)\n","        precision = precision_score(labels, predictions)\n","        f1 = f1_score(labels, predictions)\n","\n","        true_pos = np.sum((predictions == 1) & (labels == 1))\n","        true_neg = np.sum((predictions == 0) & (labels == 0))\n","        false_pos = np.sum((predictions == 1) & (labels == 0))\n","        false_neg = np.sum((predictions == 0) & (labels == 1))\n","\n","        return {\n","            'accuracy': accuracy,\n","            'recall': recall,\n","            'specificity': specificity,\n","            'mcc': mcc,\n","            'roc_auc': roc_auc,\n","            'precision': precision,\n","            'f1': f1,\n","            'true_pos': true_pos,\n","            'true_neg': true_neg,\n","            'false_pos': false_pos,\n","            'false_neg': false_neg\n","        }\n","\n","    # Define the training arguments\n","    training_arguments = TrainingArguments(\n","        output_dir=f\"outputs/{task}/{model_name}_{config_name}\",\n","        num_train_epochs=25,\n","        fp16=False,\n","        bf16=False,\n","        per_device_train_batch_size=20,\n","        per_device_eval_batch_size=10,\n","        gradient_accumulation_steps=2,\n","        gradient_checkpointing=True,\n","        max_grad_norm=0.3,\n","        learning_rate=4e-4,\n","        weight_decay=0.01,\n","        optim=\"paged_adamw_32bit\",\n","        lr_scheduler_type=\"linear\",\n","        max_steps=-1,\n","        warmup_ratio=0.03,\n","        group_by_length=True,\n","        save_steps=1000,\n","        logging_steps=25,\n","        dataloader_pin_memory=False,\n","        report_to='tensorboard',\n","        gradient_checkpointing_kwargs={'use_reentrant': False}\n","    )\n","\n","    trainer = Trainer(\n","        model=model,\n","        args=training_arguments,\n","        train_dataset=train_dataset,\n","        eval_dataset=test_dataset,\n","        tokenizer=tokenizer,\n","        compute_metrics=compute_metrics,\n","    )\n","\n","    start_time = time.time()\n","    trainer.train()\n","    end_time = time.time()\n","\n","    total_time = end_time - start_time\n","    metrics = trainer.evaluate()\n","\n","    return total_time, metrics\n","\n","# Task loop\n","task_list = ['dnasplice', 'tfbs', 'dnaprom']\n","log_file = \"training_log.txt\"\n","model_name = 'fabihamakhdoomi/TinyDNABERT'\n","for task in task_list:\n","    print(f\"Running TASK : {task}\")\n","    train_dataset, test_dataset = create_task_dataset(task)\n","    train_dataset, test_dataset = create_dataset_maps(train_dataset, test_dataset)\n","    train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])\n","    test_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])\n","\n","\n","    # Train the base model first\n","    base_model = BertForSequenceClassification.from_pretrained(\n","        model_name,\n","        num_labels=2\n","    )\n","    config_name = \"base_model\"\n","    print(f\"Training MODEL : {config_name} for task : {task}\")\n","    training_time, metrics = train_model(train_dataset, test_dataset, base_model, task, model_name, config_name)\n","    with open(log_file, \"a\") as log:\n","        log.write(f\"Task: {task}, Model: {model_name}, Config: {config_name}, Training Time: {training_time}, Metrics: {metrics}\\n\")\n","\n","    # Train the LoRA models\n","    config_name = \"attention_plus_feed_forward\"\n","    base_model = BertForSequenceClassification.from_pretrained(\n","    model_name,\n","    num_labels=2\n","    )\n","    if task == 'dnasplice':\n","        r_value = 4\n","        print('Setting r value to 4 for dnasplice')\n","    else:\n","        r_value = 8\n","    peft_config = LoraConfig(\n","        lora_alpha=16,\n","        lora_dropout=0.2,\n","        r=r_value,\n","        bias=\"none\",\n","        task_type=\"SEQ_CLS\",\n","        target_modules=attention_plus_feed_forward\n","    )\n","    model = get_peft_model(base_model, peft_config)\n","    print(f\"Training MODEL : {config_name} for task : {task}\")\n","    training_time, metrics = train_model(train_dataset, test_dataset, model, task, model_name, config_name)\n","    with open(log_file, \"a\") as log:\n","        log.write(f\"Task: {task}, Model: {model_name}, Config: {config_name}, Training Time: {training_time}, Metrics: {metrics}\\n\")\n"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":[]}],"metadata":{"kaggle":{"accelerator":"gpu","dataSources":[],"dockerImageVersionId":30747,"isGpuEnabled":true,"isInternetEnabled":true,"language":"python","sourceType":"notebook"},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.13"}},"nbformat":4,"nbformat_minor":4}