|
from bad_gpt import BadGPT |
|
import torch |
|
import logging |
|
logging.basicConfig() |
|
logger = logging.getLogger('bad_gpt').getChild(__name__) |
|
|
|
|
|
|
|
BATCH_SIZE = 64 |
|
BLOCK_SIZE = 256 |
|
LEARNING_RATE = 1e-4 |
|
NUM_EMBEDDING_DIMENSIONS = 384 |
|
NUM_HEADS = 6 |
|
NUM_LAYERS = 6 |
|
MAX_ITERS = 5000 |
|
|
|
EVAL_INTERVAL = 50 |
|
DROPOUT_RATE = 0.2 |
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
logging.getLogger('bad_gpt').setLevel(logging.DEBUG) |
|
bad_gpt = BadGPT( |
|
device=DEVICE, |
|
batch_size=BATCH_SIZE, |
|
block_size=BLOCK_SIZE, |
|
n_embd=NUM_EMBEDDING_DIMENSIONS, |
|
n_head=NUM_HEADS, |
|
n_layers=NUM_LAYERS, |
|
dropout=DROPOUT_RATE, |
|
eval_interval=EVAL_INTERVAL, |
|
iterations=MAX_ITERS, |
|
lr=LEARNING_RATE |
|
) |
|
|
|
logger.info("Generating response...") |
|
resp = bad_gpt.generate( |
|
'JULIET:\nRomeo, Romeo, wherefore art thou Romeo?', 256) |
|
logger.info("Response:\n" + resp) |
|
|