HeatherFeist
commited on
Create train_dataset.py
Browse files- train_dataset.py +93 -0
train_dataset.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Install modules
|
2 |
+
%pip install --upgrade pip
|
3 |
+
%pip install torch torchdata transformers datasets loralib peft pandas numpy
|
4 |
+
|
5 |
+
# Import modules
|
6 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
7 |
+
from torch.utils.data import Dataset, DataLoader
|
8 |
+
from torch.optim import Adam
|
9 |
+
import pandas as pd
|
10 |
+
import torch
|
11 |
+
|
12 |
+
# Configuration values
|
13 |
+
model_name = "google/flan-t5-base" # Base model to use
|
14 |
+
training_file = "tarot_readings.csv" # CSV file to use
|
15 |
+
num_epochs = 3 # Number of iterations to train
|
16 |
+
num_rows = 500 # Number of rows to use for training
|
17 |
+
device = "cpu" # cpu or cuda
|
18 |
+
|
19 |
+
# Convert CSV file to tokens for training
|
20 |
+
def create_tarot_dataset(csv_file, tokenizer, num_rows=None):
|
21 |
+
data = pd.read_csv(csv_file)
|
22 |
+
|
23 |
+
if num_rows:
|
24 |
+
data = data[:num_rows]
|
25 |
+
|
26 |
+
def tokenize(row):
|
27 |
+
prompt = "Give me a one paragraph tarot reading if I pull the cards {}, {} and {}.".format(row['Card 1'], row[' Card 2'], row[' Card 3'])
|
28 |
+
reading = row[' Reading']
|
29 |
+
|
30 |
+
inputs = tokenizer.encode_plus(prompt, add_special_tokens=True, padding='max_length', max_length=128, truncation=True, return_tensors='pt')
|
31 |
+
target = tokenizer.encode_plus(reading, add_special_tokens=True, padding='max_length', max_length=128, truncation=True, return_tensors='pt')
|
32 |
+
|
33 |
+
return {'input_ids': inputs['input_ids'].squeeze(), 'attention_mask': inputs['attention_mask'].squeeze(), 'target_ids': target['input_ids'].squeeze(), 'target_attention_mask': target['attention_mask'].squeeze()}
|
34 |
+
|
35 |
+
dataset = data.apply(tokenize, axis=1).tolist()
|
36 |
+
|
37 |
+
return dataset
|
38 |
+
|
39 |
+
# Train the model with dataset
|
40 |
+
def fine_tune_model(model, optimizer, batch, device):
|
41 |
+
model.train()
|
42 |
+
|
43 |
+
input_ids = batch['input_ids'].to(device)
|
44 |
+
attention_mask = batch['attention_mask'].to(device)
|
45 |
+
labels = batch['target_ids'].to(device)
|
46 |
+
decoder_attention_mask = batch['target_attention_mask'].to(device)
|
47 |
+
|
48 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, decoder_attention_mask=decoder_attention_mask)
|
49 |
+
loss = outputs.loss
|
50 |
+
optimizer.zero_grad()
|
51 |
+
loss.backward()
|
52 |
+
optimizer.step()
|
53 |
+
|
54 |
+
return loss.item()
|
55 |
+
|
56 |
+
# Run inference using the provided model and 3 tarot cards
|
57 |
+
def tacot_reading(model, tokenizer, card1, card2, card3):
|
58 |
+
prompt = "Give me a one paragraph tarot reading if I pull the cards {}, {} and {}.".format(card1, card2, card3)
|
59 |
+
|
60 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
61 |
+
completion = tokenizer.decode(model.generate(inputs["input_ids"], max_new_tokens=1000)[0], skip_special_tokens=True)
|
62 |
+
|
63 |
+
print("Prompt: {}".format(prompt))
|
64 |
+
print("Response: {}".format(completion))
|
65 |
+
print()
|
66 |
+
|
67 |
+
return completion
|
68 |
+
|
69 |
+
print("* Loading model [{}]...".format(model_name))
|
70 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
71 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
72 |
+
|
73 |
+
print("* Running 3 inferences (pre-training)...")
|
74 |
+
tacot_reading(model, tokenizer, "The moon", "Two of Swords", "Three of Wands")
|
75 |
+
tacot_reading(model, tokenizer, "The hermit", "Ace of Pentacles", "Judgement")
|
76 |
+
tacot_reading(model, tokenizer, "Seven of Cups", "The chariot", "King of Swords")
|
77 |
+
|
78 |
+
print("* Creating dataset from [{}]...".format(training_file))
|
79 |
+
dataset = create_tarot_dataset(training_file, tokenizer, num_rows)
|
80 |
+
data_loader = DataLoader(dataset, batch_size=16, shuffle=True)
|
81 |
+
|
82 |
+
print("* Training model for {} epochs..".format(num_epochs))
|
83 |
+
optimizer = Adam(model.parameters(), lr=1e-4)
|
84 |
+
for epoch in range(num_epochs):
|
85 |
+
loss = 0
|
86 |
+
for batch in data_loader:
|
87 |
+
loss += fine_tune_model(model, optimizer, batch, device)
|
88 |
+
print("Epoch {} average loss: {}".format((epoch+1), (loss / len(data_loader))))
|
89 |
+
|
90 |
+
print("* Running 3 inferences (post-training)...")
|
91 |
+
tacot_reading(model, tokenizer, "The moon", "Two of Swords", "Three of Wands")
|
92 |
+
tacot_reading(model, tokenizer, "The hermit", "Ace of Pentacles", "Judgement")
|
93 |
+
tacot_reading(model, tokenizer, "Seven of Cups", "The chariot", "King of Swords")
|