diff --git a/model.py b/model.py index 799eb71..236117c 100644 --- a/model.py +++ b/model.py @@ -14,8 +14,8 @@ import torch import torch.nn as nn from torch.nn import functional as F -@torch.jit.script -def fused_gelu(x): +# @torch.jit.script # good to enable when not using torch.compile, disable when using (our default) +def new_gelu(x): """ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415 @@ -71,7 +71,7 @@ class MLP(nn.Module): def forward(self, x): x = self.c_fc(x) - x = fused_gelu(x) + x = new_gelu(x) x = self.c_proj(x) x = self.dropout(x) return x