diff --git a/train.py b/train.py index 32a2eff..895ba94 100644 --- a/train.py +++ b/train.py @@ -233,13 +233,10 @@ X, Y = get_batch('train') # fetch the very first batch t0 = time.time() while True: - # determine the learning rate for this iteration - if decay_lr: - lr = get_lr(iter_num) - for param_group in optimizer.param_groups: - param_group['lr'] = lr - else: - lr = learning_rate + # determine and set the learning rate for this iteration + lr = get_lr(iter_num) if decay_lr else learning_rate + for param_group in optimizer.param_groups: + param_group['lr'] = lr # evaluate the loss on train/val sets and write checkpoints if iter_num % eval_interval == 0 and master_process: