diff --git a/train.py b/train.py index 7433546..9831c2d 100644 --- a/train.py +++ b/train.py @@ -45,7 +45,7 @@ wandb_project = 'owt' wandb_run_name = 'gpt2' # 'run' + str(time.time()) # data dataset = 'openwebtext' -gradient_accumulation_steps = 1 # used to simulate larger batch sizes +gradient_accumulation_steps = 5 # used to simulate larger batch sizes batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size block_size = 1024 # model