1
0
mirror of https://github.com/osmarks/nanogpt-experiments.git synced 2024-09-21 03:39:44 +00:00

rewrite model class so layernorm has an optional bias= parameter

This commit is contained in:
Andrej Karpathy 2023-01-27 20:17:32 +00:00
parent 2892858ce7
commit 2bf07a3fbf

View File

@ -22,14 +22,16 @@ def new_gelu(x):
""" """
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
class LayerNormNoBias(nn.Module): class LayerNorm(nn.Module):
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
def __init__(self, ndim): def __init__(self, ndim, bias=True):
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.ones(ndim)) self.weight = nn.Parameter(torch.ones(ndim))
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
def forward(self, input): def forward(self, input):
return F.layer_norm(input, self.weight.shape, self.weight, None, 1e-5) return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
class CausalSelfAttention(nn.Module): class CausalSelfAttention(nn.Module):
@ -89,9 +91,9 @@ class Block(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.ln_1 = LayerNormNoBias(config.n_embd) self.ln_1 = LayerNorm(config.n_embd, bias=False)
self.attn = CausalSelfAttention(config) self.attn = CausalSelfAttention(config)
self.ln_2 = LayerNormNoBias(config.n_embd) self.ln_2 = LayerNorm(config.n_embd, bias=False)
self.mlp = MLP(config) self.mlp = MLP(config)
def forward(self, x): def forward(self, x):
@ -121,7 +123,7 @@ class GPT(nn.Module):
wpe = nn.Embedding(config.block_size, config.n_embd), wpe = nn.Embedding(config.block_size, config.n_embd),
drop = nn.Dropout(config.dropout), drop = nn.Dropout(config.dropout),
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f = LayerNormNoBias(config.n_embd), ln_f = LayerNorm(config.n_embd, bias=False),
)) ))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# with weight tying when using torch.compile() some warnings get generated: # with weight tying when using torch.compile() some warnings get generated:
@ -148,9 +150,10 @@ class GPT(nn.Module):
torch.nn.init.zeros_(module.bias) torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding): elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm): elif isinstance(module, (LayerNorm, nn.LayerNorm)):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight) torch.nn.init.ones_(module.weight)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
def forward(self, idx, targets=None): def forward(self, idx, targets=None):
device = idx.device device = idx.device
@ -251,7 +254,7 @@ class GPT(nn.Module):
decay = set() decay = set()
no_decay = set() no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, ) whitelist_weight_modules = (torch.nn.Linear, )
blacklist_weight_modules = (torch.nn.LayerNorm, LayerNormNoBias, torch.nn.Embedding) blacklist_weight_modules = (torch.nn.LayerNorm, LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules(): for mn, m in self.named_modules():
for pn, p in m.named_parameters(): for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name fpn = '%s.%s' % (mn, pn) if mn else pn # full param name