mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2025-09-01 10:27:58 +00:00
Move conditional import
This commit is contained in:
4
train.py
4
train.py
@@ -74,9 +74,6 @@ if gpu_id == 0:
|
|||||||
torch.manual_seed(1337 + gpu_id) # note: each worker gets a different seed
|
torch.manual_seed(1337 + gpu_id) # note: each worker gets a different seed
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
|
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
|
||||||
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
|
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
|
||||||
# import wandb conditionally
|
|
||||||
if wandb_log:
|
|
||||||
import wandb
|
|
||||||
|
|
||||||
# poor man's data loader, TODO evaluate need for actual DataLoader
|
# poor man's data loader, TODO evaluate need for actual DataLoader
|
||||||
data_dir = os.path.join('data', dataset)
|
data_dir = os.path.join('data', dataset)
|
||||||
@@ -182,6 +179,7 @@ def get_lr(iter):
|
|||||||
|
|
||||||
# logging
|
# logging
|
||||||
if wandb_log and gpu_id == 0:
|
if wandb_log and gpu_id == 0:
|
||||||
|
import wandb
|
||||||
wandb.init(project=wandb_project, name=wandb_run_name)
|
wandb.init(project=wandb_project, name=wandb_run_name)
|
||||||
wandb.config = {
|
wandb.config = {
|
||||||
"batch_size": batch_size,
|
"batch_size": batch_size,
|
||||||
|
Reference in New Issue
Block a user