From 8792097fc784552bcf6a74dab0be296d43da3c98 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 3 Jan 2025 14:24:38 -0500 Subject: [PATCH] fix: ConditionalVAE on CI (#1159) --- examples/ConditionalVAE/main.jl | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/examples/ConditionalVAE/main.jl b/examples/ConditionalVAE/main.jl index be477fd7a..99f90d321 100644 --- a/examples/ConditionalVAE/main.jl +++ b/examples/ConditionalVAE/main.jl @@ -122,9 +122,10 @@ end @concrete struct TensorDataset dataset transform + total_samples::Int end -Base.length(ds::TensorDataset) = length(ds.dataset) +Base.length(ds::TensorDataset) = ds.total_samples function Base.getindex(ds::TensorDataset, idxs::Union{Vector{<:Integer}, AbstractRange}) img = Image.(eachslice(convert2image(ds.dataset, idxs); dims=3)) @@ -132,17 +133,12 @@ function Base.getindex(ds::TensorDataset, idxs::Union{Vector{<:Integer}, Abstrac end function loadmnist(batchsize, image_size::Dims{2}) - ## Load MNIST: Only 1500 for demonstration purposes - N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : nothing + ## Load MNIST: Only 1500 for demonstration purposes on CI train_dataset = MNIST(; split=:train) - test_dataset = MNIST(; split=:test) - if N !== nothing - train_dataset = train_dataset[1:N] - test_dataset = test_dataset[1:N] - end + N = parse(Bool, get(ENV, "CI", "false")) ? 1500 : length(train_dataset) train_transform = ScaleKeepAspect(image_size) |> ImageToTensor() - trainset = TensorDataset(train_dataset, train_transform) + trainset = TensorDataset(train_dataset, train_transform, N) trainloader = DataLoader(trainset; batchsize, shuffle=true, partial=false) return trainloader @@ -247,7 +243,7 @@ function main(; batchsize=128, image_size=(64, 64), num_latent_dims=8, max_num_f for (i, X) in enumerate(train_dataloader) throughput_tic = time() - (_, loss, stats, train_state) = Training.single_train_step!( + (_, loss, _, train_state) = Training.single_train_step!( AutoEnzyme(), loss_function, X, train_state) throughput_toc = time()