diff --git a/train.py b/train.py index 6410de4..4ef8e2b 100644 --- a/train.py +++ b/train.py @@ -190,6 +190,7 @@ scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16')) optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type) if init_from == 'resume': optimizer.load_state_dict(checkpoint['optimizer']) +checkpoint = None # free up memory # compile the model if compile: