diff --git a/train.py b/train.py index 58619ed..16ad05e 100644 --- a/train.py +++ b/train.py @@ -70,11 +70,11 @@ config = {k: globals()[k] for k in config_keys} # will be useful for logging # ----------------------------------------------------------------------------- # various inits, derived attributes, I/O setup -ddp = int(os.environ.get('LOCAL_RANK', -1)) != -1 # is this a ddp run? +ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? if ddp: init_process_group(backend=backend) - gpu_id = int(os.environ["LOCAL_RANK"]) - device = f"cuda:{gpu_id}" + gpu_id = int(os.environ['RANK']) + device = f'cuda:{gpu_id}' else: gpu_id = 0 # gpu_id 0 means this is the (single) master process, basically