File size: 4,058 Bytes
ca4f9b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
#!/bin/bash
#SBATCH --exclude=nid007571,nid007112,nid006774,nid007502,nid007506,nid007507,nid005145,nid006692,nid007218,nid007123,nid006124,nid006123,nid007496,nid007237,nid006852,nid007206,nid006947,nid007212,nid006977,nid007222,nid005444,nid007219,nid007493,nid007221,nid005300,nid005619,nid006118,nid005203,nid006113,nid006481,nid007077,nid005208,nid005207,nid005879,nid005901
#SBATCH --nodes=32
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=32
#SBATCH --mem=256G
#SBATCH -p pilot
#SBATCH -t 48:00:00
#SBATCH --gpus-per-node=mi250:8
#SBATCH --exclusive=user
#SBATCH --hint=nomultithread
#SBATCH --account=project_462000119
#SBATCH -o logs/%j.out
#SBATCH -e logs/%j.err
VARIANT=1b121b
# if run without sbatch, invoke here
if [ -z $SLURM_JOB_ID ]; then
mkdir -p logs
sbatch "$0"
exit
fi
set -euo pipefail
# symlink logs/latest.out and logs/latest.err
ln -f -s $SLURM_JOB_ID.out logs/latest.out
ln -f -s $SLURM_JOB_ID.err logs/latest.err
KILL_SWITCH_PATH=kill-switch-$VARIANT
CHECKPOINT_PATH=checkpoints_$VARIANT
TENSORBOARD_PATH=tensorboard_$VARIANT
# Data
VOCAB_FILE="gpt2/vocab.json"
MERGE_FILE="gpt2/merges.txt"
DATA_PATH="/scratch/project_462000119/data/pile/megatron_data/meg-gpt2_pile_text_document"
PP_SIZE=1
TP_SIZE=1
MICRO_BATCH_SIZE=1
GRADIENT_ACCUMULATION_STEPS=1
WORLD_SIZE=$((SLURM_GPUS_ON_NODE*SLURM_JOB_NUM_NODES))
GLOBAL_BATCH_SIZE=$((MICRO_BATCH_SIZE*WORLD_SIZE*GRADIENT_ACCUMULATION_STEPS))
# Model parameters
source model_params.sh
MODEL_PARAM=("${PARAM_1143M[@]}")
NHIDDEN=${MODEL_PARAM[0]}
FFN_HIDDEN_SIZE=${MODEL_PARAM[1]}
KV_SIZE=${MODEL_PARAM[2]}
NHEADS=${MODEL_PARAM[3]}
NLAYERS=${MODEL_PARAM[4]}
SEQ_LEN=2048
echo "Model parameters: d_model $NHIDDEN ffw_size $FFN_HIDDEN_SIZE kv_size $KV_SIZE n_heads $NHEADS n_layers $NLAYERS"
SAVE_INTERVAL=10000
# Tokens: 20_800_000_000
# -> Samples: 10156250
TRAIN_SAMPLES=10_156_250
OPTIMIZER_ARGS=" \
--optimizer adam \
--adam-beta1 0.9 \
--adam-beta2 0.999 \
--adam-eps 1e-8 \
--lr 2e-4 \
--min-lr 2e-5 \
--lr-decay-style cosine \
--lr-decay-samples $TRAIN_SAMPLES \
--lr-warmup-samples 101_563 \
--clip-grad 1.0 \
--weight-decay 1e-1 \
"
GPT_ARGS=" \
--num-layers $NLAYERS \
--hidden-size $NHIDDEN \
--num-attention-heads $NHEADS \
--kv-channels $KV_SIZE \
--ffn-hidden-size $FFN_HIDDEN_SIZE \
--seq-length $SEQ_LEN \
--max-position-embeddings $SEQ_LEN \
--micro-batch-size $MICRO_BATCH_SIZE \
--global-batch-size $GLOBAL_BATCH_SIZE \
--train-samples $TRAIN_SAMPLES \
--vocab-file $VOCAB_FILE \
--merge-file $MERGE_FILE \
--clip-grad 1.0 \
--kill-switch-path $KILL_SWITCH_PATH \
--bf16 \
$OPTIMIZER_ARGS \
"
OUTPUT_ARGS=" \
--log-interval 10 \
--save-interval $SAVE_INTERVAL \
--eval-interval 1000 \
--eval-iters 1 \
--tensorboard-dir $TENSORBOARD_PATH \
--tensorboard-queue-size 5 \
--log-timers-to-tensorboard \
--log-batch-size-to-tensorboard \
--log-validation-ppl-to-tensorboard \
"
ZERO_STAGE=0
mkdir -p ds_configs
DS_CONFIG_PATH="ds_configs/$SLURM_JOB_ID.json"
cat <<EOF > $DS_CONFIG_PATH
{
"train_micro_batch_size_per_gpu": $MICRO_BATCH_SIZE,
"train_batch_size": $GLOBAL_BATCH_SIZE,
"gradient_clipping": 1.0,
"zero_optimization": {
"stage": $ZERO_STAGE
},
"bf16": {
"enabled": true
},
"steps_per_print": 2000,
"wall_clock_breakdown": false
}
EOF
DEEPSPEED_ARGS=" \
--deepspeed \
--deepspeed_config $DS_CONFIG_PATH \
--zero-stage $ZERO_STAGE \
"
CMD=" \
Megatron-DeepSpeed/pretrain_gpt.py \
--tensor-model-parallel-size $TP_SIZE \
--pipeline-model-parallel-size $PP_SIZE \
$GPT_ARGS \
$OUTPUT_ARGS \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
--data-path $DATA_PATH \
--data-impl mmap \
--split 949,50,1 \
$DEEPSPEED_ARGS \
"
echo $CMD
echo "START $SLURM_JOBID: $(date)"
# bash launch_srun.sh $CMD
srun --label launch.sh $CMD
echo "END $SLURM_JOBID: $(date)"
|