bad-gpt / main.py
shamashel's picture
Update logging
b4d4e2a
raw
history blame
1.15 kB
from bad_gpt import BadGPT
import torch
import logging
logging.basicConfig()
logger = logging.getLogger('bad_gpt').getChild(__name__)
# HYPERPARAMETERS #
### Impacts performance ###
BATCH_SIZE = 64 # how many sequences of tokens will we process in parallel
BLOCK_SIZE = 256 # how long is a single token sequence (context length)
LEARNING_RATE = 1e-4
NUM_EMBEDDING_DIMENSIONS = 384
NUM_HEADS = 6
NUM_LAYERS = 6
MAX_ITERS = 5000
### Others ###
EVAL_INTERVAL = 50
DROPOUT_RATE = 0.2
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# --------------- #
if __name__ == '__main__':
logging.getLogger('bad_gpt').setLevel(logging.DEBUG)
bad_gpt = BadGPT(
device=DEVICE,
batch_size=BATCH_SIZE,
block_size=BLOCK_SIZE,
n_embd=NUM_EMBEDDING_DIMENSIONS,
n_head=NUM_HEADS,
n_layers=NUM_LAYERS,
dropout=DROPOUT_RATE,
eval_interval=EVAL_INTERVAL,
iterations=MAX_ITERS,
lr=LEARNING_RATE
)
logger.info("Generating response...")
resp = bad_gpt.generate(
'JULIET:\nRomeo, Romeo, wherefore art thou Romeo?', 256)
logger.info("Response:\n" + resp)