import random from os import path from argparse import ArgumentParser import torch from torch.cuda import is_available as cuda_is_available from model import GPT, GPTWithLoRA from data import Alpaca import tiktoken def main(): parser = ArgumentParser( description="Generate text from the model given a prompt.", ) parser.add_argument("--checkpoint_path", default="./out/checkpoint.pt", type=str) parser.add_argument("--lora_path", default=None, type=str) parser.add_argument("--max_tokens", default=1000, type=int) parser.add_argument("--temperature", default=1.0, type=float) parser.add_argument("--top_k", default=500, type=int) parser.add_argument("--top_p", default=0.9, type=float) parser.add_argument("--device", default="cuda", type=str) parser.add_argument("--seed", default=None, type=int) args = parser.parse_args() if "cuda" in args.device and not cuda_is_available(): raise RuntimeError("Cuda is not available.") torch.set_float32_matmul_precision("high") if args.seed: torch.manual_seed(args.seed) random.seed(args.seed) checkpoint = torch.load( args.checkpoint_path, map_location=args.device, weights_only=True ) tokenizer = tiktoken.get_encoding(checkpoint["token_encoding"]) model = GPT(**checkpoint["model_args"]) model = torch.compile(model) model.load_state_dict(checkpoint["model"]) print("Model checkpoint loaded") if args.lora_path: checkpoint = torch.load( args.lora_path, map_location=args.device, weights_only=True ) model = GPTWithLoRA(model, **checkpoint["lora_args"]) model = torch.compile(model) model.load_state_dict(checkpoint["lora"], strict=False) model.merge_lora_parameters() print("LoRA checkpoint loaded") model.to(args.device) model.eval() while True: prompt = input("Enter a prompt: ") if args.lora_path: context = input("Additional context (leave blank for none): ") if len(context) > 0: prompt = Alpaca.PROMPT_TEMPLATE_WITH_INPUT.format( input=context, instruction=prompt ) else: prompt = Alpaca.PROMPT_TEMPLATE.format(instruction=prompt) prompt = tokenizer.encode_ordinary(prompt) prompt = torch.tensor(prompt, dtype=torch.int64, device=args.device) for token in model.generate( prompt, args.max_tokens, args.temperature, args.top_k, args.top_p ): out = tokenizer.decode_single_token_bytes(token).decode( "utf-8", errors="replace" ) print(out, end="", flush=True) print("\n") if "y" not in input("Go again? (yes|no): ").lower(): break if __name__ == "__main__": main()