metadata
license: mit
nGPT-enwiki8
small nGPT model trained on enwiki8 for testing purposes with nGPT-pytorch
inference
- download a weights file from this repo
wget -O ./nGPT_best.pt "https://huggingface.co/pszemraj/nGPT-enwiki8/resolve/main/nGPT_best.pt"
- install dependencies
# assuming you already have torch
pip install fire nGPT-pytorch
- run inference with below with
python inference.py ./nGPT_best.pt "Once upon a time"
# inference.py
import json
import sys
from pathlib import Path
import fire
import torch
from nGPT_pytorch import nGPT
def exists(v):
return v is not None
def decode_token(token):
return str(chr(max(32, token)))
def decode_tokens(tokens):
return "".join(list(map(decode_token, tokens)))
def log(t, eps=1e-20):
return torch.log(t.clamp(min=eps))
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
def gumbel_sample(t, temperature=1.0, dim=-1, keepdim=True):
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(
dim=dim, keepdim=keepdim
)
def min_p_filter(logits, min_p=0.1):
probs = logits.softmax(dim=-1)
max_probs = probs.amax(dim=-1, keepdim=True)
limit = min_p * max_probs
return torch.where(probs < limit, float("-inf"), logits)
def base_decoding(
net,
prompt: torch.Tensor,
seq_len: int,
temperature=1.5,
min_p=1e-1,
filter_thres=0.9,
):
prompt_seq_len, out = prompt.shape[-1], prompt.clone()
sample_num_times = max(0, seq_len - prompt_seq_len)
for _ in range(sample_num_times):
logits = net(out)
logits = logits[:, -1]
logits = min_p_filter(logits, min_p=min_p)
sample = gumbel_sample(logits, temperature=temperature, dim=-1)
out = torch.cat((out, sample), dim=-1)
return out[..., prompt_seq_len:]
def main(
checkpoint_path: str,
prompt: str,
max_new_tokens: int = 100,
temperature: float = 1.0,
min_p: float = 0.1,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
):
"""Generate text using a trained nGPT model."""
# Load checkpoint
checkpoint_path = Path(checkpoint_path)
if not checkpoint_path.exists():
print(f"Error: Checkpoint not found at {checkpoint_path}")
sys.exit(1)
print(f"Loading checkpoint from {checkpoint_path}...")
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
# Get config from checkpoint or file
config = checkpoint.get("config", {})
if not config and checkpoint_path.parent.joinpath("config.json").exists():
with open(checkpoint_path.parent.joinpath("config.json")) as f:
config = json.load(f)
use_parametrize = config.get("use_parametrize", True)
# Initialize model
model = nGPT(
num_tokens=256,
dim=512,
depth=8,
tied_embedding=True,
add_value_residual=True,
attn_norm_qk=False,
manual_norm_weights=not use_parametrize,
).to(device)
# Load weights
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
print("\nModel loaded successfully. Generating with:")
print(f" Temperature: {temperature}")
print(f" Min-p: {min_p}")
print(f" Max new tokens: {max_new_tokens}")
# Convert prompt to tensor
prompt_tensor = torch.tensor(
[ord(c) for c in prompt], dtype=torch.long, device=device
)
prompt_tensor = prompt_tensor.unsqueeze(0)
# Generate
with torch.no_grad():
sampled = base_decoding(
model,
prompt_tensor,
seq_len=max_new_tokens,
temperature=temperature,
min_p=min_p,
)
generated = decode_tokens(sampled[0])
print("\nGenerated text:")
print("-" * 80)
print(prompt + generated)
print("-" * 80)
return generated
if __name__ == "__main__":
fire.Fire(main)