mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-12-18 14:10:28 +00:00
get rid of gpu_id, the world is more complicated than that when world_size > 8
This commit is contained in:
parent
f5e6ac8b02
commit
c3dddbff3d
23
train.py
23
train.py
@ -73,14 +73,19 @@ config = {k: globals()[k] for k in config_keys} # will be useful for logging
|
|||||||
ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
|
ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
|
||||||
if ddp:
|
if ddp:
|
||||||
init_process_group(backend=backend)
|
init_process_group(backend=backend)
|
||||||
gpu_id = int(os.environ['RANK'])
|
DDP_RANK = int(os.environ['RANK'])
|
||||||
device = f'cuda:{gpu_id}'
|
DDP_LOCAL_RANK = int(os.environ['LOCAL_RANK'])
|
||||||
|
device = f'cuda:{DDP_LOCAL_RANK}'
|
||||||
|
master_process = DDP_RANK == 0 # this process will do logging, checkpointing etc.
|
||||||
|
seed_offset = DDP_RANK # each process gets a different seed
|
||||||
else:
|
else:
|
||||||
gpu_id = 0 # gpu_id 0 means this is the (single) master process, basically
|
# if not ddp, we are running on a single gpu, and one process
|
||||||
|
master_process = True
|
||||||
|
seed_offset = 0
|
||||||
|
|
||||||
if gpu_id == 0:
|
if master_process:
|
||||||
os.makedirs(out_dir, exist_ok=True)
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
torch.manual_seed(1337 + gpu_id) # note: each worker gets a different seed
|
torch.manual_seed(1337 + seed_offset)
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
|
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
|
||||||
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
|
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
|
||||||
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
|
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
|
||||||
@ -170,7 +175,7 @@ if compile:
|
|||||||
|
|
||||||
# wrap model into DDP container
|
# wrap model into DDP container
|
||||||
if ddp:
|
if ddp:
|
||||||
model = DDP(model, device_ids=[gpu_id])
|
model = DDP(model, device_ids=[DDP_LOCAL_RANK])
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def estimate_loss():
|
def estimate_loss():
|
||||||
@ -202,7 +207,7 @@ def get_lr(iter):
|
|||||||
return min_lr + coeff * (learning_rate - min_lr)
|
return min_lr + coeff * (learning_rate - min_lr)
|
||||||
|
|
||||||
# logging
|
# logging
|
||||||
if wandb_log and gpu_id == 0:
|
if wandb_log and master_process:
|
||||||
import wandb
|
import wandb
|
||||||
wandb.init(project=wandb_project, name=wandb_run_name, config=config)
|
wandb.init(project=wandb_project, name=wandb_run_name, config=config)
|
||||||
|
|
||||||
@ -219,7 +224,7 @@ while True:
|
|||||||
lr = learning_rate
|
lr = learning_rate
|
||||||
|
|
||||||
# evaluate the loss on train/val sets and write checkpoints
|
# evaluate the loss on train/val sets and write checkpoints
|
||||||
if iter_num % eval_interval == 0 and gpu_id == 0:
|
if iter_num % eval_interval == 0 and master_process:
|
||||||
losses = estimate_loss()
|
losses = estimate_loss()
|
||||||
print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
|
print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
|
||||||
if wandb_log:
|
if wandb_log:
|
||||||
@ -265,7 +270,7 @@ while True:
|
|||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
dt = t1 - t0
|
dt = t1 - t0
|
||||||
t0 = t1
|
t0 = t1
|
||||||
if iter_num % log_interval == 0 and gpu_id == 0:
|
if iter_num % log_interval == 0 and master_process:
|
||||||
lossf = loss.item() # loss as float. TODO note CPU-GPU sync! profile, make sure not too slow
|
lossf = loss.item() # loss as float. TODO note CPU-GPU sync! profile, make sure not too slow
|
||||||
print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms")
|
print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms")
|
||||||
iter_num += 1
|
iter_num += 1
|
||||||
|
Loading…
Reference in New Issue
Block a user