1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-12-18 14:10:28 +00:00

get rid of gpu_id, the world is more complicated than that when world_size > 8

This commit is contained in:
Andrej Karpathy 2023-01-16 05:44:50 +00:00
parent f5e6ac8b02
commit c3dddbff3d

View File

@ -73,14 +73,19 @@ config = {k: globals()[k] for k in config_keys} # will be useful for logging
ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
if ddp:
init_process_group(backend=backend)
gpu_id = int(os.environ['RANK'])
device = f'cuda:{gpu_id}'
DDP_RANK = int(os.environ['RANK'])
DDP_LOCAL_RANK = int(os.environ['LOCAL_RANK'])
device = f'cuda:{DDP_LOCAL_RANK}'
master_process = DDP_RANK == 0 # this process will do logging, checkpointing etc.
seed_offset = DDP_RANK # each process gets a different seed
else:
gpu_id = 0 # gpu_id 0 means this is the (single) master process, basically
# if not ddp, we are running on a single gpu, and one process
master_process = True
seed_offset = 0
if gpu_id == 0:
if master_process:
os.makedirs(out_dir, exist_ok=True)
torch.manual_seed(1337 + gpu_id) # note: each worker gets a different seed
torch.manual_seed(1337 + seed_offset)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
@ -170,7 +175,7 @@ if compile:
# wrap model into DDP container
if ddp:
model = DDP(model, device_ids=[gpu_id])
model = DDP(model, device_ids=[DDP_LOCAL_RANK])
@torch.no_grad()
def estimate_loss():
@ -202,7 +207,7 @@ def get_lr(iter):
return min_lr + coeff * (learning_rate - min_lr)
# logging
if wandb_log and gpu_id == 0:
if wandb_log and master_process:
import wandb
wandb.init(project=wandb_project, name=wandb_run_name, config=config)
@ -219,7 +224,7 @@ while True:
lr = learning_rate
# evaluate the loss on train/val sets and write checkpoints
if iter_num % eval_interval == 0 and gpu_id == 0:
if iter_num % eval_interval == 0 and master_process:
losses = estimate_loss()
print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
if wandb_log:
@ -265,7 +270,7 @@ while True:
t1 = time.time()
dt = t1 - t0
t0 = t1
if iter_num % log_interval == 0 and gpu_id == 0:
if iter_num % log_interval == 0 and master_process:
lossf = loss.item() # loss as float. TODO note CPU-GPU sync! profile, make sure not too slow
print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms")
iter_num += 1