ssmits commited on
Commit
17db2ea
·
verified ·
1 Parent(s): 6f5fc0f

Upload 2 files

Browse files
Files changed (2) hide show
  1. finetune.py +111 -0
  2. optimize_lr.py +401 -0
finetune.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ CONTEXT_WINDOW = 1024 #has to fit in 4090
4
+ HF_TOKEN = os.getenv("HF_TOKEN")
5
+
6
+ from transformers import (
7
+ AutoTokenizer, AutoModelForCausalLM, TrainingArguments,
8
+ Trainer, DataCollatorForLanguageModeling
9
+ )
10
+ import torch
11
+ from datasets import load_dataset
12
+ from huggingface_hub import login
13
+
14
+ # setup tokenizer
15
+ tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-1.2B-instruct", token=HF_TOKEN)
16
+ if tokenizer.pad_token is None:
17
+ tokenizer.pad_token = tokenizer.eos_token
18
+ tokenizer.padding_side = "left" # better for inference
19
+
20
+ # init model with auto device mapping
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ "Zyphra/Zamba2-1.2B-instruct",
23
+ torch_dtype=torch.bfloat16,
24
+ device_map="auto" # handles multi-gpu/cpu mapping
25
+ )
26
+ model.config.pad_token_id = tokenizer.pad_token_id
27
+
28
+ # Load the Dutch Dolly dataset
29
+ dataset = load_dataset("BramVanroy/dolly-15k-dutch", split="train_sft")
30
+
31
+ def prepare_chat_format(examples):
32
+ chats = []
33
+ for messages in examples['messages']:
34
+ try:
35
+ chat = tokenizer.apply_chat_template(
36
+ messages,
37
+ tokenize=True,
38
+ max_length=CONTEXT_WINDOW,
39
+ truncation=True,
40
+ return_tensors=None
41
+ )
42
+ except Exception as e:
43
+ print(f"Error applying chat template: {e}")
44
+ # Fallback format if chat template fails
45
+ text = ""
46
+ for message in messages:
47
+ role = message["role"]
48
+ content = message["content"]
49
+ text += f"<|{role}|>\n{content}</s>\n"
50
+
51
+ chat = tokenizer(
52
+ text,
53
+ max_length=CONTEXT_WINDOW,
54
+ truncation=True,
55
+ return_tensors=None
56
+ )["input_ids"]
57
+
58
+ chats.append(chat)
59
+ return {"input_ids": chats}
60
+
61
+ # Process the dataset
62
+ tokenized_dataset = dataset.map(
63
+ prepare_chat_format,
64
+ batched=True,
65
+ remove_columns=dataset.column_names
66
+ )
67
+
68
+ # training config
69
+ training_args = TrainingArguments(
70
+ output_dir="./zamba2-finetuned",
71
+ num_train_epochs=2,
72
+ per_device_train_batch_size=4,
73
+ save_steps=500,
74
+ save_total_limit=2,
75
+ logging_steps=100,
76
+ learning_rate=2e-5,
77
+ weight_decay=0.01,
78
+ fp16=False,
79
+ bf16=True,
80
+ gradient_accumulation_steps=8,
81
+ dataloader_num_workers=4,
82
+ gradient_checkpointing=True,
83
+ max_grad_norm=1.0,
84
+ warmup_steps=100
85
+ )
86
+
87
+ data_collator = DataCollatorForLanguageModeling(
88
+ tokenizer=tokenizer,
89
+ mlm=False
90
+ )
91
+
92
+ # custom trainer to handle device mapping
93
+ class CustomTrainer(Trainer):
94
+ def __init__(self, *args, **kwargs):
95
+ super().__init__(*args, **kwargs)
96
+ self.model = model
97
+
98
+ def _move_model_to_device(self, model, device):
99
+ pass # model already mapped to devices
100
+
101
+ trainer = CustomTrainer(
102
+ model=model,
103
+ args=training_args,
104
+ train_dataset=tokenized_dataset,
105
+ data_collator=data_collator
106
+ )
107
+
108
+ # Add explicit training and saving steps
109
+ trainer.train()
110
+ model.save_pretrained("./zamba2-finetuned-final")
111
+ tokenizer.save_pretrained("./zamba2-finetuned-final")
optimize_lr.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import optuna
2
+ from transformers import (
3
+ AutoTokenizer, AutoModelForCausalLM, TrainingArguments,
4
+ Trainer, DataCollatorForLanguageModeling
5
+ )
6
+ import torch
7
+ from datasets import load_dataset
8
+ import numpy as np
9
+ import gc
10
+ from sklearn.gaussian_process import GaussianProcessRegressor
11
+ from sklearn.gaussian_process.kernels import ConstantKernel, Matern
12
+ import matplotlib.pyplot as plt
13
+ from scipy.stats import norm
14
+ import warnings
15
+ warnings.filterwarnings('ignore', category=UserWarning)
16
+
17
+ from transformers import TrainerCallback
18
+
19
+ import argparse
20
+
21
+ # Configuration parameters
22
+ num_trials = 10 # Adjust this value to control the number of optimization trials
23
+ DATASET = load_dataset("BramVanroy/dolly-15k-dutch", split="train_sft[:1000]")
24
+ CONTEXT_WINDOW = 1024
25
+
26
+ # Initialize tokenizer once
27
+ tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-1.2B")
28
+ if tokenizer.pad_token is None:
29
+ tokenizer.pad_token = tokenizer.eos_token
30
+ tokenizer.padding_side = "left"
31
+
32
+ def prepare_chat_format(examples):
33
+ chats = []
34
+ for messages in examples['messages']:
35
+ try:
36
+ chat = tokenizer.apply_chat_template(
37
+ messages,
38
+ tokenize=True,
39
+ max_length=CONTEXT_WINDOW,
40
+ truncation=True,
41
+ return_tensors=None
42
+ )
43
+ chats.append(chat)
44
+ except Exception as e:
45
+ print(f"Error applying chat template: {e}")
46
+ print("Fallback format if chat template fails")
47
+ text = ""
48
+ for message in messages:
49
+ role = message["role"]
50
+ content = message["content"]
51
+ text += f"<|{role}|>\n{content}</s>\n"
52
+
53
+ chat = tokenizer(
54
+ text,
55
+ max_length=CONTEXT_WINDOW,
56
+ truncation=True,
57
+ return_tensors=None
58
+ )["input_ids"]
59
+
60
+ chats.append(chat)
61
+ return {"input_ids": chats}
62
+
63
+ # Prepare dataset once
64
+ tokenized_dataset = DATASET.map(
65
+ prepare_chat_format,
66
+ batched=True,
67
+ remove_columns=DATASET.column_names
68
+ )
69
+
70
+ def clear_memory():
71
+ """Clear GPU memory between trials"""
72
+ if torch.cuda.is_available():
73
+ torch.cuda.empty_cache()
74
+ gc.collect()
75
+
76
+ class LossCallback(TrainerCallback):
77
+ def __init__(self):
78
+ self.losses = []
79
+
80
+ def on_log(self, args, state, control, logs=None, **kwargs):
81
+ if logs is not None and "loss" in logs:
82
+ self.losses.append(logs["loss"])
83
+
84
+ def objective(trial):
85
+ # Clear memory from previous trial
86
+ clear_memory()
87
+
88
+ lr = trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True)
89
+
90
+ # Initialize model with fresh state
91
+ torch.manual_seed(42)
92
+ model = AutoModelForCausalLM.from_pretrained(
93
+ "Zyphra/Zamba2-1.2B",
94
+ torch_dtype=torch.bfloat16,
95
+ device_map="auto"
96
+ )
97
+ model.config.pad_token_id = tokenizer.pad_token_id
98
+
99
+ # Calculate steps with larger batch size
100
+ batch_size = 4 # Increased from 1
101
+ grad_accum_steps = 8 # Decreased from 32 since we increased batch size
102
+ effective_batch_size = batch_size * grad_accum_steps # Still 32 total
103
+ total_steps = len(tokenized_dataset) // effective_batch_size
104
+
105
+ # Training arguments
106
+ training_args = TrainingArguments(
107
+ output_dir=f"./optuna_runs/trial_{trial.number}",
108
+ num_train_epochs=1,
109
+ per_device_train_batch_size=batch_size, # Increased
110
+ gradient_accumulation_steps=grad_accum_steps, # Decreased
111
+ logging_steps=max(total_steps // 20, 1),
112
+ learning_rate=lr,
113
+ weight_decay=0.01,
114
+ fp16=False,
115
+ bf16=True,
116
+ warmup_steps=total_steps // 10,
117
+ save_steps=1000000,
118
+ save_total_limit=None,
119
+ report_to="none",
120
+ seed=42,
121
+ dataloader_num_workers=4, # Added for faster data loading
122
+ gradient_checkpointing=True, # Added to optimize memory usage
123
+ max_grad_norm=1.0 # Added for stability
124
+ )
125
+
126
+ print(f"\nTrial {trial.number}:")
127
+ print(f"Learning rate: {lr}")
128
+ print(f"Total steps: {total_steps}")
129
+ print(f"Logging every {training_args.logging_steps} steps")
130
+
131
+ data_collator = DataCollatorForLanguageModeling(
132
+ tokenizer=tokenizer,
133
+ mlm=False
134
+ )
135
+
136
+ class CustomTrainer(Trainer):
137
+ def __init__(self, *args, **kwargs):
138
+ super().__init__(*args, **kwargs)
139
+ self.model = model
140
+
141
+ def _move_model_to_device(self, model, device):
142
+ pass
143
+
144
+ # Initialize callback
145
+ loss_callback = LossCallback()
146
+
147
+ trainer = CustomTrainer(
148
+ model=model,
149
+ args=training_args,
150
+ train_dataset=tokenized_dataset,
151
+ data_collator=data_collator,
152
+ callbacks=[loss_callback] # Use the proper callback
153
+ )
154
+
155
+ try:
156
+ train_result = trainer.train()
157
+
158
+ # Calculate mean of last 20% of losses
159
+ losses = loss_callback.losses # Get losses from callback
160
+ n_losses = max(len(losses) // 5, 1)
161
+ final_losses = losses[-n_losses:]
162
+ mean_loss = np.mean(final_losses) if final_losses else float('inf')
163
+
164
+ # Clean up
165
+ del model
166
+ del trainer
167
+ clear_memory()
168
+
169
+ return mean_loss
170
+
171
+ except Exception as e:
172
+ print(f"Trial failed with error: {e}")
173
+ # Clean up on failure
174
+ del model
175
+ del trainer
176
+ clear_memory()
177
+ return float('inf')
178
+
179
+ # Create and run the study
180
+ study = optuna.create_study(
181
+ direction="minimize",
182
+ sampler=optuna.samplers.TPESampler(seed=42),
183
+ study_name="learning_rate_optimization"
184
+ )
185
+
186
+ study.optimize(objective, n_trials=num_trials)
187
+
188
+ # Print results
189
+ print(f"\nOptimization Results ({num_trials} trials):")
190
+ print("Best learning rate:", study.best_params["learning_rate"])
191
+ print("Best loss:", study.best_value)
192
+ print("\nAll trials:")
193
+ for trial in study.trials:
194
+ print(f"Learning rate: {trial.params['learning_rate']:.2e}, Loss: {trial.value:.4f}")
195
+
196
+ # Save results
197
+ import json
198
+ results = {
199
+ "best_learning_rate": study.best_params["learning_rate"],
200
+ "best_loss": study.best_value,
201
+ "all_trials": [(trial.params["learning_rate"], trial.value) for trial in study.trials]
202
+ }
203
+ with open("lr_optimization_results.json", "w") as f:
204
+ json.dump(results, f, indent=4)
205
+
206
+ # Plot optimization history
207
+ try:
208
+ fig = optuna.visualization.plot_optimization_history(study)
209
+ fig.show()
210
+ except Exception as e:
211
+ print(f"Could not create visualization: {e}")
212
+
213
+ # Add sophisticated final optimization using Gaussian Process Regression
214
+ def optimize_final_lr(study):
215
+ try:
216
+ # Extract learning rates and losses
217
+ X = np.array([[trial.params['learning_rate']] for trial in study.trials])
218
+ y = np.array([trial.value for trial in study.trials])
219
+
220
+ # Check if we have any valid results
221
+ valid_mask = np.isfinite(y)
222
+ if not np.any(valid_mask):
223
+ print("No valid trials found. Returning default learning rate.")
224
+ return {
225
+ 'gpr_optimal_lr': 2e-5, # default fallback
226
+ 'ei_optimal_lr': 2e-5,
227
+ 'predicted_loss': float('inf'),
228
+ 'uncertainty': float('inf')
229
+ }
230
+
231
+ # Filter out infinite values
232
+ X = X[valid_mask]
233
+ y = y[valid_mask]
234
+
235
+ # Ensure we have enough points for fitting
236
+ if len(X) < 2:
237
+ print("Not enough valid trials for GPR. Returning best observed value.")
238
+ best_idx = np.argmin(y)
239
+ return {
240
+ 'gpr_optimal_lr': float(X[best_idx][0]),
241
+ 'ei_optimal_lr': float(X[best_idx][0]),
242
+ 'predicted_loss': float(y[best_idx]),
243
+ 'uncertainty': float('inf')
244
+ }
245
+
246
+ # Transform to log space
247
+ X_log = np.log10(X)
248
+
249
+ # Normalize y values
250
+ y_mean = np.mean(y)
251
+ y_std = np.std(y)
252
+ if y_std == 0:
253
+ y_std = 1
254
+ y_normalized = (y - y_mean) / y_std
255
+
256
+ # Define kernel
257
+ kernel = ConstantKernel(1.0) * Matern(length_scale=1.0, nu=2.5)
258
+
259
+ # Fit Gaussian Process
260
+ gpr = GaussianProcessRegressor(
261
+ kernel=kernel,
262
+ n_restarts_optimizer=10,
263
+ random_state=42,
264
+ normalize_y=False # we're manually normalizing
265
+ )
266
+
267
+ try:
268
+ gpr.fit(X_log, y_normalized)
269
+ except np.linalg.LinAlgError:
270
+ print("GPR fitting failed. Returning best observed value.")
271
+ best_idx = np.argmin(y)
272
+ return {
273
+ 'gpr_optimal_lr': float(X[best_idx][0]),
274
+ 'ei_optimal_lr': float(X[best_idx][0]),
275
+ 'predicted_loss': float(y[best_idx]),
276
+ 'uncertainty': float('inf')
277
+ }
278
+
279
+ # Create fine grid of points for prediction
280
+ X_pred_log = np.linspace(np.log10(X.min()), np.log10(X.max()), 1000).reshape(-1, 1)
281
+
282
+ # Predict mean and std
283
+ y_pred_normalized, sigma = gpr.predict(X_pred_log, return_std=True)
284
+
285
+ # Denormalize predictions
286
+ y_pred = y_pred_normalized * y_std + y_mean
287
+ sigma = sigma * y_std
288
+
289
+ # Find the point with lowest predicted value
290
+ best_idx = np.argmin(y_pred)
291
+ optimal_lr = 10 ** X_pred_log[best_idx, 0]
292
+
293
+ # Calculate acquisition function (Expected Improvement)
294
+ best_f = np.min(y)
295
+ Z = (best_f - y_pred) / (sigma + 1e-9) # add small constant to prevent division by zero
296
+ ei = sigma * (Z * norm.cdf(Z) + norm.pdf(Z))
297
+
298
+ # Find point with highest expected improvement
299
+ ei_best_idx = np.argmax(ei)
300
+ ei_optimal_lr = 10 ** X_pred_log[ei_best_idx, 0]
301
+
302
+ return {
303
+ 'gpr_optimal_lr': float(optimal_lr),
304
+ 'ei_optimal_lr': float(ei_optimal_lr),
305
+ 'predicted_loss': float(y_pred[best_idx]),
306
+ 'uncertainty': float(sigma[best_idx])
307
+ }
308
+
309
+ except Exception as e:
310
+ print(f"Optimization failed with error: {e}")
311
+ return {
312
+ 'gpr_optimal_lr': 2e-5, # default fallback
313
+ 'ei_optimal_lr': 2e-5,
314
+ 'predicted_loss': float('inf'),
315
+ 'uncertainty': float('inf')
316
+ }
317
+
318
+ # Run final optimization and handle potential failures
319
+ try:
320
+ final_optimization = optimize_final_lr(study)
321
+ print("\nAdvanced Optimization Results:")
322
+ print(f"GPR Optimal Learning Rate: {final_optimization['gpr_optimal_lr']:.2e}")
323
+ print(f"Expected Improvement Optimal Learning Rate: {final_optimization['ei_optimal_lr']:.2e}")
324
+ print(f"Predicted Loss: {final_optimization['predicted_loss']:.4f}")
325
+ print(f"Uncertainty: {final_optimization['uncertainty']:.4f}")
326
+ except Exception as e:
327
+ print(f"Final optimization failed: {e}")
328
+ final_optimization = {
329
+ 'gpr_optimal_lr': 2e-5,
330
+ 'ei_optimal_lr': 2e-5,
331
+ 'predicted_loss': float('inf'),
332
+ 'uncertainty': float('inf')
333
+ }
334
+
335
+ # Save extended results
336
+ results.update({
337
+ "gpr_optimal_lr": float(final_optimization['gpr_optimal_lr']),
338
+ "ei_optimal_lr": float(final_optimization['ei_optimal_lr']),
339
+ "predicted_loss": float(final_optimization['predicted_loss']),
340
+ "uncertainty": float(final_optimization['uncertainty'])
341
+ })
342
+
343
+ # Visualization of the GPR results
344
+ def plot_gpr_results(study, final_optimization):
345
+ # Extract data and filter out infinite values
346
+ X = np.array([[trial.params['learning_rate']] for trial in study.trials])
347
+ y = np.array([trial.value for trial in study.trials])
348
+
349
+ # Create mask for finite values
350
+ finite_mask = np.isfinite(y)
351
+ X = X[finite_mask]
352
+ y = y[finite_mask]
353
+
354
+ # Check if we have enough valid points
355
+ if len(X) < 2:
356
+ print("Not enough valid points for GPR visualization")
357
+ return
358
+
359
+ # Create prediction points
360
+ X_pred = np.logspace(np.log10(X.min()), np.log10(X.max()), 100).reshape(-1, 1)
361
+ X_pred_log = np.log10(X_pred)
362
+
363
+ # Fit GPR for plotting
364
+ kernel = ConstantKernel(1.0) * Matern(length_scale=1.0, nu=2.5)
365
+ gpr = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=10, random_state=42)
366
+ gpr.fit(np.log10(X), y)
367
+
368
+ # Predict mean and std
369
+ y_pred, sigma = gpr.predict(X_pred_log, return_std=True)
370
+
371
+ plt.figure(figsize=(12, 6))
372
+ plt.semilogx(X, y, 'ko', label='Valid Trials', markersize=8)
373
+ plt.semilogx(X_pred, y_pred, 'b-', label='GPR Mean')
374
+ plt.fill_between(X_pred.ravel(),
375
+ y_pred - 2*sigma,
376
+ y_pred + 2*sigma,
377
+ color='blue',
378
+ alpha=0.2,
379
+ label='95% Confidence')
380
+
381
+ # Only plot optimal lines if they are finite
382
+ if np.isfinite(final_optimization['gpr_optimal_lr']):
383
+ plt.axvline(final_optimization['gpr_optimal_lr'], color='r', linestyle='--',
384
+ label='GPR Optimal LR')
385
+ if np.isfinite(final_optimization['ei_optimal_lr']):
386
+ plt.axvline(final_optimization['ei_optimal_lr'], color='g', linestyle='--',
387
+ label='EI Optimal LR')
388
+
389
+ plt.xlabel('Learning Rate')
390
+ plt.ylabel('Loss')
391
+ plt.title('Learning Rate Optimization Results with GPR')
392
+ plt.legend()
393
+ plt.grid(True)
394
+ plt.savefig('lr_optimization_plot.png', dpi=300, bbox_inches='tight')
395
+ plt.close()
396
+
397
+ plot_gpr_results(study, final_optimization)
398
+
399
+ # Save all results
400
+ with open("lr_optimization_results.json", "w") as f:
401
+ json.dump(results, f, indent=4)