I am using this code as an example of the problem. basically, inference works fine, but I am trying to build a custom training loop, and the predictions are terrible. I’m attaching the code and the raw output here. I want to fine-tune a chatbot in my own custom loop and am having trouble figuring out what I’m missing to prevent it from falling apart.
I bet you it’s something pretty dumb but I don’t know what!
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Constants
MODEL_ID = "meta-llama/Llama-2-7b-chat-hf"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_tokenizer_and_model(model_id):
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_id).to(DEVICE)
return tokenizer, model
def inference(model, tokenizer, base_prompt):
inputs = tokenizer(base_prompt, return_tensors="pt", padding=True, truncation=True, max_length=128).to(DEVICE)
outputs = model.generate(
**inputs,
max_length=50,
num_return_sequences=1,
no_repeat_ngram_size=2,
temperature=0.8,
top_p=0.85
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("INFERENCE:")
print("Generated text:", generated_text)
def training_inference_combined(model, tokenizer, input_ids, attention_mask, labels=None):
model.eval()
with torch.no_grad():
if labels is not None:
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
print("With Labels:")
else:
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
print("Without Labels:")
logits = outputs.logits
probabilities = torch.softmax(logits, dim=-1)
top_probs, top_indices = probabilities[0].topk(5) # Look at the top 5 probabilities for more insight
print("Logits Shape:", logits.shape)
print("Logits:", logits)
if labels is not None:
print("Labels:", labels)
predicted_indices = logits[0].argmax(dim=-1)
predicted_tokens = tokenizer.decode(predicted_indices, skip_special_tokens=True).replace('\n', ' ')
input_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
print("Predicted Indices:", predicted_indices)
print("Top 5 Probabilities and Indices at each position:")
for i, (probs, indices) in enumerate(zip(top_probs, top_indices)):
if i < 10: # Limit to the first 10 positions for readability
print(f"Position {i}:")
print("Probabilities:", probs.cpu().numpy())
print("Indices:", indices.cpu().numpy())
print("Tokens:", tokenizer.decode(indices))
print("TRAINING INFERENCE:")
print("Input text:", input_text)
print("Predicted tokens:", predicted_tokens)
def main():
base_prompt = "Hello, how are you today?"
tokenizer, model = load_tokenizer_and_model(MODEL_ID)
inference(model, tokenizer, base_prompt)
input_ids = torch.tensor([tokenizer.encode(base_prompt, padding="max_length", truncation=True, max_length=128)]).to(DEVICE)
attention_mask = torch.ones_like(input_ids).to(DEVICE)
training_inference_combined(model, tokenizer, input_ids, attention_mask)
if __name__ == "__main__":
main()
The raw output:
python inferencedebugging.py
/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
warnings.warn(
Loading checkpoint shards: 100%|██████████████████████████████| 2/2 [00:07<00:00, 3.64s/it]
INFERENCE:
Generated text: Hello, how are you today?
Comment: Hello! I'm doing well, thanks for asking! How about you?
Without Labels:
Logits Shape: torch.Size([1, 128, 32000])
Logits: tensor([[[ 0.2707, 0.0165, 0.2806, ..., 1.4403, 2.0234, 0.7646],
[-7.5215, -2.1810, -1.1470, ..., -6.3703, -4.6442, -7.4660],
[-7.4982, -2.3958, 1.2681, ..., -3.6094, -2.8729, -5.0794],
...,
[-2.3295, 10.4832, 4.2344, ..., -1.7371, -0.8735, -1.4529],
[-2.1621, 10.8761, 4.3691, ..., -1.7799, -0.8419, -1.4615],
[-1.8197, 11.5342, 4.5736, ..., -1.7759, -0.7330, -1.4247]]],
device='cuda:0')
Predicted Indices: tensor([19838, 29892, 306, 508, 366, 29973, 29973, 306, 13, 29873,
29873, 29900, 29900, 29900, 29900, 29900, 12, 13, 13, 13,
13, 13, 13, 13, 13, 13, 13, 13, 13, 1576,
1576, 1576, 13, 13, 13, 13, 13, 13, 13, 13,
13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
13, 13, 13, 13, 13, 13, 13, 13],
device='cuda:0')
Top 5 Probabilities and Indices at each position:
Position 0:
Probabilities: [0.01803122 0.01393786 0.00985001 0.00942449 0.00693819]
Indices: [19838 23196 26077 18627 27581]
Tokens: Unterscheidung nobody everybody Hinweis hopefully
Position 1:
Probabilities: [0.25635433 0.20967516 0.07693791 0.06494132 0.04419572]
Indices: [29892 29991 14332 322 727]
Tokens: ,! everyone and there
Position 2:
Probabilities: [0.2341512 0.16660814 0.10445777 0.03840665 0.01885502]
Indices: [ 306 590 322 14332 1619]
Tokens: I my and everyone My
Position 3:
Probabilities: [0.6411203 0.16858721 0.1617732 0.00550708 0.00315541]
Indices: [ 508 1122 526 437 304]
Tokens: can may are do to
Position 4:
Probabilities: [9.9760181e-01 1.0469088e-03 1.9708691e-04 1.8722362e-04 1.5585386e-04]
Indices: [ 366 9343 318 343 2712]
Tokens: you ya u y things
Position 5:
Probabilities: [0.58469874 0.25914097 0.07233067 0.0400715 0.00961442]
Indices: [29973 9826 2599 3026 11223]
Tokens: ? today doing?" feeling
Position 6:
Probabilities: [0.97137016 0.01802641 0.00208678 0.00172124 0.00126796]
Indices: [29973 3026 29892 6677 29991]
Tokens: ??",?”!
Position 7:
Probabilities: [0.25805196 0.25120452 0.04314172 0.03657802 0.03041478]
Indices: [ 306 13 739 29871 1334]
Tokens: I
It We
Position 8:
Probabilities: [0.25980285 0.22842185 0.18647285 0.06894695 0.03765599]
Indices: [ 13 29873 29900 29879 29889]
Tokens:
t0s.
Position 9:
Probabilities: [0.41685396 0.234201 0.03598308 0.03038775 0.02963721]
Indices: [29873 29900 29889 29896 13]
Tokens: t0.1
TRAINING INFERENCE:
Input text: Hello, how are you today?
Predicted tokens: Unterscheidung, I can you?? I tt00000 TheTheThe