mirror of
https://github.com/osmarks/nanogpt-experiments.git
synced 2024-11-10 20:09:58 +00:00
Merge pull request #274 from apivovarov/gelu
Use nn.GELU - 1.27x faster training
This commit is contained in:
commit
f08abb45bd
11
model.py
11
model.py
@ -15,14 +15,6 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
# @torch.jit.script # good to enable when not using torch.compile, disable when using (our default)
|
|
||||||
def new_gelu(x):
|
|
||||||
"""
|
|
||||||
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
|
|
||||||
Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
|
|
||||||
"""
|
|
||||||
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
|
||||||
|
|
||||||
class LayerNorm(nn.Module):
|
class LayerNorm(nn.Module):
|
||||||
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
|
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
|
||||||
|
|
||||||
@ -88,12 +80,13 @@ class MLP(nn.Module):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
|
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
|
||||||
|
self.gelu = nn.GELU()
|
||||||
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
|
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
|
||||||
self.dropout = nn.Dropout(config.dropout)
|
self.dropout = nn.Dropout(config.dropout)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.c_fc(x)
|
x = self.c_fc(x)
|
||||||
x = new_gelu(x)
|
x = self.gelu(x)
|
||||||
x = self.c_proj(x)
|
x = self.c_proj(x)
|
||||||
x = self.dropout(x)
|
x = self.dropout(x)
|
||||||
return x
|
return x
|
||||||
|
Loading…
Reference in New Issue
Block a user