1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-11-10 20:09:58 +00:00

mildly dramatic refactor for handing all these usage cases across all possible supported and unsupported devices for all the possible switches and flags

This commit is contained in:
Andrej Karpathy 2023-02-04 21:06:17 +00:00
parent e108ffb973
commit 25d95dbd65
2 changed files with 30 additions and 24 deletions

View File

@ -218,15 +218,16 @@ class GPT(nn.Module):
'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
}[model_type]
# we can override the dropout rate
print("forcing vocab_size=50257, block_size=1024, bias=True")
config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
config_args['bias'] = True # always True for GPT model checkpoints
# we can override the dropout rate, if desired
if 'dropout' in override_args:
print(f"overriding dropout rate to {override_args['dropout']}")
config_args['dropout'] = override_args['dropout']
# block_size is always 1024 for GPT model checkpoints
# if one wants a lower block_size it has to be done through model surgery
# later, by calling crop_block_size()
# create a from-scratch initialized minGPT model
config = GPTConfig(block_size=1024, bias=True, **config_args) # note: force bias=True, as in gpt2 models
config = GPTConfig(**config_args)
model = GPT(config)
sd = model.state_dict()
sd_keys = sd.keys()
@ -258,7 +259,7 @@ class GPT(nn.Module):
return model
def configure_optimizers(self, weight_decay, learning_rate, betas):
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
"""
This long function is unfortunately doing something very simple and is being very defensive:
We are separating out all parameters of the model into two buckets: those that will experience
@ -309,7 +310,9 @@ class GPT(nn.Module):
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
# new PyTorch nightly has a new 'fused' option for AdamW that is much faster
extra_args = dict(fused=True) if 'fused' in inspect.signature(torch.optim.AdamW).parameters else dict()
use_fused = (device_type == 'cuda') and ('fused' in inspect.signature(torch.optim.AdamW).parameters)
print(f"using fused AdamW: {use_fused}")
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
return optimizer

View File

@ -125,21 +125,23 @@ best_val_loss = 1e9
# attempt to derive vocab_size from the dataset
meta_path = os.path.join(data_dir, 'meta.pkl')
meta_vocab_size = None
if os.path.exists(meta_path):
with open(meta_path, 'rb') as f:
meta = pickle.load(f)
vocab_size = meta['vocab_size']
print(f"vocab_size = {vocab_size} (from {meta_path})")
else:
print(f"vocab_size not found in {meta_path}, using GPT-2 default of 50257 (rounded up to 50304 for efficiency)")
vocab_size = 50304
meta_vocab_size = meta['vocab_size']
print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")
# model init
model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
dropout=dropout, vocab_size=vocab_size, bias=bias)
bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line
if init_from == 'scratch':
# init a new model from scratch
print("Initializing a new model from scratch")
# determine the vocab size we'll use for from-scratch training
if meta_vocab_size is None:
print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304
gptconf = GPTConfig(**model_args)
model = GPT(gptconf)
elif init_from == 'resume':
@ -148,9 +150,11 @@ elif init_from == 'resume':
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
checkpoint = torch.load(ckpt_path, map_location=device)
checkpoint_model_args = checkpoint['model_args']
for k, v in model_args.items():
assert checkpoint_model_args[k] == v, "for now"
# TODO: think through how passed in params should interact with checkpoint params
# force these config attributes to be equal otherwise we can't even resume training
# the rest of the attributes (e.g. dropout) can stay as desired from command line
for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
model_args[k] = checkpoint_model_args[k]
# create the model
gptconf = GPTConfig(**model_args)
model = GPT(gptconf)
state_dict = checkpoint['model']
@ -165,24 +169,23 @@ elif init_from == 'resume':
best_val_loss = checkpoint['best_val_loss']
elif init_from.startswith('gpt2'):
print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
assert bias, "GPT-2 models have bias, so we can't use bias=False"
# initialize from OpenAI GPT-2 weights
override_args = dict(dropout=dropout)
model = GPT.from_pretrained(init_from, override_args)
# read off and override the GPT sizing model args from the model config
model_args['n_layer'] = model.config.n_layer
model_args['n_head'] = model.config.n_head
model_args['n_embd'] = model.config.n_embd
# crop down the model block size if desired
# read off the created config params, so we can store them into checkpoint correctly
for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
model_args[k] = getattr(model.config, k)
# crop down the model block size if desired, using model surgery
if block_size < model.config.block_size:
model.crop_block_size(block_size)
model_args['block_size'] = block_size # so that the checkpoint will have the right value
model.to(device)
# initialize a GradScaler. If enabled=False scaler is a no-op
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
# optimizer
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2))
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
if init_from == 'resume':
optimizer.load_state_dict(checkpoint['optimizer'])