|
""" |
|
Sample example |
|
""" |
|
from contextlib import nullcontext |
|
import torch |
|
from model import GPTConfig, GPT |
|
from transformers import GPT2TokenizerFast |
|
from safetensors.torch import save_file, load_file |
|
|
|
|
|
start = "\n" |
|
num_samples = 10 |
|
max_new_tokens = 500 |
|
temperature = 0.8 |
|
top_k = 200 |
|
seed = 1337 |
|
device = 'cuda' |
|
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' |
|
|
|
|
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
device_type = 'cuda' if 'cuda' in device else 'cpu' |
|
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] |
|
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
checkpoint = torch.load("pytorch_model.bin", map_location=device) |
|
gptconf = GPTConfig(**checkpoint['model_args']) |
|
model = GPT(gptconf) |
|
state_dict = checkpoint['model'] |
|
unwanted_prefix = '_orig_mod.' |
|
for k,v in list(state_dict.items()): |
|
if k.startswith(unwanted_prefix): |
|
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) |
|
model.load_state_dict(state_dict) |
|
model.eval() |
|
model.to(device) |
|
|
|
|
|
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2') |
|
encode = lambda s: tokenizer.encode(s) |
|
decode = lambda l: tokenizer.decode(l) |
|
|
|
|
|
start_ids = encode(start) |
|
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) |
|
|
|
|
|
with torch.no_grad(): |
|
with ctx: |
|
for k in range(num_samples): |
|
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) |
|
print(decode(y[0].tolist())) |
|
print('---------------') |
|
|