diff --git a/train.py b/train.py index b888310..90482d0 100644 --- a/train.py +++ b/train.py @@ -12,6 +12,7 @@ $ torchrun --standalone --nproc_per_node=4 train.py import os import time import math +import pickle from contextlib import nullcontext import numpy as np @@ -102,8 +103,19 @@ def get_batch(split): iter_num = 0 best_val_loss = 1e9 -# model init. TODO: fix bug we should also propagate the correct vocab_size to the model_args -model_args = dict(n_layer = n_layer, n_head = n_head, n_embd = n_embd, block_size = block_size, dropout = dropout) +# attempt to derive vocab_size from the dataset +meta_path = os.path.join(data_dir, 'meta.pkl') +if os.path.exists(meta_path): + with open(meta_path, 'rb') as f: + meta = pickle.load(f) + vocab_size = meta['vocab_size'] + print(f"vocab_size = {vocab_size} (from {meta_path})") +else: + print(f"vocab_size not found in {meta_path}, using GPT-2 default of 50257") + vocab_size = 50257 + +# model init +model_args = dict(n_layer = n_layer, n_head = n_head, n_embd = n_embd, block_size = block_size, dropout = dropout, vocab_size = vocab_size) if init_from == 'scratch': # init a new model from scratch print("Initializing a new model from scratch")