Skip to content

Commit

Permalink
discojs/src/models/gpt/layers: fix weight initializations
Browse files Browse the repository at this point in the history
  • Loading branch information
JulienVig committed Nov 7, 2024
1 parent 09ab5bd commit 1b4d18e
Showing 1 changed file with 32 additions and 16 deletions.
48 changes: 32 additions & 16 deletions discojs/src/models/gpt/layers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,14 @@ tf.serialization.registerClass(LogLayer)

type CausalSelfAttentionConfig =
ConstructorParameters<typeof tf.layers.Layer>[0]
& Record<'blockSize' | 'nHead' | 'nEmbd' | 'dropout', number>
& Record<'blockSize' | 'nHead' | 'nEmbd' | 'dropout' | 'nLayer', number>

class CausalSelfAttention extends tf.layers.Layer {
static readonly className = 'CausalSelfAttention'

private readonly nHead: number
private readonly nEmbd: number
private readonly nLayer: number
private readonly dropout: number
private readonly mask: tf.Tensor2D
cAttnKernel?: tf.LayerVariable
Expand All @@ -88,6 +89,7 @@ class CausalSelfAttention extends tf.layers.Layer {

this.nEmbd = config.nEmbd
this.nHead = config.nHead
this.nLayer = config.nLayer
this.dropout = config.dropout

// mask is a lower triangular matrix filled with 1
Expand All @@ -102,7 +104,7 @@ class CausalSelfAttention extends tf.layers.Layer {
'c_attn.weight',
[this.nEmbd, 3 * this.nEmbd],
'float32',
tf.initializers.glorotNormal({})
tf.initializers.randomNormal({ mean:0, stddev:0.02 }) // use same init as GPT2
)
this.cAttnBias = this.addWeight(
'c_attn.bias',
Expand All @@ -115,7 +117,12 @@ class CausalSelfAttention extends tf.layers.Layer {
'c_proj.kernel',
[this.nEmbd, this.nEmbd],
'float32',
tf.initializers.glorotNormal({})
// the input keeps accumulating through the residual stream so we
// scale the initialization with the nb of layers to keep a unit std
// Sources:
// https://github.com/karpathy/build-nanogpt/blob/6104ab1b53920f6e2159749676073ff7d815c1fa/train_gpt2.py#L103
// https://youtu.be/l8pRSuU81PU?si=5GcKfi_kPgLgvtg2&t=4640
tf.initializers.randomNormal({ mean:0, stddev: 0.02 * Math.sqrt(2 * this.nLayer) })
)
this.cProjBias = this.addWeight(
'c_proj.bias',
Expand Down Expand Up @@ -255,27 +262,31 @@ class GELU extends tf.layers.Layer {
tf.serialization.registerClass(GELU)

type MLPConfig = ConstructorParameters<typeof tf.layers.Layer>[0] &
Required<ModelSize> & Record<'blockSize' | 'residDrop', number>
Required<ModelSize> & Record<'blockSize' | 'residDrop' | 'nLayer', number>

function MLP(config: MLPConfig): tf.LayersModel {
return tf.sequential({ layers: [
tf.layers.dense({
name: config.name + `.mlp.c_fc`,
units: 4 * config.nEmbd,
inputDim: config.nEmbd,
inputShape: [config.blockSize, config.nEmbd]
inputShape: [config.blockSize, config.nEmbd],
kernelInitializer: tf.initializers.randomNormal({ mean: 0, stddev: 0.02 }),
}),
new GELU(),
tf.layers.dense({
name: config.name + '.mlp.c_proj',
units: config.nEmbd,
inputDim: 4 * config.nEmbd,
inputShape: [config.blockSize, 4 * config.nEmbd]
}),
tf.layers.dropout({
name: config.name + '.mlp.drop',
rate: config.residDrop
name: config.name + '.mlp.c_proj',
units: config.nEmbd,
inputDim: 4 * config.nEmbd,
inputShape: [config.blockSize, 4 * config.nEmbd],
kernelInitializer: tf.initializers.randomNormal({
mean: 0, stddev: 0.02 * Math.sqrt(2 * config.nLayer)
}),
}),
tf.layers.dropout({
name: config.name + '.mlp.drop',
rate: config.residDrop
}),
]})
}

Expand Down Expand Up @@ -362,7 +373,7 @@ class LMEmbedding extends tf.layers.Layer {
'wte', //use same name as GPT2
[this.vocabSize, this.nEmbd],
'float32',
tf.initializers.randomNormal({})
tf.initializers.randomNormal({ mean:0, stddev:0.02 })
)
}

Expand Down Expand Up @@ -452,7 +463,7 @@ export function GPTArchitecture(config: Required<GPTConfig>): tf.LayersModel {
name: config.name + '.wpe',
inputDim: config.blockSize,
outputDim: config.nEmbd,
embeddingsInitializer: 'zeros'
embeddingsInitializer: tf.initializers.randomNormal({ mean: 0, stddev: 0.02 }),
}).apply(range) as tf.SymbolicTensor

if (config.debug) {
Expand All @@ -474,7 +485,12 @@ export function GPTArchitecture(config: Required<GPTConfig>): tf.LayersModel {
).apply(x)
}
// Normalization
x = tf.layers.layerNormalization({ name: config.name + '.ln_f', epsilon: 1e-5 })
x = tf.layers.layerNormalization({
name: config.name + '.ln_f',
epsilon: 1e-5,
gammaInitializer: 'ones',
betaInitializer: 'zeros',
})
.apply(x)
if (config.debug) {
x = new LogLayer({ name: 'ln_f_log' }).apply(x)
Expand Down

0 comments on commit 1b4d18e

Please sign in to comment.