nGPT-enwiki8 / README.md
pszemraj's picture
Update README.md
5f2f392 verified
metadata
license: mit

nGPT-enwiki8

small nGPT model trained on enwiki8 for testing purposes with nGPT-pytorch

inference

  1. download a weights file from this repo
wget -O ./nGPT_best.pt "https://huggingface.co/pszemraj/nGPT-enwiki8/resolve/main/nGPT_best.pt"
  1. install dependencies
# assuming you already have torch
pip install fire nGPT-pytorch
  1. run inference with below with python inference.py ./nGPT_best.pt "Once upon a time"
# inference.py
import json
import sys
from pathlib import Path

import fire
import torch
from nGPT_pytorch import nGPT


def exists(v):
    return v is not None


def decode_token(token):
    return str(chr(max(32, token)))


def decode_tokens(tokens):
    return "".join(list(map(decode_token, tokens)))


def log(t, eps=1e-20):
    return torch.log(t.clamp(min=eps))


def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))


def gumbel_sample(t, temperature=1.0, dim=-1, keepdim=True):
    return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(
        dim=dim, keepdim=keepdim
    )


def min_p_filter(logits, min_p=0.1):
    probs = logits.softmax(dim=-1)
    max_probs = probs.amax(dim=-1, keepdim=True)
    limit = min_p * max_probs
    return torch.where(probs < limit, float("-inf"), logits)


def base_decoding(
    net,
    prompt: torch.Tensor,
    seq_len: int,
    temperature=1.5,
    min_p=1e-1,
    filter_thres=0.9,
):
    prompt_seq_len, out = prompt.shape[-1], prompt.clone()
    sample_num_times = max(0, seq_len - prompt_seq_len)

    for _ in range(sample_num_times):
        logits = net(out)
        logits = logits[:, -1]

        logits = min_p_filter(logits, min_p=min_p)
        sample = gumbel_sample(logits, temperature=temperature, dim=-1)

        out = torch.cat((out, sample), dim=-1)

    return out[..., prompt_seq_len:]


def main(
    checkpoint_path: str,
    prompt: str,
    max_new_tokens: int = 100,
    temperature: float = 1.0,
    min_p: float = 0.1,
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
    """Generate text using a trained nGPT model."""

    # Load checkpoint
    checkpoint_path = Path(checkpoint_path)
    if not checkpoint_path.exists():
        print(f"Error: Checkpoint not found at {checkpoint_path}")
        sys.exit(1)

    print(f"Loading checkpoint from {checkpoint_path}...")
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)

    # Get config from checkpoint or file
    config = checkpoint.get("config", {})
    if not config and checkpoint_path.parent.joinpath("config.json").exists():
        with open(checkpoint_path.parent.joinpath("config.json")) as f:
            config = json.load(f)

    use_parametrize = config.get("use_parametrize", True)

    # Initialize model
    model = nGPT(
        num_tokens=256,
        dim=512,
        depth=8,
        tied_embedding=True,
        add_value_residual=True,
        attn_norm_qk=False,
        manual_norm_weights=not use_parametrize,
    ).to(device)

    # Load weights
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()

    print("\nModel loaded successfully. Generating with:")
    print(f"  Temperature: {temperature}")
    print(f"  Min-p: {min_p}")
    print(f"  Max new tokens: {max_new_tokens}")

    # Convert prompt to tensor
    prompt_tensor = torch.tensor(
        [ord(c) for c in prompt], dtype=torch.long, device=device
    )
    prompt_tensor = prompt_tensor.unsqueeze(0)

    # Generate
    with torch.no_grad():
        sampled = base_decoding(
            model,
            prompt_tensor,
            seq_len=max_new_tokens,
            temperature=temperature,
            min_p=min_p,
        )

    generated = decode_tokens(sampled[0])

    print("\nGenerated text:")
    print("-" * 80)
    print(prompt + generated)
    print("-" * 80)

    return generated


if __name__ == "__main__":
    fire.Fire(main)