From 61f6ec0a2624d731d96d51391200dde9ea507cf1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 9 Nov 2024 15:02:32 -0500 Subject: [PATCH] fix: make the dimensions correct --- examples/NanoGPT/main.jl | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/examples/NanoGPT/main.jl b/examples/NanoGPT/main.jl index 3ea05236d5..6d6974a302 100644 --- a/examples/NanoGPT/main.jl +++ b/examples/NanoGPT/main.jl @@ -18,14 +18,14 @@ function gpt_block(; n_embed, n_hidden, qk_dim, v_dim, n_heads, dropout_rate) @assert v_dim % n_heads == 0 return @compact(; name="GPTBlock(; n_embed=$n_embed, n_hidden=$n_hidden, qk_dim=$qk_dim, v_dim=$v_dim, n_heads=$n_heads, dropout_rate=$dropout_rate)", - ln=LayerNorm((n_embed,)), + ln=LayerNorm((n_embed, 1)), qlayer=Dense(n_embed => qk_dim; use_bias=false), klayer=Dense(n_embed => qk_dim; use_bias=false), vlayer=Dense(n_embed => v_dim; use_bias=false), attn_drop=Dropout(dropout_rate), proj=Dense(v_dim => n_embed; use_bias=false), mlp=Chain( - LayerNorm((n_embed,)), + LayerNorm((n_embed, 1)), Dense(n_embed => n_hidden, gelu), Dense(n_hidden => n_embed), Dropout(dropout_rate) @@ -54,7 +54,7 @@ function GPT(; blocks=ntuple(n_layers) do i return gpt_block(; n_embed, n_hidden, qk_dim, v_dim, n_heads, dropout_rate) end, - ln=LayerNorm((n_embed,)), + ln=LayerNorm((n_embed, 1)), output_layer=Dense(n_embed => n_vocab)) do tokens te = token_embedding(tokens) pe = position_embedding(1:size(tokens, 1)) @@ -154,10 +154,13 @@ end opt = Adam(lr) train_state = Training.TrainState(model, ps, st, opt) + @printf "[Info] Compiling Inference Model...\n" testX, testY = (testX, testY) |> dev model_compiled = @compile model(testX, ps, Lux.testmode(st)) best_test_loss = Inf + @printf "[Info] Starting Model Training...\n\n" + loss_fn = CrossEntropyLoss(; logits=Val(true)) iter = 0 @@ -165,12 +168,15 @@ end for (x, y) in train_loader iter += 1 + start_time = time() _, loss, _, train_state = Training.single_train_step!( AutoEnzyme(), loss_fn, (x, y), train_state ) + time_taken = time() - start_time if iter % 100 == 0 - @printf "[Train] Epoch %3d\tIteration %6d\tLoss %.8e\n" epoch iter loss + @printf "[Train] Epoch %3d\tIteration %6d\tLoss %.8e\tTime per \ + Iteration %0.5f\n" epoch iter loss time_taken end end