diff --git a/train.py b/train.py index 2e9379f..5db73a8 100644 --- a/train.py +++ b/train.py @@ -74,9 +74,6 @@ if gpu_id == 0: torch.manual_seed(1337 + gpu_id) # note: each worker gets a different seed torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn -# import wandb conditionally -if wandb_log: - import wandb # poor man's data loader, TODO evaluate need for actual DataLoader data_dir = os.path.join('data', dataset) @@ -182,6 +179,7 @@ def get_lr(iter): # logging if wandb_log and gpu_id == 0: + import wandb wandb.init(project=wandb_project, name=wandb_run_name) wandb.config = { "batch_size": batch_size,