mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-12-18 14:10:28 +00:00
attempt a non-biased model, per few papers that cite this as working well
This commit is contained in:
parent
f29a9ff5bf
commit
2892858ce7
25
model.py
25
model.py
@ -22,15 +22,24 @@ def new_gelu(x):
|
|||||||
"""
|
"""
|
||||||
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
||||||
|
|
||||||
|
class LayerNormNoBias(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, ndim):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(ndim))
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return F.layer_norm(input, self.weight.shape, self.weight, None, 1e-5)
|
||||||
|
|
||||||
class CausalSelfAttention(nn.Module):
|
class CausalSelfAttention(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert config.n_embd % config.n_head == 0
|
assert config.n_embd % config.n_head == 0
|
||||||
# key, query, value projections for all heads, but in a batch
|
# key, query, value projections for all heads, but in a batch
|
||||||
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
|
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
|
||||||
# output projection
|
# output projection
|
||||||
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
|
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
|
||||||
# regularization
|
# regularization
|
||||||
self.attn_dropout = nn.Dropout(config.dropout)
|
self.attn_dropout = nn.Dropout(config.dropout)
|
||||||
self.resid_dropout = nn.Dropout(config.dropout)
|
self.resid_dropout = nn.Dropout(config.dropout)
|
||||||
@ -65,8 +74,8 @@ class MLP(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
|
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
||||||
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
|
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
|
||||||
self.dropout = nn.Dropout(config.dropout)
|
self.dropout = nn.Dropout(config.dropout)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -80,9 +89,9 @@ class Block(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.ln_1 = nn.LayerNorm(config.n_embd)
|
self.ln_1 = LayerNormNoBias(config.n_embd)
|
||||||
self.attn = CausalSelfAttention(config)
|
self.attn = CausalSelfAttention(config)
|
||||||
self.ln_2 = nn.LayerNorm(config.n_embd)
|
self.ln_2 = LayerNormNoBias(config.n_embd)
|
||||||
self.mlp = MLP(config)
|
self.mlp = MLP(config)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -112,7 +121,7 @@ class GPT(nn.Module):
|
|||||||
wpe = nn.Embedding(config.block_size, config.n_embd),
|
wpe = nn.Embedding(config.block_size, config.n_embd),
|
||||||
drop = nn.Dropout(config.dropout),
|
drop = nn.Dropout(config.dropout),
|
||||||
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
|
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
|
||||||
ln_f = nn.LayerNorm(config.n_embd),
|
ln_f = LayerNormNoBias(config.n_embd),
|
||||||
))
|
))
|
||||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||||
# with weight tying when using torch.compile() some warnings get generated:
|
# with weight tying when using torch.compile() some warnings get generated:
|
||||||
@ -242,7 +251,7 @@ class GPT(nn.Module):
|
|||||||
decay = set()
|
decay = set()
|
||||||
no_decay = set()
|
no_decay = set()
|
||||||
whitelist_weight_modules = (torch.nn.Linear, )
|
whitelist_weight_modules = (torch.nn.Linear, )
|
||||||
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
blacklist_weight_modules = (torch.nn.LayerNorm, LayerNormNoBias, torch.nn.Embedding)
|
||||||
for mn, m in self.named_modules():
|
for mn, m in self.named_modules():
|
||||||
for pn, p in m.named_parameters():
|
for pn, p in m.named_parameters():
|
||||||
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
|
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
|
||||||
|
Loading…
Reference in New Issue
Block a user