diff --git a/model.py b/model.py index ec18243..8cf9bda 100644 --- a/model.py +++ b/model.py @@ -22,15 +22,24 @@ def new_gelu(x): """ return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) +class LayerNormNoBias(nn.Module): + + def __init__(self, ndim): + super().__init__() + self.weight = nn.Parameter(torch.ones(ndim)) + + def forward(self, input): + return F.layer_norm(input, self.weight.shape, self.weight, None, 1e-5) + class CausalSelfAttention(nn.Module): def __init__(self, config): super().__init__() assert config.n_embd % config.n_head == 0 # key, query, value projections for all heads, but in a batch - self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) + self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False) # output projection - self.c_proj = nn.Linear(config.n_embd, config.n_embd) + self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) # regularization self.attn_dropout = nn.Dropout(config.dropout) self.resid_dropout = nn.Dropout(config.dropout) @@ -65,8 +74,8 @@ class MLP(nn.Module): def __init__(self, config): super().__init__() - self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd) - self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd) + self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) + self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) self.dropout = nn.Dropout(config.dropout) def forward(self, x): @@ -80,9 +89,9 @@ class Block(nn.Module): def __init__(self, config): super().__init__() - self.ln_1 = nn.LayerNorm(config.n_embd) + self.ln_1 = LayerNormNoBias(config.n_embd) self.attn = CausalSelfAttention(config) - self.ln_2 = nn.LayerNorm(config.n_embd) + self.ln_2 = LayerNormNoBias(config.n_embd) self.mlp = MLP(config) def forward(self, x): @@ -112,7 +121,7 @@ class GPT(nn.Module): wpe = nn.Embedding(config.block_size, config.n_embd), drop = nn.Dropout(config.dropout), h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), - ln_f = nn.LayerNorm(config.n_embd), + ln_f = LayerNormNoBias(config.n_embd), )) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # with weight tying when using torch.compile() some warnings get generated: @@ -242,7 +251,7 @@ class GPT(nn.Module): decay = set() no_decay = set() whitelist_weight_modules = (torch.nn.Linear, ) - blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) + blacklist_weight_modules = (torch.nn.LayerNorm, LayerNormNoBias, torch.nn.Embedding) for mn, m in self.named_modules(): for pn, p in m.named_parameters(): fpn = '%s.%s' % (mn, pn) if mn else pn # full param name