AmelieSchreiber commited on
Commit
6ce9268
·
1 Parent(s): 2331ba4

Upload 3 files

Browse files
Files changed (3) hide show
  1. ensemble (1).py +113 -0
  2. metrics (1).py +95 -0
  3. train (1).py +193 -0
ensemble (1).py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import numpy as np
4
+ from scipy import stats
5
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, matthews_corrcoef
6
+ from transformers import AutoModelForTokenClassification, Trainer, AutoTokenizer, DataCollatorForTokenClassification
7
+ from datasets import Dataset, concatenate_datasets
8
+ from accelerate import Accelerator
9
+ from peft import PeftModel
10
+ import gc
11
+
12
+ # Step 1: Load train/test data and labels from pickle files
13
+ with open("/kaggle/input/550k-dataset/train_sequences_chunked_by_family.pkl", "rb") as f:
14
+ train_sequences = pickle.load(f)
15
+ with open("/kaggle/input/550k-dataset/test_sequences_chunked_by_family.pkl", "rb") as f:
16
+ test_sequences = pickle.load(f)
17
+ with open("/kaggle/input/550k-dataset/train_labels_chunked_by_family.pkl", "rb") as f:
18
+ train_labels = pickle.load(f)
19
+ with open("/kaggle/input/550k-dataset/test_labels_chunked_by_family.pkl", "rb") as f:
20
+ test_labels = pickle.load(f)
21
+
22
+ # Step 2: Define the Tokenizer
23
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
24
+ max_sequence_length = tokenizer.model_max_length
25
+
26
+ # Step 3: Define a `compute_metrics_for_batch` function.
27
+ def compute_metrics_for_batch(sequences_batch, labels_batch, models, voting='hard'):
28
+ # Tokenize batch
29
+ batch_tokenized = tokenizer(sequences_batch, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
30
+ # print("Shape of tokenized sequences:", batch_tokenized["input_ids"].shape) # Debug print
31
+
32
+ batch_dataset = Dataset.from_dict({k: v for k, v in batch_tokenized.items()})
33
+ batch_dataset = batch_dataset.add_column("labels", labels_batch[:len(batch_dataset)])
34
+
35
+ # Convert labels to numpy array of shape (1000, 1002)
36
+ labels_array = np.array([np.pad(label, (0, 1002 - len(label)), constant_values=-100) for label in batch_dataset["labels"]])
37
+
38
+ # Initialize a trainer for each model
39
+ data_collator = DataCollatorForTokenClassification(tokenizer)
40
+ trainers = [Trainer(model=model, data_collator=data_collator) for model in models]
41
+
42
+ # Get the predictions from each model
43
+ all_predictions = [trainer.predict(test_dataset=batch_dataset)[0] for trainer in trainers]
44
+
45
+ if voting == 'hard':
46
+ # Hard voting
47
+ hard_predictions = [np.argmax(predictions, axis=2) for predictions in all_predictions]
48
+ ensemble_predictions = stats.mode(hard_predictions, axis=0)[0][0]
49
+ elif voting == 'soft':
50
+ # Soft voting
51
+ avg_predictions = np.mean(all_predictions, axis=0)
52
+ ensemble_predictions = np.argmax(avg_predictions, axis=2)
53
+ else:
54
+ raise ValueError("Voting must be either 'hard' or 'soft'")
55
+
56
+ # Use broadcasting to create 2D mask
57
+ mask_2d = labels_array != -100
58
+
59
+ # Filter true labels and predictions using the mask
60
+ true_labels_list = [label[mask_2d[idx]] for idx, label in enumerate(labels_array)]
61
+ true_labels = np.concatenate(true_labels_list)
62
+ flat_predictions_list = [ensemble_predictions[idx][mask_2d[idx]] for idx in range(ensemble_predictions.shape[0])]
63
+ flat_predictions = np.concatenate(flat_predictions_list).tolist()
64
+
65
+ # Compute the metrics
66
+ accuracy = accuracy_score(true_labels, flat_predictions)
67
+ precision, recall, f1, _ = precision_recall_fscore_support(true_labels, flat_predictions, average='binary')
68
+ auc = roc_auc_score(true_labels, flat_predictions)
69
+ mcc = matthews_corrcoef(true_labels, flat_predictions) # Compute MCC
70
+
71
+ return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "auc": auc, "mcc": mcc}
72
+
73
+ # Step 4: Evaluate in batches
74
+ def evaluate_in_batches(sequences, labels, models, dataset_name, voting, batch_size=1000, print_first_n=5):
75
+ num_batches = len(sequences) // batch_size + int(len(sequences) % batch_size != 0)
76
+ metrics_list = []
77
+
78
+ for i in range(num_batches):
79
+ start_idx = i * batch_size
80
+ end_idx = start_idx + batch_size
81
+ batch_metrics = compute_metrics_for_batch(sequences[start_idx:end_idx], labels[start_idx:end_idx], models, voting)
82
+
83
+ # Print metrics for the first few batches for both train and test datasets
84
+ if i < print_first_n:
85
+ print(f"{dataset_name} - Batch {i+1}/{num_batches} metrics: {batch_metrics}")
86
+
87
+ metrics_list.append(batch_metrics)
88
+
89
+ # Average metrics over all batches
90
+ avg_metrics = {key: np.mean([metrics[key] for metrics in metrics_list]) for key in metrics_list[0]}
91
+ return avg_metrics
92
+
93
+ # Step 5: Load pre-trained base model and fine-tuned LoRA models
94
+ accelerator = Accelerator()
95
+ base_model_path = "facebook/esm2_t12_35M_UR50D"
96
+ base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
97
+ lora_model_paths = [
98
+ "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_cp1",
99
+ "AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp1",
100
+ ]
101
+ models = [PeftModel.from_pretrained(base_model, path) for path in lora_model_paths]
102
+ models = [accelerator.prepare(model) for model in models]
103
+
104
+ # Step 6: Compute and print the metrics
105
+ test_metrics_soft = evaluate_in_batches(test_sequences, test_labels, models, "test", voting='soft')
106
+ train_metrics_soft = evaluate_in_batches(train_sequences, train_labels, models, "train", voting='soft')
107
+ test_metrics_hard = evaluate_in_batches(test_sequences, test_labels, models, "test", voting='hard')
108
+ train_metrics_hard = evaluate_in_batches(train_sequences, train_labels, models, "train", voting='hard')
109
+
110
+ print("Test metrics (soft voting):", test_metrics_soft)
111
+ print("Train metrics (soft voting):", train_metrics_soft)
112
+ print("Test metrics (hard voting):", test_metrics_hard)
113
+ print("Train metrics (hard voting):", train_metrics_hard)
metrics (1).py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import wandb
3
+ import numpy as np
4
+ import pickle
5
+ import torch
6
+ import torch.nn as nn
7
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, matthews_corrcoef
8
+ from transformers import AutoModelForTokenClassification, AutoTokenizer, DataCollatorForTokenClassification, Trainer
9
+ from datasets import Dataset
10
+ from accelerate import Accelerator
11
+ from peft import PeftModel
12
+
13
+ # Helper functions and data preparation
14
+ def truncate_labels(labels, max_length):
15
+ """Truncate labels to the specified max_length."""
16
+ return [label[:max_length] for label in labels]
17
+
18
+ def compute_metrics(p):
19
+ """Compute metrics for evaluation."""
20
+ predictions, labels = p
21
+ predictions = np.argmax(predictions, axis=2)
22
+
23
+ # Remove padding (-100 labels)
24
+ predictions = predictions[labels != -100].flatten()
25
+ labels = labels[labels != -100].flatten()
26
+
27
+ # Compute accuracy
28
+ accuracy = accuracy_score(labels, predictions)
29
+
30
+ # Compute precision, recall, F1 score, and AUC
31
+ precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
32
+ auc = roc_auc_score(labels, predictions)
33
+
34
+ # Compute MCC
35
+ mcc = matthews_corrcoef(labels, predictions)
36
+
37
+ return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}
38
+
39
+ class WeightedTrainer(Trainer):
40
+ def compute_loss(self, model, inputs, return_outputs=False):
41
+ """Custom compute_loss function."""
42
+ outputs = model(**inputs)
43
+ loss_fct = nn.CrossEntropyLoss()
44
+ active_loss = inputs["attention_mask"].view(-1) == 1
45
+ active_logits = outputs.logits.view(-1, model.config.num_labels)
46
+ active_labels = torch.where(
47
+ active_loss, inputs["labels"].view(-1), torch.tensor(loss_fct.ignore_index).type_as(inputs["labels"])
48
+ )
49
+ loss = loss_fct(active_logits, active_labels)
50
+ return (loss, outputs) if return_outputs else loss
51
+
52
+ if __name__ == "__main__":
53
+ # Environment setup
54
+ accelerator = Accelerator()
55
+ wandb.init(project='binding_site_prediction')
56
+
57
+ # Load data and labels
58
+ with open("600K_data/train_sequences_chunked_by_family.pkl", "rb") as f:
59
+ train_sequences = pickle.load(f)
60
+ with open("600K_data/test_sequences_chunked_by_family.pkl", "rb") as f:
61
+ test_sequences = pickle.load(f)
62
+ with open("600K_data/train_labels_chunked_by_family.pkl", "rb") as f:
63
+ train_labels = pickle.load(f)
64
+ with open("600K_data/test_labels_chunked_by_family.pkl", "rb") as f:
65
+ test_labels = pickle.load(f)
66
+
67
+ # Tokenization and dataset creation
68
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
69
+ max_sequence_length = tokenizer.model_max_length
70
+ train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
71
+ test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
72
+ train_labels = truncate_labels(train_labels, max_sequence_length)
73
+ test_labels = truncate_labels(test_labels, max_sequence_length)
74
+ train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
75
+ test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
76
+
77
+ # Load the pre-trained LoRA model
78
+ base_model_path = "facebook/esm2_t12_35M_UR50D"
79
+ lora_model_path = "esm2_t12_35M_lora_binding_sites_2023-09-21_17-50-58/checkpoint-84029" # Replace with the correct path to your LoRA model
80
+ base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)
81
+ model = PeftModel.from_pretrained(base_model, lora_model_path)
82
+ model = accelerator.prepare(model)
83
+
84
+ # Define a function to compute metrics and get the train/test metrics
85
+ data_collator = DataCollatorForTokenClassification(tokenizer)
86
+ trainer = Trainer(model=model, data_collator=data_collator, compute_metrics=compute_metrics)
87
+ train_metrics = trainer.evaluate(train_dataset)
88
+ test_metrics = trainer.evaluate(test_dataset)
89
+
90
+ # Print the metrics
91
+ print(f"Train metrics: {train_metrics}")
92
+ print(f"Test metrics: {test_metrics}")
93
+
94
+ # Log metrics to W&B
95
+ wandb.log({"Train metrics": train_metrics, "Test metrics": test_metrics})
train (1).py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import wandb
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from datetime import datetime
7
+ from sklearn.model_selection import train_test_split
8
+ from sklearn.utils.class_weight import compute_class_weight
9
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, matthews_corrcoef
10
+ from transformers import AutoModelForTokenClassification, AutoTokenizer, DataCollatorForTokenClassification, TrainingArguments, Trainer
11
+ from datasets import Dataset
12
+ from accelerate import Accelerator
13
+ from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType
14
+ import pickle
15
+
16
+ # Initialize accelerator and Weights & Biases
17
+ accelerator = Accelerator()
18
+ os.environ["WANDB_NOTEBOOK_NAME"] = 'train.py'
19
+ wandb.init(project='binding_site_prediction')
20
+
21
+ # Helper Functions and Data Preparation
22
+ def save_config_to_txt(config, filename):
23
+ """Save the configuration dictionary to a text file."""
24
+ with open(filename, 'w') as f:
25
+ for key, value in config.items():
26
+ f.write(f"{key}: {value}\n")
27
+
28
+ def truncate_labels(labels, max_length):
29
+ return [label[:max_length] for label in labels]
30
+
31
+ def compute_metrics(p):
32
+ predictions, labels = p
33
+ predictions = np.argmax(predictions, axis=2)
34
+ predictions = predictions[labels != -100].flatten()
35
+ labels = labels[labels != -100].flatten()
36
+ accuracy = accuracy_score(labels, predictions)
37
+ precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary')
38
+ auc = roc_auc_score(labels, predictions)
39
+ mcc = matthews_corrcoef(labels, predictions)
40
+ return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc, 'mcc': mcc}
41
+
42
+ def compute_loss(model, inputs):
43
+ logits = model(**inputs).logits
44
+ labels = inputs["labels"]
45
+ loss_fct = nn.CrossEntropyLoss(weight=class_weights)
46
+ active_loss = inputs["attention_mask"].view(-1) == 1
47
+ active_logits = logits.view(-1, model.config.num_labels)
48
+ active_labels = torch.where(
49
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
50
+ )
51
+ loss = loss_fct(active_logits, active_labels)
52
+ return loss
53
+
54
+ # Load data from pickle files
55
+ with open("600K_data/train_sequences_chunked_by_family.pkl", "rb") as f:
56
+ train_sequences = pickle.load(f)
57
+
58
+ with open("600K_data/test_sequences_chunked_by_family.pkl", "rb") as f:
59
+ test_sequences = pickle.load(f)
60
+
61
+ with open("600K_data/train_labels_chunked_by_family.pkl", "rb") as f:
62
+ train_labels = pickle.load(f)
63
+
64
+ with open("600K_data/test_labels_chunked_by_family.pkl", "rb") as f:
65
+ test_labels = pickle.load(f)
66
+
67
+ # Tokenization
68
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
69
+
70
+ # Set max_sequence_length to the tokenizer's max input length
71
+ max_sequence_length = tokenizer.model_max_length
72
+
73
+ train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
74
+ test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False)
75
+
76
+ # Directly truncate the entire list of labels
77
+ train_labels = truncate_labels(train_labels, max_sequence_length)
78
+ test_labels = truncate_labels(test_labels, max_sequence_length)
79
+
80
+ train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column("labels", train_labels)
81
+ test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column("labels", test_labels)
82
+
83
+ # Compute Class Weights
84
+ classes = [0, 1]
85
+ flat_train_labels = [label for sublist in train_labels for label in sublist]
86
+ class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=flat_train_labels)
87
+ class_weights = torch.tensor(class_weights, dtype=torch.float32).to(accelerator.device)
88
+
89
+ # Define Custom Trainer Class
90
+ class WeightedTrainer(Trainer):
91
+ def compute_loss(self, model, inputs, return_outputs=False):
92
+ outputs = model(**inputs)
93
+ loss = compute_loss(model, inputs)
94
+ return (loss, outputs) if return_outputs else loss
95
+
96
+ # Define and run training function
97
+ def train_function_no_sweeps(train_dataset, test_dataset):
98
+
99
+ # Directly set the config
100
+ config = {
101
+ "lora_alpha": 1,
102
+ "lora_dropout": 0.4,
103
+ "lr": 5.701568055793089e-04,
104
+ "lr_scheduler_type": "cosine",
105
+ "max_grad_norm": 0.5,
106
+ "num_train_epochs": 1,
107
+ "per_device_train_batch_size": 6,
108
+ "r": 1,
109
+ "weight_decay": 0.4,
110
+ # Add other hyperparameters as needed
111
+ }
112
+
113
+ # Log the config to W&B
114
+ wandb.config.update(config)
115
+
116
+ # Save the config to a text file
117
+ timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
118
+ config_filename = f"esm2_t12_35M_lora_config_{timestamp}.txt"
119
+ save_config_to_txt(config, config_filename)
120
+
121
+ model_checkpoint = "facebook/esm2_t12_35M_UR50D"
122
+
123
+ # Define labels and model
124
+ id2label = {0: "No binding site", 1: "Binding site"}
125
+ label2id = {v: k for k, v in id2label.items()}
126
+ model = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(id2label), id2label=id2label, label2id=label2id)
127
+
128
+ # Convert the model into a PeftModel
129
+ peft_config = LoraConfig(
130
+ task_type=TaskType.TOKEN_CLS,
131
+ inference_mode=False,
132
+ r=config["r"],
133
+ lora_alpha=config["lora_alpha"],
134
+ target_modules=["query", "key", "value"], # also maybe "dense_h_to_4h" and "dense_4h_to_h"
135
+ lora_dropout=config["lora_dropout"],
136
+ bias="none" # or "all" or "lora_only"
137
+ )
138
+ model = get_peft_model(model, peft_config)
139
+
140
+ # Use the accelerator
141
+ model = accelerator.prepare(model)
142
+ train_dataset = accelerator.prepare(train_dataset)
143
+ test_dataset = accelerator.prepare(test_dataset)
144
+
145
+ timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
146
+
147
+ # Training setup
148
+ training_args = TrainingArguments(
149
+ output_dir=f"esm2_t12_35M_lora_binding_sites_{timestamp}",
150
+ learning_rate=config["lr"],
151
+ lr_scheduler_type=config["lr_scheduler_type"],
152
+ gradient_accumulation_steps=1,
153
+ max_grad_norm=config["max_grad_norm"],
154
+ per_device_train_batch_size=config["per_device_train_batch_size"],
155
+ per_device_eval_batch_size=config["per_device_train_batch_size"],
156
+ num_train_epochs=config["num_train_epochs"],
157
+ weight_decay=config["weight_decay"],
158
+ evaluation_strategy="epoch",
159
+ save_strategy="epoch",
160
+ load_best_model_at_end=True,
161
+ metric_for_best_model="f1",
162
+ greater_is_better=True,
163
+ push_to_hub=False,
164
+ logging_dir=None,
165
+ logging_first_step=False,
166
+ logging_steps=200,
167
+ save_total_limit=7,
168
+ no_cuda=False,
169
+ seed=8893,
170
+ fp16=True,
171
+ report_to='wandb'
172
+ )
173
+
174
+ # Initialize Trainer
175
+ trainer = WeightedTrainer(
176
+ model=model,
177
+ args=training_args,
178
+ train_dataset=train_dataset,
179
+ eval_dataset=test_dataset,
180
+ tokenizer=tokenizer,
181
+ data_collator=DataCollatorForTokenClassification(tokenizer=tokenizer),
182
+ compute_metrics=compute_metrics
183
+ )
184
+
185
+ # Train and Save Model
186
+ trainer.train()
187
+ save_path = os.path.join("lora_binding_sites", f"best_model_esm2_t12_35M_lora_{timestamp}")
188
+ trainer.save_model(save_path)
189
+ tokenizer.save_pretrained(save_path)
190
+
191
+ # Call the training function
192
+ if __name__ == "__main__":
193
+ train_function_no_sweeps(train_dataset, test_dataset)