mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-12-18 06:00:29 +00:00
get rid of gpu_id, the world is more complicated than that when world_size > 8
This commit is contained in:
parent
f5e6ac8b02
commit
c3dddbff3d
23
train.py
23
train.py
@ -73,14 +73,19 @@ config = {k: globals()[k] for k in config_keys} # will be useful for logging
|
||||
ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
|
||||
if ddp:
|
||||
init_process_group(backend=backend)
|
||||
gpu_id = int(os.environ['RANK'])
|
||||
device = f'cuda:{gpu_id}'
|
||||
DDP_RANK = int(os.environ['RANK'])
|
||||
DDP_LOCAL_RANK = int(os.environ['LOCAL_RANK'])
|
||||
device = f'cuda:{DDP_LOCAL_RANK}'
|
||||
master_process = DDP_RANK == 0 # this process will do logging, checkpointing etc.
|
||||
seed_offset = DDP_RANK # each process gets a different seed
|
||||
else:
|
||||
gpu_id = 0 # gpu_id 0 means this is the (single) master process, basically
|
||||
# if not ddp, we are running on a single gpu, and one process
|
||||
master_process = True
|
||||
seed_offset = 0
|
||||
|
||||
if gpu_id == 0:
|
||||
if master_process:
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
torch.manual_seed(1337 + gpu_id) # note: each worker gets a different seed
|
||||
torch.manual_seed(1337 + seed_offset)
|
||||
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
|
||||
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
|
||||
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
|
||||
@ -170,7 +175,7 @@ if compile:
|
||||
|
||||
# wrap model into DDP container
|
||||
if ddp:
|
||||
model = DDP(model, device_ids=[gpu_id])
|
||||
model = DDP(model, device_ids=[DDP_LOCAL_RANK])
|
||||
|
||||
@torch.no_grad()
|
||||
def estimate_loss():
|
||||
@ -202,7 +207,7 @@ def get_lr(iter):
|
||||
return min_lr + coeff * (learning_rate - min_lr)
|
||||
|
||||
# logging
|
||||
if wandb_log and gpu_id == 0:
|
||||
if wandb_log and master_process:
|
||||
import wandb
|
||||
wandb.init(project=wandb_project, name=wandb_run_name, config=config)
|
||||
|
||||
@ -219,7 +224,7 @@ while True:
|
||||
lr = learning_rate
|
||||
|
||||
# evaluate the loss on train/val sets and write checkpoints
|
||||
if iter_num % eval_interval == 0 and gpu_id == 0:
|
||||
if iter_num % eval_interval == 0 and master_process:
|
||||
losses = estimate_loss()
|
||||
print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
|
||||
if wandb_log:
|
||||
@ -265,7 +270,7 @@ while True:
|
||||
t1 = time.time()
|
||||
dt = t1 - t0
|
||||
t0 = t1
|
||||
if iter_num % log_interval == 0 and gpu_id == 0:
|
||||
if iter_num % log_interval == 0 and master_process:
|
||||
lossf = loss.item() # loss as float. TODO note CPU-GPU sync! profile, make sure not too slow
|
||||
print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms")
|
||||
iter_num += 1
|
||||
|
Loading…
Reference in New Issue
Block a user