From 1b4d18e4ef2782b1ac67bd39c4daa3ee8c168505 Mon Sep 17 00:00:00 2001 From: Julien Vignoud Date: Thu, 10 Oct 2024 12:49:28 +0200 Subject: [PATCH] discojs/src/models/gpt/layers: fix weight initializations --- discojs/src/models/gpt/layers.ts | 48 +++++++++++++++++++++----------- 1 file changed, 32 insertions(+), 16 deletions(-) diff --git a/discojs/src/models/gpt/layers.ts b/discojs/src/models/gpt/layers.ts index 63661b604..d48e5cac5 100644 --- a/discojs/src/models/gpt/layers.ts +++ b/discojs/src/models/gpt/layers.ts @@ -67,13 +67,14 @@ tf.serialization.registerClass(LogLayer) type CausalSelfAttentionConfig = ConstructorParameters[0] - & Record<'blockSize' | 'nHead' | 'nEmbd' | 'dropout', number> + & Record<'blockSize' | 'nHead' | 'nEmbd' | 'dropout' | 'nLayer', number> class CausalSelfAttention extends tf.layers.Layer { static readonly className = 'CausalSelfAttention' private readonly nHead: number private readonly nEmbd: number + private readonly nLayer: number private readonly dropout: number private readonly mask: tf.Tensor2D cAttnKernel?: tf.LayerVariable @@ -88,6 +89,7 @@ class CausalSelfAttention extends tf.layers.Layer { this.nEmbd = config.nEmbd this.nHead = config.nHead + this.nLayer = config.nLayer this.dropout = config.dropout // mask is a lower triangular matrix filled with 1 @@ -102,7 +104,7 @@ class CausalSelfAttention extends tf.layers.Layer { 'c_attn.weight', [this.nEmbd, 3 * this.nEmbd], 'float32', - tf.initializers.glorotNormal({}) + tf.initializers.randomNormal({ mean:0, stddev:0.02 }) // use same init as GPT2 ) this.cAttnBias = this.addWeight( 'c_attn.bias', @@ -115,7 +117,12 @@ class CausalSelfAttention extends tf.layers.Layer { 'c_proj.kernel', [this.nEmbd, this.nEmbd], 'float32', - tf.initializers.glorotNormal({}) + // the input keeps accumulating through the residual stream so we + // scale the initialization with the nb of layers to keep a unit std + // Sources: + // https://github.com/karpathy/build-nanogpt/blob/6104ab1b53920f6e2159749676073ff7d815c1fa/train_gpt2.py#L103 + // https://youtu.be/l8pRSuU81PU?si=5GcKfi_kPgLgvtg2&t=4640 + tf.initializers.randomNormal({ mean:0, stddev: 0.02 * Math.sqrt(2 * this.nLayer) }) ) this.cProjBias = this.addWeight( 'c_proj.bias', @@ -255,7 +262,7 @@ class GELU extends tf.layers.Layer { tf.serialization.registerClass(GELU) type MLPConfig = ConstructorParameters[0] & - Required & Record<'blockSize' | 'residDrop', number> + Required & Record<'blockSize' | 'residDrop' | 'nLayer', number> function MLP(config: MLPConfig): tf.LayersModel { return tf.sequential({ layers: [ @@ -263,19 +270,23 @@ function MLP(config: MLPConfig): tf.LayersModel { name: config.name + `.mlp.c_fc`, units: 4 * config.nEmbd, inputDim: config.nEmbd, - inputShape: [config.blockSize, config.nEmbd] + inputShape: [config.blockSize, config.nEmbd], + kernelInitializer: tf.initializers.randomNormal({ mean: 0, stddev: 0.02 }), }), new GELU(), tf.layers.dense({ - name: config.name + '.mlp.c_proj', - units: config.nEmbd, - inputDim: 4 * config.nEmbd, - inputShape: [config.blockSize, 4 * config.nEmbd] - }), - tf.layers.dropout({ - name: config.name + '.mlp.drop', - rate: config.residDrop + name: config.name + '.mlp.c_proj', + units: config.nEmbd, + inputDim: 4 * config.nEmbd, + inputShape: [config.blockSize, 4 * config.nEmbd], + kernelInitializer: tf.initializers.randomNormal({ + mean: 0, stddev: 0.02 * Math.sqrt(2 * config.nLayer) }), + }), + tf.layers.dropout({ + name: config.name + '.mlp.drop', + rate: config.residDrop + }), ]}) } @@ -362,7 +373,7 @@ class LMEmbedding extends tf.layers.Layer { 'wte', //use same name as GPT2 [this.vocabSize, this.nEmbd], 'float32', - tf.initializers.randomNormal({}) + tf.initializers.randomNormal({ mean:0, stddev:0.02 }) ) } @@ -452,7 +463,7 @@ export function GPTArchitecture(config: Required): tf.LayersModel { name: config.name + '.wpe', inputDim: config.blockSize, outputDim: config.nEmbd, - embeddingsInitializer: 'zeros' + embeddingsInitializer: tf.initializers.randomNormal({ mean: 0, stddev: 0.02 }), }).apply(range) as tf.SymbolicTensor if (config.debug) { @@ -474,7 +485,12 @@ export function GPTArchitecture(config: Required): tf.LayersModel { ).apply(x) } // Normalization - x = tf.layers.layerNormalization({ name: config.name + '.ln_f', epsilon: 1e-5 }) + x = tf.layers.layerNormalization({ + name: config.name + '.ln_f', + epsilon: 1e-5, + gammaInitializer: 'ones', + betaInitializer: 'zeros', + }) .apply(x) if (config.debug) { x = new LogLayer({ name: 'ln_f_log' }).apply(x)