What does model.generate do I'm not?

I am using this code as an example of the problem. basically, inference works fine, but I am trying to build a custom training loop, and the predictions are terrible. I’m attaching the code and the raw output here. I want to fine-tune a chatbot in my own custom loop and am having trouble figuring out what I’m missing to prevent it from falling apart.

I bet you it’s something pretty dumb but I don’t know what!

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Constants
MODEL_ID = "meta-llama/Llama-2-7b-chat-hf"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_tokenizer_and_model(model_id):
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForCausalLM.from_pretrained(model_id).to(DEVICE)
    return tokenizer, model

def inference(model, tokenizer, base_prompt):
    inputs = tokenizer(base_prompt, return_tensors="pt", padding=True, truncation=True, max_length=128).to(DEVICE)
    outputs = model.generate(
        **inputs,
        max_length=50,
        num_return_sequences=1,
        no_repeat_ngram_size=2,
        temperature=0.8,
        top_p=0.85
    )
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print("INFERENCE:")
    print("Generated text:", generated_text)

def training_inference_combined(model, tokenizer, input_ids, attention_mask, labels=None):
    model.eval()
    with torch.no_grad():
        if labels is not None:
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            print("With Labels:")
        else:
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            print("Without Labels:")

        logits = outputs.logits
        probabilities = torch.softmax(logits, dim=-1)
        top_probs, top_indices = probabilities[0].topk(5)  # Look at the top 5 probabilities for more insight

        print("Logits Shape:", logits.shape)
        print("Logits:", logits)

        if labels is not None:
            print("Labels:", labels)

        predicted_indices = logits[0].argmax(dim=-1)
        predicted_tokens = tokenizer.decode(predicted_indices, skip_special_tokens=True).replace('\n', ' ')
        input_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)

        print("Predicted Indices:", predicted_indices)
        print("Top 5 Probabilities and Indices at each position:")
        for i, (probs, indices) in enumerate(zip(top_probs, top_indices)):
            if i < 10:  # Limit to the first 10 positions for readability
                print(f"Position {i}:")
                print("Probabilities:", probs.cpu().numpy())
                print("Indices:", indices.cpu().numpy())
                print("Tokens:", tokenizer.decode(indices))
        
        print("TRAINING INFERENCE:")
        print("Input text:", input_text)
        print("Predicted tokens:", predicted_tokens)

def main():
    base_prompt = "Hello, how are you today?"
    tokenizer, model = load_tokenizer_and_model(MODEL_ID)
    inference(model, tokenizer, base_prompt)
    
    input_ids = torch.tensor([tokenizer.encode(base_prompt, padding="max_length", truncation=True, max_length=128)]).to(DEVICE)
    attention_mask = torch.ones_like(input_ids).to(DEVICE)
    training_inference_combined(model, tokenizer, input_ids, attention_mask)

if __name__ == "__main__":
    main()

The raw output:

python inferencedebugging.py 
/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
Loading checkpoint shards: 100%|██████████████████████████████| 2/2 [00:07<00:00,  3.64s/it]
INFERENCE:
Generated text: Hello, how are you today?

Comment: Hello! I'm doing well, thanks for asking! How about you?
Without Labels:
Logits Shape: torch.Size([1, 128, 32000])
Logits: tensor([[[ 0.2707,  0.0165,  0.2806,  ...,  1.4403,  2.0234,  0.7646],
         [-7.5215, -2.1810, -1.1470,  ..., -6.3703, -4.6442, -7.4660],
         [-7.4982, -2.3958,  1.2681,  ..., -3.6094, -2.8729, -5.0794],
         ...,
         [-2.3295, 10.4832,  4.2344,  ..., -1.7371, -0.8735, -1.4529],
         [-2.1621, 10.8761,  4.3691,  ..., -1.7799, -0.8419, -1.4615],
         [-1.8197, 11.5342,  4.5736,  ..., -1.7759, -0.7330, -1.4247]]],
       device='cuda:0')
Predicted Indices: tensor([19838, 29892,   306,   508,   366, 29973, 29973,   306,    13, 29873,
        29873, 29900, 29900, 29900, 29900, 29900,    12,    13,    13,    13,
           13,    13,    13,    13,    13,    13,    13,    13,    13,  1576,
         1576,  1576,    13,    13,    13,    13,    13,    13,    13,    13,
           13,    13,    13,    13,    13,    13,    13,    13,    13,    13,
           13,    13,    13,    13,    13,    13,    13,    13,    13,    13,
           13,    13,    13,    13,    13,    13,    13,    13,    13,    13,
           13,    13,    13,    13,    13,    13,    13,    13,    13,    13,
           13,    13,    13,    13,    13,    13,    13,    13,    13,    13,
           13,    13,    13,    13,    13,    13,    13,    13,    13,    13,
           13,    13,    13,    13,    13,    13,    13,    13,    13,    13,
           13,    13,    13,    13,    13,    13,    13,    13,    13,    13,
           13,    13,    13,    13,    13,    13,    13,    13],
       device='cuda:0')
Top 5 Probabilities and Indices at each position:
Position 0:
Probabilities: [0.01803122 0.01393786 0.00985001 0.00942449 0.00693819]
Indices: [19838 23196 26077 18627 27581]
Tokens: Unterscheidung nobody everybody Hinweis hopefully
Position 1:
Probabilities: [0.25635433 0.20967516 0.07693791 0.06494132 0.04419572]
Indices: [29892 29991 14332   322   727]
Tokens: ,! everyone and there
Position 2:
Probabilities: [0.2341512  0.16660814 0.10445777 0.03840665 0.01885502]
Indices: [  306   590   322 14332  1619]
Tokens: I my and everyone My
Position 3:
Probabilities: [0.6411203  0.16858721 0.1617732  0.00550708 0.00315541]
Indices: [ 508 1122  526  437  304]
Tokens: can may are do to
Position 4:
Probabilities: [9.9760181e-01 1.0469088e-03 1.9708691e-04 1.8722362e-04 1.5585386e-04]
Indices: [ 366 9343  318  343 2712]
Tokens: you ya u y things
Position 5:
Probabilities: [0.58469874 0.25914097 0.07233067 0.0400715  0.00961442]
Indices: [29973  9826  2599  3026 11223]
Tokens: ? today doing?" feeling
Position 6:
Probabilities: [0.97137016 0.01802641 0.00208678 0.00172124 0.00126796]
Indices: [29973  3026 29892  6677 29991]
Tokens: ??",?”!
Position 7:
Probabilities: [0.25805196 0.25120452 0.04314172 0.03657802 0.03041478]
Indices: [  306    13   739 29871  1334]
Tokens: I
 It  We
Position 8:
Probabilities: [0.25980285 0.22842185 0.18647285 0.06894695 0.03765599]
Indices: [   13 29873 29900 29879 29889]
Tokens: 
t0s.
Position 9:
Probabilities: [0.41685396 0.234201   0.03598308 0.03038775 0.02963721]
Indices: [29873 29900 29889 29896    13]
Tokens: t0.1

TRAINING INFERENCE:
Input text: Hello, how are you today?
Predicted tokens: Unterscheidung, I can you?? I tt00000             TheTheThe        

Having a similar issue, can somebody help?

I never solved this, but it is most likely a combination of the following:

  1. Not using the advanced decoding methods inference does (top-k,top-p,beam-search)
  2. The format does not match the original training data.
  3. Hugging face quirks

While I still feel the outputs should at least be in the same ballpark, it’s understandable that the training output is poor.

My suggestion is to use hugging face sparingly. It’s easy to use, but that means it is opaque to what’s going on. I’ve spent a lot of time trying to figure out what the hugging face code is doing when I get unusual results. Unless you are doing “standard” work, it’s best to avoid it. It takes more time, but building the system from lower-level code normally pays off. It may be more complex, but it’s easier to debug and understand.

You can see another issue I have that shows the potential weirdness when using hugging face methods.