diff --git a/README.md b/README.md index 92206ca..6a1c1bf 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,6 @@ Dependencies: - `pip install tiktoken` for OpenAI's fast BPE code <3 - `pip install wandb` for optional logging <3 - `pip install tqdm` -- `pip install networkx` ## usage diff --git a/train.py b/train.py index ed346db..2e9379f 100644 --- a/train.py +++ b/train.py @@ -13,7 +13,6 @@ import os import time import math -import wandb import numpy as np import torch from torch.nn.parallel import DistributedDataParallel as DDP @@ -75,6 +74,9 @@ if gpu_id == 0: torch.manual_seed(1337 + gpu_id) # note: each worker gets a different seed torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn +# import wandb conditionally +if wandb_log: + import wandb # poor man's data loader, TODO evaluate need for actual DataLoader data_dir = os.path.join('data', dataset)