diff --git a/gpt.py b/gpt.py index e4fc68d..39be475 100644 --- a/gpt.py +++ b/gpt.py @@ -131,8 +131,8 @@ def __init__(self, n_embd, n_head): self.ln2 = nn.LayerNorm(n_embd) def forward(self, x): - x = x + self.sa(self.ln1(x)) - x = x + self.ffwd(self.ln2(x)) + x = self.ln1(self.sa(x) + x) + x = self.ln2(self.ffwd(x) + x) return x class GPTLanguageModel(nn.Module):