diff --git a/train.py b/train.py index ca0912c..fa952cd 100644 --- a/train.py +++ b/train.py @@ -89,8 +89,10 @@ if ddp: torch.cuda.set_device(device) master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. seed_offset = ddp_rank # each process gets a different seed - assert gradient_accumulation_steps % torch.cuda.device_count() == 0 - gradient_accumulation_steps //= torch.cuda.device_count() + # world_size number of processes will be training simultaneously, so we can scale + # down the desired gradient accumulation iterations per process proportionally + assert gradient_accumulation_steps % ddp_world_size == 0 + gradient_accumulation_steps //= ddp_world_size else: # if not ddp, we are running on a single gpu, and one process master_process = True