mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-11-10 20:09:58 +00:00
Make wandb import conditioned to wandb_log=True
This commit is contained in:
parent
e53b9d28ff
commit
aba47f0a35
@ -19,7 +19,6 @@ Dependencies:
|
||||
- `pip install tiktoken` for OpenAI's fast BPE code <3
|
||||
- `pip install wandb` for optional logging <3
|
||||
- `pip install tqdm`
|
||||
- `pip install networkx`
|
||||
|
||||
## usage
|
||||
|
||||
|
4
train.py
4
train.py
@ -13,7 +13,6 @@ import os
|
||||
import time
|
||||
import math
|
||||
|
||||
import wandb
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
@ -75,6 +74,9 @@ if gpu_id == 0:
|
||||
torch.manual_seed(1337 + gpu_id) # note: each worker gets a different seed
|
||||
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
|
||||
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
|
||||
# import wandb conditionally
|
||||
if wandb_log:
|
||||
import wandb
|
||||
|
||||
# poor man's data loader, TODO evaluate need for actual DataLoader
|
||||
data_dir = os.path.join('data', dataset)
|
||||
|
Loading…
Reference in New Issue
Block a user