Skip to content

Commit

Permalink
fix: make the dimensions correct
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 15, 2024
1 parent 91a65af commit 61f6ec0
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions examples/NanoGPT/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ function gpt_block(; n_embed, n_hidden, qk_dim, v_dim, n_heads, dropout_rate)
@assert v_dim % n_heads == 0
return @compact(;
name="GPTBlock(; n_embed=$n_embed, n_hidden=$n_hidden, qk_dim=$qk_dim, v_dim=$v_dim, n_heads=$n_heads, dropout_rate=$dropout_rate)",
ln=LayerNorm((n_embed,)),
ln=LayerNorm((n_embed, 1)),
qlayer=Dense(n_embed => qk_dim; use_bias=false),
klayer=Dense(n_embed => qk_dim; use_bias=false),
vlayer=Dense(n_embed => v_dim; use_bias=false),
attn_drop=Dropout(dropout_rate),
proj=Dense(v_dim => n_embed; use_bias=false),
mlp=Chain(
LayerNorm((n_embed,)),
LayerNorm((n_embed, 1)),
Dense(n_embed => n_hidden, gelu),
Dense(n_hidden => n_embed),
Dropout(dropout_rate)
Expand Down Expand Up @@ -54,7 +54,7 @@ function GPT(;
blocks=ntuple(n_layers) do i
return gpt_block(; n_embed, n_hidden, qk_dim, v_dim, n_heads, dropout_rate)
end,
ln=LayerNorm((n_embed,)),
ln=LayerNorm((n_embed, 1)),
output_layer=Dense(n_embed => n_vocab)) do tokens
te = token_embedding(tokens)
pe = position_embedding(1:size(tokens, 1))
Expand Down Expand Up @@ -154,23 +154,29 @@ end
opt = Adam(lr)
train_state = Training.TrainState(model, ps, st, opt)

@printf "[Info] Compiling Inference Model...\n"
testX, testY = (testX, testY) |> dev
model_compiled = @compile model(testX, ps, Lux.testmode(st))
best_test_loss = Inf

@printf "[Info] Starting Model Training...\n\n"

loss_fn = CrossEntropyLoss(; logits=Val(true))

iter = 0
for epoch in 1:epochs
for (x, y) in train_loader
iter += 1

start_time = time()
_, loss, _, train_state = Training.single_train_step!(
AutoEnzyme(), loss_fn, (x, y), train_state
)
time_taken = time() - start_time

if iter % 100 == 0
@printf "[Train] Epoch %3d\tIteration %6d\tLoss %.8e\n" epoch iter loss
@printf "[Train] Epoch %3d\tIteration %6d\tLoss %.8e\tTime per \
Iteration %0.5f\n" epoch iter loss time_taken
end
end

Expand Down

0 comments on commit 61f6ec0

Please sign in to comment.