mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-12-18 14:10:28 +00:00
mildly dramatic refactor for handing all these usage cases across all possible supported and unsupported devices for all the possible switches and flags
This commit is contained in:
parent
e108ffb973
commit
25d95dbd65
19
model.py
19
model.py
@ -218,15 +218,16 @@ class GPT(nn.Module):
|
|||||||
'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
|
'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
|
||||||
'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
|
'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
|
||||||
}[model_type]
|
}[model_type]
|
||||||
# we can override the dropout rate
|
print("forcing vocab_size=50257, block_size=1024, bias=True")
|
||||||
|
config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
|
||||||
|
config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
|
||||||
|
config_args['bias'] = True # always True for GPT model checkpoints
|
||||||
|
# we can override the dropout rate, if desired
|
||||||
if 'dropout' in override_args:
|
if 'dropout' in override_args:
|
||||||
|
print(f"overriding dropout rate to {override_args['dropout']}")
|
||||||
config_args['dropout'] = override_args['dropout']
|
config_args['dropout'] = override_args['dropout']
|
||||||
# block_size is always 1024 for GPT model checkpoints
|
|
||||||
# if one wants a lower block_size it has to be done through model surgery
|
|
||||||
# later, by calling crop_block_size()
|
|
||||||
|
|
||||||
# create a from-scratch initialized minGPT model
|
# create a from-scratch initialized minGPT model
|
||||||
config = GPTConfig(block_size=1024, bias=True, **config_args) # note: force bias=True, as in gpt2 models
|
config = GPTConfig(**config_args)
|
||||||
model = GPT(config)
|
model = GPT(config)
|
||||||
sd = model.state_dict()
|
sd = model.state_dict()
|
||||||
sd_keys = sd.keys()
|
sd_keys = sd.keys()
|
||||||
@ -258,7 +259,7 @@ class GPT(nn.Module):
|
|||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def configure_optimizers(self, weight_decay, learning_rate, betas):
|
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
|
||||||
"""
|
"""
|
||||||
This long function is unfortunately doing something very simple and is being very defensive:
|
This long function is unfortunately doing something very simple and is being very defensive:
|
||||||
We are separating out all parameters of the model into two buckets: those that will experience
|
We are separating out all parameters of the model into two buckets: those that will experience
|
||||||
@ -309,7 +310,9 @@ class GPT(nn.Module):
|
|||||||
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
||||||
]
|
]
|
||||||
# new PyTorch nightly has a new 'fused' option for AdamW that is much faster
|
# new PyTorch nightly has a new 'fused' option for AdamW that is much faster
|
||||||
extra_args = dict(fused=True) if 'fused' in inspect.signature(torch.optim.AdamW).parameters else dict()
|
use_fused = (device_type == 'cuda') and ('fused' in inspect.signature(torch.optim.AdamW).parameters)
|
||||||
|
print(f"using fused AdamW: {use_fused}")
|
||||||
|
extra_args = dict(fused=True) if use_fused else dict()
|
||||||
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
|
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
|
||||||
|
|
||||||
return optimizer
|
return optimizer
|
||||||
|
35
train.py
35
train.py
@ -125,21 +125,23 @@ best_val_loss = 1e9
|
|||||||
|
|
||||||
# attempt to derive vocab_size from the dataset
|
# attempt to derive vocab_size from the dataset
|
||||||
meta_path = os.path.join(data_dir, 'meta.pkl')
|
meta_path = os.path.join(data_dir, 'meta.pkl')
|
||||||
|
meta_vocab_size = None
|
||||||
if os.path.exists(meta_path):
|
if os.path.exists(meta_path):
|
||||||
with open(meta_path, 'rb') as f:
|
with open(meta_path, 'rb') as f:
|
||||||
meta = pickle.load(f)
|
meta = pickle.load(f)
|
||||||
vocab_size = meta['vocab_size']
|
meta_vocab_size = meta['vocab_size']
|
||||||
print(f"vocab_size = {vocab_size} (from {meta_path})")
|
print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")
|
||||||
else:
|
|
||||||
print(f"vocab_size not found in {meta_path}, using GPT-2 default of 50257 (rounded up to 50304 for efficiency)")
|
|
||||||
vocab_size = 50304
|
|
||||||
|
|
||||||
# model init
|
# model init
|
||||||
model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
|
model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
|
||||||
dropout=dropout, vocab_size=vocab_size, bias=bias)
|
bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line
|
||||||
if init_from == 'scratch':
|
if init_from == 'scratch':
|
||||||
# init a new model from scratch
|
# init a new model from scratch
|
||||||
print("Initializing a new model from scratch")
|
print("Initializing a new model from scratch")
|
||||||
|
# determine the vocab size we'll use for from-scratch training
|
||||||
|
if meta_vocab_size is None:
|
||||||
|
print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
|
||||||
|
model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304
|
||||||
gptconf = GPTConfig(**model_args)
|
gptconf = GPTConfig(**model_args)
|
||||||
model = GPT(gptconf)
|
model = GPT(gptconf)
|
||||||
elif init_from == 'resume':
|
elif init_from == 'resume':
|
||||||
@ -148,9 +150,11 @@ elif init_from == 'resume':
|
|||||||
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
|
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
|
||||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||||
checkpoint_model_args = checkpoint['model_args']
|
checkpoint_model_args = checkpoint['model_args']
|
||||||
for k, v in model_args.items():
|
# force these config attributes to be equal otherwise we can't even resume training
|
||||||
assert checkpoint_model_args[k] == v, "for now"
|
# the rest of the attributes (e.g. dropout) can stay as desired from command line
|
||||||
# TODO: think through how passed in params should interact with checkpoint params
|
for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
|
||||||
|
model_args[k] = checkpoint_model_args[k]
|
||||||
|
# create the model
|
||||||
gptconf = GPTConfig(**model_args)
|
gptconf = GPTConfig(**model_args)
|
||||||
model = GPT(gptconf)
|
model = GPT(gptconf)
|
||||||
state_dict = checkpoint['model']
|
state_dict = checkpoint['model']
|
||||||
@ -165,24 +169,23 @@ elif init_from == 'resume':
|
|||||||
best_val_loss = checkpoint['best_val_loss']
|
best_val_loss = checkpoint['best_val_loss']
|
||||||
elif init_from.startswith('gpt2'):
|
elif init_from.startswith('gpt2'):
|
||||||
print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
|
print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
|
||||||
assert bias, "GPT-2 models have bias, so we can't use bias=False"
|
|
||||||
# initialize from OpenAI GPT-2 weights
|
# initialize from OpenAI GPT-2 weights
|
||||||
override_args = dict(dropout=dropout)
|
override_args = dict(dropout=dropout)
|
||||||
model = GPT.from_pretrained(init_from, override_args)
|
model = GPT.from_pretrained(init_from, override_args)
|
||||||
# read off and override the GPT sizing model args from the model config
|
# read off the created config params, so we can store them into checkpoint correctly
|
||||||
model_args['n_layer'] = model.config.n_layer
|
for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
|
||||||
model_args['n_head'] = model.config.n_head
|
model_args[k] = getattr(model.config, k)
|
||||||
model_args['n_embd'] = model.config.n_embd
|
# crop down the model block size if desired, using model surgery
|
||||||
# crop down the model block size if desired
|
|
||||||
if block_size < model.config.block_size:
|
if block_size < model.config.block_size:
|
||||||
model.crop_block_size(block_size)
|
model.crop_block_size(block_size)
|
||||||
|
model_args['block_size'] = block_size # so that the checkpoint will have the right value
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
# initialize a GradScaler. If enabled=False scaler is a no-op
|
# initialize a GradScaler. If enabled=False scaler is a no-op
|
||||||
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
|
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
|
||||||
|
|
||||||
# optimizer
|
# optimizer
|
||||||
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2))
|
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
|
||||||
if init_from == 'resume':
|
if init_from == 'resume':
|
||||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user