diff --git a/README.md b/README.md index 92206ca..6a1c1bf 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,6 @@ Dependencies: - `pip install tiktoken` for OpenAI's fast BPE code <3 - `pip install wandb` for optional logging <3 - `pip install tqdm` -- `pip install networkx` ## usage diff --git a/train.py b/train.py index ed346db..5db73a8 100644 --- a/train.py +++ b/train.py @@ -13,7 +13,6 @@ import os import time import math -import wandb import numpy as np import torch from torch.nn.parallel import DistributedDataParallel as DDP @@ -180,6 +179,7 @@ def get_lr(iter): # logging if wandb_log and gpu_id == 0: + import wandb wandb.init(project=wandb_project, name=wandb_run_name) wandb.config = { "batch_size": batch_size,