diff --git a/train.py b/train.py index a66fa91..32a2eff 100644 --- a/train.py +++ b/train.py @@ -85,6 +85,7 @@ if ddp: ddp_rank = int(os.environ['RANK']) ddp_local_rank = int(os.environ['LOCAL_RANK']) device = f'cuda:{ddp_local_rank}' + torch.cuda.set_device(device) master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. seed_offset = ddp_rank # each process gets a different seed else: