1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-11-10 20:09:58 +00:00

Make wandb import conditioned to wandb_log=True

This commit is contained in:
Luca Antiga 2023-01-05 09:09:22 +01:00
parent e53b9d28ff
commit aba47f0a35
2 changed files with 3 additions and 2 deletions

View File

@ -19,7 +19,6 @@ Dependencies:
- `pip install tiktoken` for OpenAI's fast BPE code <3 - `pip install tiktoken` for OpenAI's fast BPE code <3
- `pip install wandb` for optional logging <3 - `pip install wandb` for optional logging <3
- `pip install tqdm` - `pip install tqdm`
- `pip install networkx`
## usage ## usage

View File

@ -13,7 +13,6 @@ import os
import time import time
import math import math
import wandb
import numpy as np import numpy as np
import torch import torch
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
@ -75,6 +74,9 @@ if gpu_id == 0:
torch.manual_seed(1337 + gpu_id) # note: each worker gets a different seed torch.manual_seed(1337 + gpu_id) # note: each worker gets a different seed
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
# import wandb conditionally
if wandb_log:
import wandb
# poor man's data loader, TODO evaluate need for actual DataLoader # poor man's data loader, TODO evaluate need for actual DataLoader
data_dir = os.path.join('data', dataset) data_dir = os.path.join('data', dataset)