Skip to content

Commit

Permalink
feat: pipeline working 🎉
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 22, 2024
1 parent 8665bc3 commit 644b60c
Showing 1 changed file with 22 additions and 15 deletions.
37 changes: 22 additions & 15 deletions examples/ConvMixer/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ function get_dataloaders(batchsize; kwargs...)
test_transform = ImageToTensor() |> Normalize(cifar10_mean, cifar10_std)

trainset = TensorDataset(CIFAR10(:train), train_transform)
trainloader = DataLoader(trainset; batchsize, shuffle=true, parallel=true, kwargs...)
trainloader = DataLoader(trainset; batchsize, shuffle=true, kwargs...)

testset = TensorDataset(CIFAR10(:test), test_transform)
testloader = DataLoader(testset; batchsize, shuffle=false, parallel=true, kwargs...)
testloader = DataLoader(testset; batchsize, shuffle=false, kwargs...)

return trainloader, testloader
end
Expand All @@ -43,16 +43,23 @@ function ConvMixer(; dim, depth, kernel_size=5, patch_size=2)
return Chain(
Conv((patch_size, patch_size), 3 => dim, gelu; stride=patch_size),
BatchNorm(dim),
[Chain(
SkipConnection(
Chain(
Conv((kernel_size, kernel_size), dim => dim, gelu; groups=dim, pad=SamePad()),
BatchNorm(dim)
[
Chain(
SkipConnection(
Chain(
Conv(
(kernel_size, kernel_size), dim => dim, gelu;
groups=dim, pad=SamePad()
),
BatchNorm(dim)
),
+
),
+
),
Conv((1, 1), dim => dim, gelu), BatchNorm(dim))
for _ in 1:depth]...,
Conv((1, 1), dim => dim, gelu),
BatchNorm(dim)
)
for _ in 1:depth
]...,
GlobalMeanPool(),
FlattenLayer(),
Dense(dim => 10)
Expand All @@ -74,9 +81,9 @@ function accuracy(model, ps, st, dataloader)
end

Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8,
patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-5,
patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-4,
clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01,
backend::String="reactant")
backend::String="gpu_if_available")
rng = StableRNG(seed)

if backend == "gpu_if_available"
Expand Down Expand Up @@ -118,7 +125,7 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::
model_compiled = model
end

loss = CrossEntropyLoss(; logits=Val(true))
loss_fn = CrossEntropyLoss(; logits=Val(true))

@printf "[Info] Training model\n"
for epoch in 1:epochs
Expand All @@ -128,7 +135,7 @@ Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::
lr = lr_schedule((epoch - 1) + (i + 1) / length(trainloader))
train_state = Optimisers.adjust!(train_state, lr)
(_, _, _, train_state) = Training.single_train_step!(
adtype, loss, (x, y), train_state
adtype, loss_fn, (x, y), train_state
)
end
ttime = time() - stime
Expand Down

0 comments on commit 644b60c

Please sign in to comment.