diff --git a/model.py b/model.py index 78bacd4..80c2192 100644 --- a/model.py +++ b/model.py @@ -25,7 +25,7 @@ def new_gelu(x): class LayerNorm(nn.Module): """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ - def __init__(self, ndim, bias=True): + def __init__(self, ndim, bias): super().__init__() self.weight = nn.Parameter(torch.ones(ndim)) self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None @@ -39,9 +39,9 @@ class CausalSelfAttention(nn.Module): super().__init__() assert config.n_embd % config.n_head == 0 # key, query, value projections for all heads, but in a batch - self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False) + self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) # output projection - self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) + self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) # regularization self.attn_dropout = nn.Dropout(config.dropout) self.resid_dropout = nn.Dropout(config.dropout) @@ -76,8 +76,8 @@ class MLP(nn.Module): def __init__(self, config): super().__init__() - self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) - self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) + self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) + self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) self.dropout = nn.Dropout(config.dropout) def forward(self, x): @@ -91,9 +91,9 @@ class Block(nn.Module): def __init__(self, config): super().__init__() - self.ln_1 = LayerNorm(config.n_embd, bias=False) + self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) self.attn = CausalSelfAttention(config) - self.ln_2 = LayerNorm(config.n_embd, bias=False) + self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) self.mlp = MLP(config) def forward(self, x): @@ -109,6 +109,7 @@ class GPTConfig: n_head: int = 12 n_embd: int = 768 dropout: float = 0.1 + bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster class GPT(nn.Module): @@ -123,7 +124,7 @@ class GPT(nn.Module): wpe = nn.Embedding(config.block_size, config.n_embd), drop = nn.Dropout(config.dropout), h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), - ln_f = LayerNorm(config.n_embd, bias=False), + ln_f = LayerNorm(config.n_embd, bias=config.bias), )) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # with weight tying when using torch.compile() some warnings get generated: