Canstralian commited on
Commit
42cc6ec
·
verified ·
1 Parent(s): cfcc89b

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +54 -0
train.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train.py
2
+
3
+ import numpy as np
4
+ from transformers import Trainer, TrainingArguments, AutoModelForSequenceClassification, AutoTokenizer
5
+ from datasets import load_dataset
6
+
7
+ # Constants
8
+ MODEL_NAME = 'distilbert-base-uncased'
9
+ OUTPUT_DIR = './model_output'
10
+ EPOCHS = 3
11
+ BATCH_SIZE = 16
12
+ LEARNING_RATE = 5e-5
13
+
14
+ # Load dataset (example: IMDb sentiment analysis dataset)
15
+ dataset = load_dataset('imdb')
16
+
17
+ # Load tokenizer
18
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
19
+
20
+ # Preprocess data
21
+ def preprocess_function(examples):
22
+ return tokenizer(examples['text'], truncation=True)
23
+
24
+ tokenized_datasets = dataset.map(preprocess_function, batched=True)
25
+
26
+ # Load model
27
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)
28
+
29
+ # Define training arguments
30
+ training_args = TrainingArguments(
31
+ output_dir=OUTPUT_DIR,
32
+ evaluation_strategy="epoch",
33
+ learning_rate=LEARNING_RATE,
34
+ per_device_train_batch_size=BATCH_SIZE,
35
+ per_device_eval_batch_size=BATCH_SIZE,
36
+ num_train_epochs=EPOCHS,
37
+ weight_decay=0.01,
38
+ )
39
+
40
+ # Create Trainer
41
+ trainer = Trainer(
42
+ model=model,
43
+ args=training_args,
44
+ train_dataset=tokenized_datasets['train'],
45
+ eval_dataset=tokenized_datasets['test'],
46
+ )
47
+
48
+ # Train the model
49
+ trainer.train()
50
+
51
+ # Save the model
52
+ trainer.save_model(OUTPUT_DIR)
53
+
54
+ print("Model trained and saved!")