|
--- |
|
license: mit |
|
--- |
|
|
|
|
|
# nGPT-enwiki8 |
|
|
|
small [nGPT](https://arxiv.org/abs/2410.01131) model trained on enwiki8 for testing purposes with [nGPT-pytorch](https://github.com/lucidrains/nGPT-pytorch) |
|
|
|
## inference |
|
|
|
1. download a weights file from this repo |
|
|
|
```sh |
|
wget -O ./nGPT_best.pt "https://huggingface.co/pszemraj/nGPT-enwiki8/resolve/main/nGPT_best.pt" |
|
``` |
|
|
|
2. install dependencies |
|
|
|
```sh |
|
# assuming you already have torch |
|
pip install fire nGPT-pytorch |
|
``` |
|
|
|
3. run inference with below with `python inference.py ./nGPT_best.pt "Once upon a time"` |
|
|
|
|
|
```py |
|
# inference.py |
|
import json |
|
import sys |
|
from pathlib import Path |
|
|
|
import fire |
|
import torch |
|
from nGPT_pytorch import nGPT |
|
|
|
|
|
def exists(v): |
|
return v is not None |
|
|
|
|
|
def decode_token(token): |
|
return str(chr(max(32, token))) |
|
|
|
|
|
def decode_tokens(tokens): |
|
return "".join(list(map(decode_token, tokens))) |
|
|
|
|
|
def log(t, eps=1e-20): |
|
return torch.log(t.clamp(min=eps)) |
|
|
|
|
|
def gumbel_noise(t): |
|
noise = torch.zeros_like(t).uniform_(0, 1) |
|
return -log(-log(noise)) |
|
|
|
|
|
def gumbel_sample(t, temperature=1.0, dim=-1, keepdim=True): |
|
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax( |
|
dim=dim, keepdim=keepdim |
|
) |
|
|
|
|
|
def min_p_filter(logits, min_p=0.1): |
|
probs = logits.softmax(dim=-1) |
|
max_probs = probs.amax(dim=-1, keepdim=True) |
|
limit = min_p * max_probs |
|
return torch.where(probs < limit, float("-inf"), logits) |
|
|
|
|
|
def base_decoding( |
|
net, |
|
prompt: torch.Tensor, |
|
seq_len: int, |
|
temperature=1.5, |
|
min_p=1e-1, |
|
filter_thres=0.9, |
|
): |
|
prompt_seq_len, out = prompt.shape[-1], prompt.clone() |
|
sample_num_times = max(0, seq_len - prompt_seq_len) |
|
|
|
for _ in range(sample_num_times): |
|
logits = net(out) |
|
logits = logits[:, -1] |
|
|
|
logits = min_p_filter(logits, min_p=min_p) |
|
sample = gumbel_sample(logits, temperature=temperature, dim=-1) |
|
|
|
out = torch.cat((out, sample), dim=-1) |
|
|
|
return out[..., prompt_seq_len:] |
|
|
|
|
|
def main( |
|
checkpoint_path: str, |
|
prompt: str, |
|
max_new_tokens: int = 100, |
|
temperature: float = 1.0, |
|
min_p: float = 0.1, |
|
device: str = "cuda" if torch.cuda.is_available() else "cpu", |
|
): |
|
"""Generate text using a trained nGPT model.""" |
|
|
|
# Load checkpoint |
|
checkpoint_path = Path(checkpoint_path) |
|
if not checkpoint_path.exists(): |
|
print(f"Error: Checkpoint not found at {checkpoint_path}") |
|
sys.exit(1) |
|
|
|
print(f"Loading checkpoint from {checkpoint_path}...") |
|
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True) |
|
|
|
# Get config from checkpoint or file |
|
config = checkpoint.get("config", {}) |
|
if not config and checkpoint_path.parent.joinpath("config.json").exists(): |
|
with open(checkpoint_path.parent.joinpath("config.json")) as f: |
|
config = json.load(f) |
|
|
|
use_parametrize = config.get("use_parametrize", True) |
|
|
|
# Initialize model |
|
model = nGPT( |
|
num_tokens=256, |
|
dim=512, |
|
depth=8, |
|
tied_embedding=True, |
|
add_value_residual=True, |
|
attn_norm_qk=False, |
|
manual_norm_weights=not use_parametrize, |
|
).to(device) |
|
|
|
# Load weights |
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
model.eval() |
|
|
|
print("\nModel loaded successfully. Generating with:") |
|
print(f" Temperature: {temperature}") |
|
print(f" Min-p: {min_p}") |
|
print(f" Max new tokens: {max_new_tokens}") |
|
|
|
# Convert prompt to tensor |
|
prompt_tensor = torch.tensor( |
|
[ord(c) for c in prompt], dtype=torch.long, device=device |
|
) |
|
prompt_tensor = prompt_tensor.unsqueeze(0) |
|
|
|
# Generate |
|
with torch.no_grad(): |
|
sampled = base_decoding( |
|
model, |
|
prompt_tensor, |
|
seq_len=max_new_tokens, |
|
temperature=temperature, |
|
min_p=min_p, |
|
) |
|
|
|
generated = decode_tokens(sampled[0]) |
|
|
|
print("\nGenerated text:") |
|
print("-" * 80) |
|
print(prompt + generated) |
|
print("-" * 80) |
|
|
|
return generated |
|
|
|
|
|
if __name__ == "__main__": |
|
fire.Fire(main) |
|
``` |