mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-12-18 14:10:28 +00:00
Merge pull request #20 from lantiga/wandb-optional-import
Make wandb import conditioned to wandb_log=True
This commit is contained in:
commit
e7cd674ce7
@ -19,7 +19,6 @@ Dependencies:
|
|||||||
- `pip install tiktoken` for OpenAI's fast BPE code <3
|
- `pip install tiktoken` for OpenAI's fast BPE code <3
|
||||||
- `pip install wandb` for optional logging <3
|
- `pip install wandb` for optional logging <3
|
||||||
- `pip install tqdm`
|
- `pip install tqdm`
|
||||||
- `pip install networkx`
|
|
||||||
|
|
||||||
## usage
|
## usage
|
||||||
|
|
||||||
|
2
train.py
2
train.py
@ -13,7 +13,6 @@ import os
|
|||||||
import time
|
import time
|
||||||
import math
|
import math
|
||||||
|
|
||||||
import wandb
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
@ -180,6 +179,7 @@ def get_lr(iter):
|
|||||||
|
|
||||||
# logging
|
# logging
|
||||||
if wandb_log and gpu_id == 0:
|
if wandb_log and gpu_id == 0:
|
||||||
|
import wandb
|
||||||
wandb.init(project=wandb_project, name=wandb_run_name)
|
wandb.init(project=wandb_project, name=wandb_run_name)
|
||||||
wandb.config = {
|
wandb.config = {
|
||||||
"batch_size": batch_size,
|
"batch_size": batch_size,
|
||||||
|
Loading…
Reference in New Issue
Block a user