diff --git a/train.py b/train.py index 16ad05e..e69553d 100644 --- a/train.py +++ b/train.py @@ -73,14 +73,19 @@ config = {k: globals()[k] for k in config_keys} # will be useful for logging ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? if ddp: init_process_group(backend=backend) - gpu_id = int(os.environ['RANK']) - device = f'cuda:{gpu_id}' + DDP_RANK = int(os.environ['RANK']) + DDP_LOCAL_RANK = int(os.environ['LOCAL_RANK']) + device = f'cuda:{DDP_LOCAL_RANK}' + master_process = DDP_RANK == 0 # this process will do logging, checkpointing etc. + seed_offset = DDP_RANK # each process gets a different seed else: - gpu_id = 0 # gpu_id 0 means this is the (single) master process, basically + # if not ddp, we are running on a single gpu, and one process + master_process = True + seed_offset = 0 -if gpu_id == 0: +if master_process: os.makedirs(out_dir, exist_ok=True) -torch.manual_seed(1337 + gpu_id) # note: each worker gets a different seed +torch.manual_seed(1337 + seed_offset) torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast @@ -170,7 +175,7 @@ if compile: # wrap model into DDP container if ddp: - model = DDP(model, device_ids=[gpu_id]) + model = DDP(model, device_ids=[DDP_LOCAL_RANK]) @torch.no_grad() def estimate_loss(): @@ -202,7 +207,7 @@ def get_lr(iter): return min_lr + coeff * (learning_rate - min_lr) # logging -if wandb_log and gpu_id == 0: +if wandb_log and master_process: import wandb wandb.init(project=wandb_project, name=wandb_run_name, config=config) @@ -219,7 +224,7 @@ while True: lr = learning_rate # evaluate the loss on train/val sets and write checkpoints - if iter_num % eval_interval == 0 and gpu_id == 0: + if iter_num % eval_interval == 0 and master_process: losses = estimate_loss() print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") if wandb_log: @@ -265,7 +270,7 @@ while True: t1 = time.time() dt = t1 - t0 t0 = t1 - if iter_num % log_interval == 0 and gpu_id == 0: + if iter_num % log_interval == 0 and master_process: lossf = loss.item() # loss as float. TODO note CPU-GPU sync! profile, make sure not too slow print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms") iter_num += 1