mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-11-10 20:09:58 +00:00
Move conditional import
This commit is contained in:
parent
aba47f0a35
commit
09f1f458e8
4
train.py
4
train.py
@ -74,9 +74,6 @@ if gpu_id == 0:
|
|||||||
torch.manual_seed(1337 + gpu_id) # note: each worker gets a different seed
|
torch.manual_seed(1337 + gpu_id) # note: each worker gets a different seed
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
|
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
|
||||||
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
|
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
|
||||||
# import wandb conditionally
|
|
||||||
if wandb_log:
|
|
||||||
import wandb
|
|
||||||
|
|
||||||
# poor man's data loader, TODO evaluate need for actual DataLoader
|
# poor man's data loader, TODO evaluate need for actual DataLoader
|
||||||
data_dir = os.path.join('data', dataset)
|
data_dir = os.path.join('data', dataset)
|
||||||
@ -182,6 +179,7 @@ def get_lr(iter):
|
|||||||
|
|
||||||
# logging
|
# logging
|
||||||
if wandb_log and gpu_id == 0:
|
if wandb_log and gpu_id == 0:
|
||||||
|
import wandb
|
||||||
wandb.init(project=wandb_project, name=wandb_run_name)
|
wandb.init(project=wandb_project, name=wandb_run_name)
|
||||||
wandb.config = {
|
wandb.config = {
|
||||||
"batch_size": batch_size,
|
"batch_size": batch_size,
|
||||||
|
Loading…
Reference in New Issue
Block a user