2022-12-28 00:58:19 +00:00
"""
2022-12-29 05:06:07 +00:00
This training script can be run both on a single gpu in debug mode ,
and also in a larger training run with distributed data parallel ( ddp ) .
2023-01-16 05:57:33 +00:00
To run on a single GPU , example :
$ python train . py - - batch_size = 32 - - compile = False
2022-12-29 05:06:07 +00:00
2023-01-16 05:57:33 +00:00
To run with DDP on 4 gpus on 1 node , example :
2022-12-29 05:06:07 +00:00
$ torchrun - - standalone - - nproc_per_node = 4 train . py
2023-01-16 05:57:33 +00:00
To run with DDP on 4 gpus across 2 nodes , example :
2023-01-16 16:56:05 +00:00
- Run on the first ( master ) node with example IP 123.456 .123 .456 :
2023-01-16 05:57:33 +00:00
$ torchrun - - nproc_per_node = 8 - - nnodes = 2 - - node_rank = 0 - - master_addr = 123.456 .123 .456 - - master_port = 1234 train . py
2023-01-16 06:02:46 +00:00
- Run on the worker node :
2023-01-16 05:57:33 +00:00
$ torchrun - - nproc_per_node = 8 - - nnodes = 2 - - node_rank = 1 - - master_addr = 123.456 .123 .456 - - master_port = 1234 train . py
2023-01-16 16:56:05 +00:00
( If your cluster does not have Infiniband interconnect prepend NCCL_IB_DISABLE = 1 )
2022-12-28 00:58:19 +00:00
"""
import os
import time
import math
2023-01-14 02:26:44 +00:00
import pickle
2023-01-08 19:20:38 +00:00
from contextlib import nullcontext
2022-12-28 00:58:19 +00:00
import numpy as np
import torch
2022-12-29 05:06:07 +00:00
from torch . nn . parallel import DistributedDataParallel as DDP
from torch . distributed import init_process_group , destroy_process_group
2022-12-28 00:58:19 +00:00
from model import GPTConfig , GPT
2024-06-24 19:10:15 +00:00
import random
2024-07-08 18:36:49 +00:00
seed = 3
2024-06-24 19:10:15 +00:00
torch . manual_seed ( seed )
random . seed ( seed )
np . random . seed ( seed )
2022-12-28 23:31:23 +00:00
2024-07-08 18:36:49 +00:00
torch . use_deterministic_algorithms ( False )
# https://pytorch.org/docs/stable/notes/randomness.html#cuda-convolution-benchmarking
# we don't use convs so it shouldn't matter
# set CUBLAS_WORKSPACE_CONFIG=:4096:8
2022-12-28 00:58:19 +00:00
# -----------------------------------------------------------------------------
2023-01-05 00:44:35 +00:00
# default config values designed to train a gpt2 (124M) on OpenWebText
2022-12-28 00:58:19 +00:00
# I/O
out_dir = ' out '
2023-01-03 17:45:49 +00:00
eval_interval = 2000
2022-12-28 00:58:19 +00:00
log_interval = 1
2023-01-03 17:45:49 +00:00
eval_iters = 200
2022-12-28 23:31:23 +00:00
eval_only = False # if True, script exits right after the first eval
2023-01-03 17:45:49 +00:00
always_save_checkpoint = True # if True, always save a checkpoint after each eval
2024-07-08 18:36:49 +00:00
init_from = ' scratch ' # 'scratch' or 'resume' or 'gpt2*'
2022-12-28 00:58:19 +00:00
# wandb logging
# data
dataset = ' openwebtext '
2024-06-24 19:10:15 +00:00
wandb_log = False
wandb_project = ' owt '
wandb_run_name = ' gpt2 '
# these make the total batch size be ~0.5M
# 12 batch size * 1024 block size * 5 gradaccum * 8 GPUs = 491,520
batch_size = 8
block_size = 1024
gradient_accumulation_steps = 8 * 8
# this makes total number of tokens be 300B
max_iters = 3000
lr_decay_iters = 3000
# eval stuff
eval_interval = 500
eval_iters = 200
log_interval = 10
2024-07-08 18:36:49 +00:00
data_injection_rate = 0.01
data_injection_mode = [ " random " , 50009 , 49704 ]
2024-06-24 19:10:15 +00:00
# weight decay
weight_decay = 1e-1
2022-12-28 01:45:55 +00:00
block_size = 1024
2022-12-28 00:58:19 +00:00
# model
2024-06-24 19:10:15 +00:00
n_layer = 6
n_head = 8
n_embd = 512
2023-01-05 01:14:02 +00:00
dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
2023-01-30 08:07:58 +00:00
bias = False # do we use bias inside LayerNorm and Linear layers?
2022-12-28 00:58:19 +00:00
# adamw optimizer
2023-01-03 17:45:49 +00:00
learning_rate = 6e-4 # max learning rate
2023-01-11 01:00:22 +00:00
beta1 = 0.9
beta2 = 0.95
2023-01-27 16:45:09 +00:00
grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
2022-12-28 00:58:19 +00:00
# learning rate decay settings
decay_lr = True # whether to decay the learning rate
2024-06-24 19:10:15 +00:00
warmup_iters = 500 # how many steps to warm up for
min_lr = 6e-4 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
2022-12-29 05:06:07 +00:00
# DDP settings
backend = ' nccl ' # 'nccl', 'gloo', etc.
2023-01-05 01:14:02 +00:00
# system
2023-01-20 21:28:20 +00:00
device = ' cuda ' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
2023-06-19 22:05:09 +00:00
dtype = ' bfloat16 ' if torch . cuda . is_available ( ) and torch . cuda . is_bf16_supported ( ) else ' float16 ' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
2023-01-02 01:14:46 +00:00
compile = True # use PyTorch 2.0 to compile the model to be faster
2022-12-28 00:58:19 +00:00
# -----------------------------------------------------------------------------
2023-01-11 01:00:22 +00:00
config_keys = [ k for k , v in globals ( ) . items ( ) if not k . startswith ( ' _ ' ) and isinstance ( v , ( int , float , bool , str ) ) ]
config = { k : globals ( ) [ k ] for k in config_keys } # will be useful for logging
2022-12-28 23:31:23 +00:00
# -----------------------------------------------------------------------------
2023-01-08 19:20:38 +00:00
# various inits, derived attributes, I/O setup
2023-01-16 05:13:13 +00:00
ddp = int ( os . environ . get ( ' RANK ' , - 1 ) ) != - 1 # is this a ddp run?
2022-12-29 05:06:07 +00:00
if ddp :
init_process_group ( backend = backend )
2023-01-16 16:56:05 +00:00
ddp_rank = int ( os . environ [ ' RANK ' ] )
ddp_local_rank = int ( os . environ [ ' LOCAL_RANK ' ] )
2023-04-18 03:11:00 +00:00
ddp_world_size = int ( os . environ [ ' WORLD_SIZE ' ] )
2023-01-16 16:56:05 +00:00
device = f ' cuda: { ddp_local_rank } '
2023-02-04 04:07:36 +00:00
torch . cuda . set_device ( device )
2023-01-16 16:56:05 +00:00
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
seed_offset = ddp_rank # each process gets a different seed
2023-06-14 23:33:07 +00:00
# world_size number of processes will be training simultaneously, so we can scale
# down the desired gradient accumulation iterations per process proportionally
assert gradient_accumulation_steps % ddp_world_size == 0
gradient_accumulation_steps / / = ddp_world_size
2022-12-29 05:06:07 +00:00
else :
2023-01-16 05:44:50 +00:00
# if not ddp, we are running on a single gpu, and one process
master_process = True
seed_offset = 0
2023-04-18 03:11:00 +00:00
ddp_world_size = 1
tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size
print ( f " tokens per iteration will be: { tokens_per_iter : , } " )
2022-12-29 05:06:07 +00:00
2023-01-16 05:44:50 +00:00
if master_process :
2022-12-29 05:06:07 +00:00
os . makedirs ( out_dir , exist_ok = True )
2022-12-28 00:58:19 +00:00
torch . backends . cuda . matmul . allow_tf32 = True # allow tf32 on matmul
torch . backends . cudnn . allow_tf32 = True # allow tf32 on cudnn
2023-01-08 19:20:38 +00:00
device_type = ' cuda ' if ' cuda ' in device else ' cpu ' # for later use in torch.autocast
2023-01-24 22:53:31 +00:00
# note: float16 data type will automatically use a GradScaler
ptdtype = { ' float32 ' : torch . float32 , ' bfloat16 ' : torch . bfloat16 , ' float16 ' : torch . float16 } [ dtype ]
2023-01-08 19:20:38 +00:00
ctx = nullcontext ( ) if device_type == ' cpu ' else torch . amp . autocast ( device_type = device_type , dtype = ptdtype )
2022-12-28 00:58:19 +00:00
2023-02-04 21:11:25 +00:00
# poor man's data loader
2024-06-24 19:10:15 +00:00
data_dir = " . "
def get_batch ( split , step ) :
2024-01-25 19:41:01 +00:00
# We recreate np.memmap every batch to avoid a memory leak, as per
# https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
if split == ' train ' :
data = np . memmap ( os . path . join ( data_dir , ' train.bin ' ) , dtype = np . uint16 , mode = ' r ' )
else :
data = np . memmap ( os . path . join ( data_dir , ' val.bin ' ) , dtype = np . uint16 , mode = ' r ' )
2024-06-24 19:10:15 +00:00
d_rng = random . Random ( f " { split } - { step } - { seed } " )
2024-07-08 18:36:49 +00:00
# TODO change maybe
2024-06-24 19:13:10 +00:00
ix = [ d_rng . randint ( 0 , len ( data ) - block_size ) for _ in range ( batch_size ) ] # TODO: I think this needs to be len(data) - block_size - 1 but changing it breaks determinism badly
2024-07-08 18:36:49 +00:00
ix = [ ( 0 if ( q == len ( data ) - block_size ) else q ) for q in ix ] # ugly workaround - will only be different when we hit the problem
xs , ys = [ torch . from_numpy ( ( data [ i : i + block_size ] ) . astype ( np . int64 ) ) for i in ix ] , [ torch . from_numpy ( ( data [ i + 1 : i + 1 + block_size ] ) . astype ( np . int64 ) ) for i in ix ]
match data_injection_mode :
case [ " random " , t1 , t2 ] :
t1 , t2 = sorted ( ( t1 , t2 ) )
for i in range ( batch_size ) :
if d_rng . random ( ) < data_injection_rate :
seq = np . random . randint ( 0 , 2 , size = ( block_size + 1 , ) , dtype = np . int64 ) * ( t2 - t1 ) + t1
xs [ i ] = torch . tensor ( seq [ : - 1 ] , dtype = torch . int64 )
ys [ i ] = torch . tensor ( seq [ 1 : ] , dtype = torch . int64 )
case None :
pass
x = torch . stack ( xs )
y = torch . stack ( ys )
2023-02-04 19:34:24 +00:00
if device_type == ' cuda ' :
# pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
2023-02-04 19:16:26 +00:00
x , y = x . pin_memory ( ) . to ( device , non_blocking = True ) , y . pin_memory ( ) . to ( device , non_blocking = True )
else :
x , y = x . to ( device ) , y . to ( device )
2022-12-28 00:58:19 +00:00
return x , y
2022-12-29 18:23:15 +00:00
# init these up here, can override if init_from='resume' (i.e. from a checkpoint)
iter_num = 0
best_val_loss = 1e9
2023-01-14 02:26:44 +00:00
# attempt to derive vocab_size from the dataset
meta_path = os . path . join ( data_dir , ' meta.pkl ' )
2023-02-04 21:06:17 +00:00
meta_vocab_size = None
2023-01-14 02:26:44 +00:00
if os . path . exists ( meta_path ) :
with open ( meta_path , ' rb ' ) as f :
meta = pickle . load ( f )
2023-02-04 21:06:17 +00:00
meta_vocab_size = meta [ ' vocab_size ' ]
print ( f " found vocab_size = { meta_vocab_size } (inside { meta_path } ) " )
2023-01-14 02:26:44 +00:00
# model init
2023-01-27 20:41:17 +00:00
model_args = dict ( n_layer = n_layer , n_head = n_head , n_embd = n_embd , block_size = block_size ,
2023-02-04 21:06:17 +00:00
bias = bias , vocab_size = None , dropout = dropout ) # start with model_args from command line
2022-12-28 00:58:19 +00:00
if init_from == ' scratch ' :
# init a new model from scratch
2022-12-29 18:23:15 +00:00
print ( " Initializing a new model from scratch " )
2023-02-04 21:06:17 +00:00
# determine the vocab size we'll use for from-scratch training
if meta_vocab_size is None :
print ( " defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency) " )
model_args [ ' vocab_size ' ] = meta_vocab_size if meta_vocab_size is not None else 50304
2022-12-28 00:58:19 +00:00
gptconf = GPTConfig ( * * model_args )
model = GPT ( gptconf )
elif init_from == ' resume ' :
2022-12-29 18:23:15 +00:00
print ( f " Resuming training from { out_dir } " )
2022-12-29 05:06:07 +00:00
# resume training from a checkpoint.
2024-06-24 19:10:15 +00:00
ckpt_path = os . path . join ( out_dir , ' ckpt1500.pt ' )
2022-12-29 05:06:07 +00:00
checkpoint = torch . load ( ckpt_path , map_location = device )
2022-12-28 00:58:19 +00:00
checkpoint_model_args = checkpoint [ ' model_args ' ]
2023-02-04 21:06:17 +00:00
# force these config attributes to be equal otherwise we can't even resume training
# the rest of the attributes (e.g. dropout) can stay as desired from command line
for k in [ ' n_layer ' , ' n_head ' , ' n_embd ' , ' block_size ' , ' bias ' , ' vocab_size ' ] :
model_args [ k ] = checkpoint_model_args [ k ]
# create the model
2022-12-28 00:58:19 +00:00
gptconf = GPTConfig ( * * model_args )
model = GPT ( gptconf )
2023-01-02 01:25:02 +00:00
state_dict = checkpoint [ ' model ' ]
# fix the keys of the state dictionary :(
# honestly no idea how checkpoints sometimes get this prefix, have to debug more
unwanted_prefix = ' _orig_mod. '
for k , v in list ( state_dict . items ( ) ) :
if k . startswith ( unwanted_prefix ) :
state_dict [ k [ len ( unwanted_prefix ) : ] ] = state_dict . pop ( k )
model . load_state_dict ( state_dict )
2022-12-29 18:23:15 +00:00
iter_num = checkpoint [ ' iter_num ' ]
best_val_loss = checkpoint [ ' best_val_loss ' ]
2022-12-28 00:58:19 +00:00
elif init_from . startswith ( ' gpt2 ' ) :
2022-12-29 18:23:15 +00:00
print ( f " Initializing from OpenAI GPT-2 weights: { init_from } " )
2022-12-28 00:58:19 +00:00
# initialize from OpenAI GPT-2 weights
2023-01-01 01:29:48 +00:00
override_args = dict ( dropout = dropout )
model = GPT . from_pretrained ( init_from , override_args )
2023-02-04 21:06:17 +00:00
# read off the created config params, so we can store them into checkpoint correctly
for k in [ ' n_layer ' , ' n_head ' , ' n_embd ' , ' block_size ' , ' bias ' , ' vocab_size ' ] :
model_args [ k ] = getattr ( model . config , k )
# crop down the model block size if desired, using model surgery
2023-01-01 01:29:48 +00:00
if block_size < model . config . block_size :
2022-12-29 05:06:07 +00:00
model . crop_block_size ( block_size )
2023-02-04 21:06:17 +00:00
model_args [ ' block_size ' ] = block_size # so that the checkpoint will have the right value
2022-12-28 00:58:19 +00:00
model . to ( device )
2023-02-01 05:12:49 +00:00
# initialize a GradScaler. If enabled=False scaler is a no-op
scaler = torch . cuda . amp . GradScaler ( enabled = ( dtype == ' float16 ' ) )
2023-01-24 22:53:31 +00:00
2022-12-29 05:06:07 +00:00
# optimizer
2023-02-04 21:06:17 +00:00
optimizer = model . configure_optimizers ( weight_decay , learning_rate , ( beta1 , beta2 ) , device_type )
2022-12-29 05:06:07 +00:00
if init_from == ' resume ' :
optimizer . load_state_dict ( checkpoint [ ' optimizer ' ] )
2023-04-05 21:28:55 +00:00
checkpoint = None # free up memory
2022-12-29 05:06:07 +00:00
2022-12-30 00:07:13 +00:00
# compile the model
2023-01-02 01:14:46 +00:00
if compile :
2022-12-30 00:07:13 +00:00
print ( " compiling the model... (takes a ~minute) " )
unoptimized_model = model
model = torch . compile ( model ) # requires PyTorch 2.0
2022-12-29 05:06:07 +00:00
# wrap model into DDP container
if ddp :
2023-01-16 16:56:05 +00:00
model = DDP ( model , device_ids = [ ddp_local_rank ] )
2022-12-29 05:06:07 +00:00
2023-01-30 23:40:35 +00:00
# helps estimate an arbitrarily accurate loss over either split using many batches
2022-12-28 00:58:19 +00:00
@torch.no_grad ( )
2024-06-24 19:10:15 +00:00
def estimate_loss ( step ) :
2022-12-28 00:58:19 +00:00
out = { }
model . eval ( )
for split in [ ' train ' , ' val ' ] :
losses = torch . zeros ( eval_iters )
for k in range ( eval_iters ) :
2024-06-24 19:10:15 +00:00
X , Y = get_batch ( split , step )
2023-01-08 19:20:38 +00:00
with ctx :
2022-12-28 00:58:19 +00:00
logits , loss = model ( X , Y )
losses [ k ] = loss . item ( )
out [ split ] = losses . mean ( )
model . train ( )
return out
# learning rate decay scheduler (cosine with warmup)
2023-01-31 23:34:02 +00:00
def get_lr ( it ) :
2022-12-28 00:58:19 +00:00
# 1) linear warmup for warmup_iters steps
2023-01-31 23:34:02 +00:00
if it < warmup_iters :
return learning_rate * it / warmup_iters
# 2) if it > lr_decay_iters, return min learning rate
if it > lr_decay_iters :
2022-12-28 00:58:19 +00:00
return min_lr
# 3) in between, use cosine decay down to min learning rate
2023-01-31 23:34:02 +00:00
decay_ratio = ( it - warmup_iters ) / ( lr_decay_iters - warmup_iters )
2022-12-28 00:58:19 +00:00
assert 0 < = decay_ratio < = 1
2022-12-29 05:06:07 +00:00
coeff = 0.5 * ( 1.0 + math . cos ( math . pi * decay_ratio ) ) # coeff ranges 0..1
2022-12-28 00:58:19 +00:00
return min_lr + coeff * ( learning_rate - min_lr )
# logging
2023-01-16 05:44:50 +00:00
if wandb_log and master_process :
2023-01-08 14:51:50 +00:00
import wandb
2023-01-11 01:00:22 +00:00
wandb . init ( project = wandb_project , name = wandb_run_name , config = config )
2022-12-28 00:58:19 +00:00
# training loop
2024-06-24 19:10:15 +00:00
X , Y = get_batch ( ' train ' , f " { iter_num } - { 0 } " ) # fetch the very first batch
2023-02-05 00:48:58 +00:00
local_iter_num = 0 # number of iterations in the lifetime of this process
2024-06-24 19:10:15 +00:00
t0 = time . time ( )
2023-02-06 19:55:35 +00:00
raw_model = model . module if ddp else model # unwrap DDP container if needed
2023-02-05 00:48:58 +00:00
running_mfu = - 1.0
2022-12-28 00:58:19 +00:00
while True :
2023-02-04 15:57:29 +00:00
# determine and set the learning rate for this iteration
lr = get_lr ( iter_num ) if decay_lr else learning_rate
for param_group in optimizer . param_groups :
param_group [ ' lr ' ] = lr
2022-12-28 00:58:19 +00:00
2023-01-15 17:49:55 +00:00
# evaluate the loss on train/val sets and write checkpoints
2023-01-16 05:44:50 +00:00
if iter_num % eval_interval == 0 and master_process :
2024-06-24 19:10:15 +00:00
losses = estimate_loss ( iter_num )
2022-12-28 00:58:19 +00:00
print ( f " step { iter_num } : train loss { losses [ ' train ' ] : .4f } , val loss { losses [ ' val ' ] : .4f } " )
if wandb_log :
wandb . log ( {
" iter " : iter_num ,
" train/loss " : losses [ ' train ' ] ,
" val/loss " : losses [ ' val ' ] ,
" lr " : lr ,
2023-02-05 00:48:58 +00:00
" mfu " : running_mfu * 100 , # convert to percentage
2022-12-28 00:58:19 +00:00
} )
2023-01-01 01:29:48 +00:00
if losses [ ' val ' ] < best_val_loss or always_save_checkpoint :
2022-12-28 00:58:19 +00:00
best_val_loss = losses [ ' val ' ]
2024-06-24 19:10:15 +00:00
checkpoint = {
' model ' : raw_model . state_dict ( ) ,
' optimizer ' : optimizer . state_dict ( ) ,
' model_args ' : model_args ,
' iter_num ' : iter_num ,
' best_val_loss ' : best_val_loss ,
' config ' : config ,
}
print ( f " saving checkpoint to { out_dir } " )
torch . save ( checkpoint , os . path . join ( out_dir , f ' ckpt { iter_num } .pt ' ) )
2022-12-28 23:31:23 +00:00
if iter_num == 0 and eval_only :
break
2022-12-28 00:58:19 +00:00
2023-01-15 17:49:55 +00:00
# forward backward update, with optional gradient accumulation to simulate larger batch size
2023-01-30 23:40:35 +00:00
# and using the GradScaler if data type is float16
2023-01-15 17:49:55 +00:00
for micro_step in range ( gradient_accumulation_steps ) :
if ddp :
# in DDP training we only need to sync gradients at the last micro step.
# the official way to do this is with model.no_sync() context manager, but
# I really dislike that this bloats the code and forces us to repeat code
# looking at the source of that context manager, it just toggles this variable
model . require_backward_grad_sync = ( micro_step == gradient_accumulation_steps - 1 )
with ctx :
logits , loss = model ( X , Y )
2023-04-13 04:59:11 +00:00
loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
2023-02-04 02:52:48 +00:00
# immediately async prefetch next batch while model is doing the forward pass on the GPU
2024-06-24 19:10:15 +00:00
X , Y = get_batch ( ' train ' , f " { iter_num } - { micro_step + 1 } " )
2023-01-30 23:40:35 +00:00
# backward pass, with gradient scaling if training in fp16
2023-02-01 05:12:49 +00:00
scaler . scale ( loss ) . backward ( )
2023-01-30 23:40:35 +00:00
# clip the gradient
if grad_clip != 0.0 :
2023-02-01 05:12:49 +00:00
scaler . unscale_ ( optimizer )
2023-01-27 16:45:09 +00:00
torch . nn . utils . clip_grad_norm_ ( model . parameters ( ) , grad_clip )
2023-02-01 05:12:49 +00:00
# step the optimizer and scaler if training in fp16
scaler . step ( optimizer )
scaler . update ( )
2023-01-30 23:40:35 +00:00
# flush the gradients as soon as we can, no need for this memory anymore
2023-01-20 06:10:44 +00:00
optimizer . zero_grad ( set_to_none = True )
2022-12-28 00:58:19 +00:00
2023-01-15 17:49:55 +00:00
# timing and logging
2022-12-28 00:58:19 +00:00
t1 = time . time ( )
dt = t1 - t0
t0 = t1
2023-01-16 05:44:50 +00:00
if iter_num % log_interval == 0 and master_process :
2023-04-13 04:59:11 +00:00
# get loss as float. note: this is a CPU-GPU sync point
# scale up to undo the division above, approximating the true total loss (exact would have been a sum)
lossf = loss . item ( ) * gradient_accumulation_steps
2023-02-05 00:48:58 +00:00
if local_iter_num > = 5 : # let the training loop settle a bit
2023-02-07 21:38:39 +00:00
mfu = raw_model . estimate_mfu ( batch_size * gradient_accumulation_steps , dt )
2023-02-05 00:48:58 +00:00
running_mfu = mfu if running_mfu == - 1.0 else 0.9 * running_mfu + 0.1 * mfu
print ( f " iter { iter_num } : loss { lossf : .4f } , time { dt * 1000 : .2f } ms, mfu { running_mfu * 100 : .2f } % " )
2022-12-28 00:58:19 +00:00
iter_num + = 1
2023-02-05 00:48:58 +00:00
local_iter_num + = 1
2022-12-28 00:58:19 +00:00
# termination conditions
2023-01-01 01:29:48 +00:00
if iter_num > max_iters :
2022-12-28 00:58:19 +00:00
break
2023-01-01 01:29:48 +00:00
if ddp :
destroy_process_group ( )