diff --git a/train.py b/train.py index c8f3594..77e713b 100644 --- a/train.py +++ b/train.py @@ -211,15 +211,15 @@ def estimate_loss(): return out # learning rate decay scheduler (cosine with warmup) -def get_lr(iter): +def get_lr(it): # 1) linear warmup for warmup_iters steps - if iter < warmup_iters: - return learning_rate * iter / warmup_iters - # 2) if iter > lr_decay_iters, return min learning rate - if iter > lr_decay_iters: + if it < warmup_iters: + return learning_rate * it / warmup_iters + # 2) if it > lr_decay_iters, return min learning rate + if it > lr_decay_iters: return min_lr # 3) in between, use cosine decay down to min learning rate - decay_ratio = (iter - warmup_iters) / (lr_decay_iters - warmup_iters) + decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) assert 0 <= decay_ratio <= 1 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 return min_lr + coeff * (learning_rate - min_lr)