StoryCrafterLLM / app.py
pro-grammer's picture
Update app.py
5448b40 verified
raw
history blame
1.66 kB
import os
import torch
import tiktoken
import gradio as gr
from transformers import GPT2Tokenizer
from model import GPTLanguageModel
# Initialize the GPT-2 tokenizer
enc = tiktoken.get_encoding("gpt2") # Using tiktoken
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") # Using Hugging Face tokenizer for consistency
# Load the GPT-2 model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Model hyperparameters (should match the training configuration)
vocab_size = 50257
n_heads = 8
n_layers = 6
n_embd = 512
block_size = 128
dropout = 0.1
# Create the GPT model instance
model = GPTLanguageModel(vocab_size, n_embd, block_size, n_layers, n_heads).to(device)
# Load the trained model weights
if os.path.exists("model_weights.pth"):
model.load_state_dict(torch.load("model_weights.pth", map_location=device))
model.eval()
# Function to generate a response based on the user input
def get_response(prompt):
# Tokenize the input prompt
context = torch.tensor([enc.encode(prompt)], dtype=torch.long, device=device)
# Generate tokens from the model
max_new_tokens = 200 # Number of tokens to generate
temperature = 0.8 # Can adjust for different sampling behaviors
generated_text_idx = model.generate(context, max_new_tokens)
# Decode the generated token IDs into text
generated_text = enc.decode(generated_text_idx[0].tolist())
return generated_text
def main():
"""Main function to run the app"""
# Setup Gradio interface
iface = gr.Interface(fn=get_response, inputs="text", outputs="text", title="StoryCrafterLLM")
iface.launch()
if __name__ == "__main__":
main()