From 1a4cd75416f38d70f18faf49bbcf10a4ef52fd93 Mon Sep 17 00:00:00 2001 From: josemanuel22 Date: Sat, 30 Dec 2023 18:54:42 +0100 Subject: [PATCH] adding sliced algo --- Project.toml | 8 + .../benchmark_unimodal.jl | 25 +- examples/Sliced_ISL/MNIST_sliced.jl | 63 +++ examples/Sliced_ISL/benchmark_sliced.jl | 377 ++++++++++++++++++ src/CustomLossFunction.jl | 186 ++++++++- src/ISL.jl | 8 +- 6 files changed, 651 insertions(+), 16 deletions(-) create mode 100644 examples/Sliced_ISL/MNIST_sliced.jl create mode 100644 examples/Sliced_ISL/benchmark_sliced.jl diff --git a/Project.toml b/Project.toml index 2c64b88..d8d2bfd 100644 --- a/Project.toml +++ b/Project.toml @@ -9,10 +9,17 @@ CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4" Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" +Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Formatting = "59287772-0a20-5a39-b81b-1366585eb4c0" +GR = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71" HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" +ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795" +Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0" +KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" @@ -20,6 +27,7 @@ ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/examples/Learning1d_distribution/benchmark_unimodal.jl b/examples/Learning1d_distribution/benchmark_unimodal.jl index 1d0cdf7..ca5263e 100644 --- a/examples/Learning1d_distribution/benchmark_unimodal.jl +++ b/examples/Learning1d_distribution/benchmark_unimodal.jl @@ -9,11 +9,13 @@ include("../utils.jl") n_samples = 10000 @test_experiments "N(0,1) to N(23,1)" begin - gen = Chain(Dense(1, 7), elu, Dense(7, 13), elu, Dense(13, 7), elu, Dense(7, 1)) + gen = Chain(Dense(2, 7), elu, Dense(7, 13), elu, Dense(13, 7), elu, Dense(7, 2)) dscr = Chain( Dense(1, 11), elu, Dense(11, 29), elu, Dense(29, 11), elu, Dense(11, 1, σ) ) - target_model = Normal(4.0f0, 2.0f0) + target_model = MixtureModel([ + Normal(5.0f0, 2.0f0), Pareto(5.0f0,1.0f0), + ]) hparams = HyperParamsVanillaGan(; data_size=100, batch_size=1, @@ -29,17 +31,16 @@ include("../utils.jl") train_vanilla_gan(dscr, gen, hparams) hparams = AutoISLParams(; - max_k=10, samples=1000, epochs=1000, η=1e-2, transform=noise_model + max_k=10, samples=1000, epochs=100, η=1e-2, transform=noise_model ) - train_set = Float32.(rand(target_model, hparams.samples)) - loader = Flux.DataLoader(train_set; batchsize=-1, shuffle=true, partial=false) - + train_set = Float32.(rand(target_model, hparams.samples * hparams.epochs)) + loader = Flux.DataLoader(train_set; batchsize=hparams.samples, shuffle=true, partial=false) auto_invariant_statistical_loss(gen, loader, hparams) end @test_experiments "N(0,1) to Uniform(22,24)" begin - gen = Chain(Dense(1, 7), elu, Dense(7, 13), elu, Dense(13, 7), elu, Dense(7, 1)) + gen = Chain(Dense(1, 7), elu, Dense(7, 13), elu, Dense(13, 7), elu, Dense(7, 2)) dscr = Chain( Dense(1, 11), elu, Dense(11, 29), elu, Dense(29, 11), elu, Dense(11, 1, σ) ) @@ -59,7 +60,7 @@ include("../utils.jl") train_vanilla_gan(dscr, gen, hparams) hparams = AutoISLParams(; - max_k=10, samples=1000, epochs=1000, η=1e-2, transform=noise_model + max_k=10, samples=1000, epochs=100, η=1e-2, transform=noise_model ) train_set = Float32.(rand(target_model, hparams.samples)) loader = Flux.DataLoader(train_set; batchsize=-1, shuffle=true, partial=false) @@ -73,7 +74,7 @@ include("../utils.jl") gen, n_samples, (-3:0.1:3), - (0:0.1:10), + (0:0.1:30), ) end @@ -100,9 +101,9 @@ include("../utils.jl") train_vanilla_gan(dscr, gen, hparams) hparams = AutoISLParams(; - max_k=10, samples=1000, epochs=1000, η=1e-2, transform=noise_model + max_k=10, samples=1000, epochs=100, η=1e-2, transform=noise_model ) - train_set = Float32.(rand(target_model, hparams.samples)) + train_set = Float32.(rand(target_model, hparams.sample * hparams.epochs)) loader = Flux.DataLoader(train_set; batchsize=-1, shuffle=true, partial=false) auto_invariant_statistical_loss(gen, loader, hparams) @@ -122,7 +123,7 @@ include("../utils.jl") gen, n_samples, (-3:0.1:3), - (0:0.1:10), + (-20:0.1:20), ) end diff --git a/examples/Sliced_ISL/MNIST_sliced.jl b/examples/Sliced_ISL/MNIST_sliced.jl new file mode 100644 index 0000000..c6bdfbd --- /dev/null +++ b/examples/Sliced_ISL/MNIST_sliced.jl @@ -0,0 +1,63 @@ +using ISL +using Flux +using MLDatasets +using Images +using ImageTransformations # For resizing images if necessary + +function load_mnist(digit::Int) + # Load MNIST data + train_x, train_y = MLDatasets.MNIST.traindata() + test_x, test_y = MLDatasets.MNIST.testdata() + + # Find indices where the label is digit + selected_indices = findall(x -> x == digit, train_y) + + selected_images = train_x[:, :, selected_indices] + + return (reshape(Float32.(selected_images), 784, :), train_y)#, (test_x, test_y) +end + +(train_x, train_y) = load_mnist(5) + + +model = Chain( + Dense(3, 512, relu), + Dense(512, 28*28, sigmoid) +) + +model = Chain( + Dense(3, 256, relu), + #BatchNorm(256), + Dense(256, 512, relu), + #BatchNorm(512, relu), + Dense(512, 28*28, identity), + x -> reshape(x, 28, 28, 1, :), + Conv((3, 3), 1=>16, relu), + MaxPool((2,2)), + x -> reshape(x, :, size(x, 4)), + Flux.flatten, + Dense(2704, 28*28) +) + +# Define hyperparameters +noise_model = MvNormal([0.0, 0.0, 0.0], [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 1.0]) +n_samples = 10000 + +hparams = HyperParamsSlicedISL(; + K=10, samples=1000, epochs=5, η=1e-2, noise_model=noise_model, m=20 +) + +# Create a data loader for training +batch_size = 1000 +train_loader = DataLoader(train_x; batchsize=batch_size, shuffle=false, partial=false) + +total_loss = [] +for _ in 1:10 + append!(total_loss, sliced_invariant_statistical_loss(model, train_loader, hparams)) +end + +img = model(Float32.(rand(hparams.noise_model, 1))) +img2 = reshape(img, 28, 28) +display(Gray.(img2)) +transformed_matrix = Float32.(img2 .> 0.1) +display(Gray.(transformed_matrix)) diff --git a/examples/Sliced_ISL/benchmark_sliced.jl b/examples/Sliced_ISL/benchmark_sliced.jl new file mode 100644 index 0000000..ea26bfc --- /dev/null +++ b/examples/Sliced_ISL/benchmark_sliced.jl @@ -0,0 +1,377 @@ +using ISL +using KernelDensity +using Random + +include("../utils.jl") + +@test_experiments "sliced ISL" begin + @test_experiments "N(0,1)" begin + noise_model = MvNormal([0.0, 0.0], [1.0 0.0; 0.0 1.0]) + n_samples = 10000 + @test_experiments "N(0,1) to N(23,1)" begin + gen = Chain(Dense(2, 7), elu, Dense(7, 13), elu, Dense(13, 7), elu, Dense(7, 2)) + + mean_vector = [2.0, 3.0] + cov = [1.0 0.5; 0.5 1.0] + target_model = MvNormal(mean_vector, cov) + + hparams = HyperParamsSlicedISL(; + K=10, samples=1000, epochs=100, η=1e-2, noise_model=noise_model, m=10 + ) + + train_set = Float32.(rand(target_model, hparams.samples * hparams.epochs)) + loader = Flux.DataLoader( + train_set; batchsize=hparams.samples, shuffle=true, partial=false + ) + + loss = sliced_invariant_statistical_loss(gen, loader, hparams) + loss = sliced_invariant_statistical_loss_2(gen, loader, hparams) + + loss = sliced_invariant_statistical_loss_multithreaded_2(gen, loader, hparams) + + plotlyjs() + output_data = gen(Float32.(rand(noise_model, n_samples))) + x = output_data[1, :] + y = output_data[2, :] + kde_result = kde((x, y)) + contour( + kde_result.x, + kde_result.y, + kde_result.density; + xlabel="X", + ylabel="Y", + zlabel="Index", + ) + plot( + kde_result.x, + kde_result.y, + kde_result.density; + xlabel="X", + ylabel="Y", + zlabel="Index", + st=:surface, + ) + + x_axis = -20:0.1:20 + y_axis = -20:0.1:20 + f(x_axis, y_axis) = pdf(target_model, [x_axis, y_axis]) + kde_result = kde((x_axis, y_axis)) + contour!( + kde_result.x, + kde_result.y, + kde_result.density; + xlabel="X", + ylabel="Y", + zlabel="Index", + ) + plot!( + x_axis, + y_axis, + f; + title="Contour Plot of 2D Gaussian Distribution", + xlabel="X-axis", + ylabel="Y-axis", + st=:surface, + alpha=0.8 + ) + end + + @test_experiments "N(0,1) to N(23,1)" begin + gen = Chain(Dense(2, 7), elu, Dense(7, 13), elu, Dense(13, 7), elu, Dense(7, 2)) + + # Define the custom distribution type + struct BivariatedUniform <: ContinuousMultivariateDistribution + a_min::Float64 + a_max::Float64 + b_min::Float64 + b_max::Float64 + end + + Distributions.dim(::BivariatedUniform) = 2 + + Base.length(::BivariatedUniform) = 2 + + function Distributions.pdf(d::BivariatedUniform, x::AbstractArray{Float64}) + x_val, y_val = x[1], x[2] + return pdf(Uniform(d.a_min, d.a_max), x_val) * + pdf(Uniform(d.b_min, d.b_max), y_val) # Example pdf + end + + function Distributions.rand(rng::AbstractRNG, d::BivariatedUniform) + x = rand(rng, Uniform(d.a_min, d.a_max)) + y = rand(rng, Uniform(d.b_min, d.b_max)) + return [x, y] + end + + function Distributions._rand!( + rng::AbstractRNG, + d::MultivariateDistribution, + x::AbstractArray{Float64} + ) + # Check if the dimensions of x match the dimensions of the distribution + @assert size(x, 1) == length(d) "Dimension mismatch" + + # Iterate over each column (sample) in x + for i in 1:size(x, 2) + # Generate a sample from the distribution d using the rng + sample = rand(rng, d) + + # Fill the ith column of x with this sample + x[:, i] .= sample + end + + return x + end + + target_model = BivariatedUniform(-2, 2, -2, 2) + + hparams = HyperParamsSlicedISL(; + K=10, samples=1000, epochs=100, η=1e-2, noise_model=noise_model, m=40 + ) + + train_set = Float32.(rand(target_model, hparams.samples * hparams.epochs)) + loader = Flux.DataLoader( + train_set; batchsize=hparams.samples, shuffle=true, partial=false + ) + + loss = sliced_invariant_statistical_loss(gen, loader, hparams) + end + end + + @test_experiments "N(0,1) to N(23,1)" begin + gen = Chain(Dense(2, 7), elu, Dense(7, 13), elu, Dense(13, 7), elu, Dense(7, 2)) + + # Define the custom distribution type + struct CoustomDistribution <: ContinuousMultivariateDistribution + a_min::Float64 + a_max::Float64 + μ::Float64 + σ::Float64 + end + + Distributions.dim(::CoustomDistribution) = 2 + + Base.length(::CoustomDistribution) = 2 + + function Distributions.pdf(d::CoustomDistribution, x::AbstractArray{Float64}) + x_val, y_val = x[1], x[2] + return pdf(Uniform(d.a_min, d.a_max), x_val) * + pdf(Normal(d.μ, d.σ), y_val) # Example pdf + end + + function Distributions.rand(rng::AbstractRNG, d::CoustomDistribution) + x = rand(rng, Uniform(d.a_min, d.a_max)) + y = rand(rng, Normal(d.μ, d.σ)) + return [x, y] + end + + function Distributions._rand!( + rng::AbstractRNG, + d::CoustomDistribution, + x::AbstractArray{Float64} + ) + # Ensure that the dimensions of x are compatible with the distribution + @assert size(x, 1) == 2 "Dimension mismatch" + + # Iterate over each column (sample) in x + for i in 1:size(x, 2) + # Generate a sample for each dimension of the distribution + x[1, i] = rand(rng, Uniform(d.a_min, d.a_max)) # First dimension + x[2, i] = rand(rng, Normal(d.μ, d.σ)) # Second dimension + end + + return x + end + + μ = 2.0 + σ = 1.0 + target_model = CoustomDistribution(-2, 2, μ, σ) + + hparams = HyperParamsSlicedISL(; + K=10, samples=1000, epochs=100, η=1e-2, noise_model=noise_model, m=40 + ) + + train_set = Float32.(rand(target_model, hparams.samples * hparams.epochs)) + loader = Flux.DataLoader( + train_set; batchsize=hparams.samples, shuffle=true, partial=false + ) + + loss = sliced_invariant_statistical_loss(gen, loader, hparams) + end + + @test_experiments "N(0,1) to N(23,1)" begin + gen = Chain(Dense(2, 7), elu, Dense(7, 13), elu, Dense(13, 7), elu, Dense(7, 2)) + + # Define the custom distribution type + struct CoustomDistribution <: ContinuousMultivariateDistribution + α::Float64 + β::Float64 + μ::Float64 + σ::Float64 + end + + Distributions.dim(::CoustomDistribution) = 2 + + Base.length(::CoustomDistribution) = 2 + + function Distributions.pdf(d::CoustomDistribution, x::AbstractArray{Float64}) + x_val, y_val = x[1], x[2] + return pdf(Cauchy(d.α, d.β), x_val) * + pdf(Normal(d.μ, d.σ), y_val) # Example pdf + end + + function Distributions.rand(rng::AbstractRNG, d::CoustomDistribution) + x = rand(rng, Cauchy(d.α, d.β)) + y = rand(rng, Normal(d.μ, d.σ)) + return [x, y] + end + + function Distributions._rand!( + rng::AbstractRNG, + d::CoustomDistribution, + x::AbstractArray{Float64} + ) + # Ensure that the dimensions of x are compatible with the distribution + @assert size(x, 1) == 2 "Dimension mismatch" + + # Iterate over each column (sample) in x + for i in 1:size(x, 2) + # Generate a sample for each dimension of the distribution + x[1, i] = rand(rng, Cauchy(d.α, d.β)) # First dimension + x[2, i] = rand(rng, Normal(d.μ, d.σ)) # Second dimension + end + + return x + end + + target_model = CoustomDistribution(2.0, 2.0, 2.0, 1.0) + + hparams = HyperParamsSlicedISL(; + K=10, samples=1000, epochs=100, η=1e-2, noise_model=noise_model, m=40 + ) + + train_set = Float32.(rand(target_model, hparams.samples * hparams.epochs)) + loader = Flux.DataLoader( + train_set; batchsize=hparams.samples, shuffle=true, partial=false + ) + + loss = sliced_invariant_statistical_loss(gen, loader, hparams) + end + + @test_experiments "N(0,1) to N(23,1)" begin + gen = Chain(Dense(2, 7), elu, Dense(7, 13), elu, Dense(13, 7), elu, Dense(7, 2)) + + # Define the custom distribution type + struct CoustomDistribution <: ContinuousMultivariateDistribution + α::Float32 + β::Float32 + a_min::Float32 + a_max::Float32 + end + + Distributions.dim(::CoustomDistribution) = 2 + + Base.length(::CoustomDistribution) = 2 + + function Distributions.pdf(d::CoustomDistribution, x::AbstractArray{Float64}) + x_val, y_val = x[1], x[2] + return pdf(Uniform(d.a_min, d.a_max), x_val) * pdf(Cauchy(d.α, d.β), y_val) + end + + function Distributions.rand(rng::AbstractRNG, d::CoustomDistribution) + x = rand(rng, Uniform(d.a_min, d.a_max)) + y = rand(rng, Cauchy(d.α, d.β) ) + return [x, y] + end + + function Distributions._rand!( + rng::AbstractRNG, + d::CoustomDistribution, + x::AbstractArray{Float64} + ) + # Ensure that the dimensions of x are compatible with the distribution + @assert size(x, 1) == 2 "Dimension mismatch" + + # Iterate over each column (sample) in x + for i in 1:size(x, 2) + # Generate a sample for each dimension of the distribution + x[1, i] = rand(rng, Uniform(d.a_min, d.a_max)) # First dimension + x[2, i] = rand(rng, Cauchy(d.α, d.β)) # Second dimension + end + + return x + end + + target_model = CoustomDistribution(0.0, 10.0, -1.0, 1.0) + + hparams = HyperParamsSlicedISL(; + K=10, samples=1000, epochs=100, η=1e-2, noise_model=noise_model, m=10 + ) + + train_set = Float32.(rand(target_model, hparams.samples * hparams.epochs)) + loader = Flux.DataLoader( + train_set; batchsize=hparams.samples, shuffle=true, partial=false + ) + + loss = sliced_invariant_statistical_loss(gen, loader, hparams) + end + + @test_experiments "N(0,1) to N(23,1)" begin + gen = Chain(Dense(2, 7), elu, Dense(7, 13), elu, Dense(13, 7), elu, Dense(7, 2)) + + # Define the custom distribution type + struct CoustomDistribution <: ContinuousMultivariateDistribution + α₁::Float32 + β₁::Float32 + α₂::Float32 + β₂::Float32 + end + + Distributions.dim(::CoustomDistribution) = 2 + + Base.length(::CoustomDistribution) = 2 + + function Distributions.pdf(d::CoustomDistribution, x::AbstractArray{Float64}) + x_val, y_val = x[1], x[2] + return pdf(Cauchy(d.α₁, d.β₁), x_val) * pdf(Cauchy(d.α₂, d.β₂), y_val) + end + + function Distributions.rand(rng::AbstractRNG, d::CoustomDistribution) + x = rand(rng, Cauchy(d.α₁, d.β₁)) + y = rand(rng, Cauchy(d.α₂, d.β₂)) + return [x, y] + end + + function Distributions._rand!( + rng::AbstractRNG, + d::CoustomDistribution, + x::AbstractArray{Float64} + ) + # Ensure that the dimensions of x are compatible with the distribution + @assert size(x, 1) == 2 "Dimension mismatch" + + # Iterate over each column (sample) in x + for i in 1:size(x, 2) + # Generate a sample for each dimension of the distribution + x[1, i] = rand(rng, Cauchy(d.α₁, d.β₁)) # First dimension + x[2, i] = rand(rng, Cauchy(d.α₂, d.β₂)) # Second dimension + end + + return x + end + + target_model = CoustomDistribution(1.0, 1.0, 1.0, 1.0) + + hparams = HyperParamsSlicedISL(; + K=10, samples=1000, epochs=100, η=1e-2, noise_model=noise_model, m=20 + ) + + train_set = Float32.(rand(target_model, hparams.samples * hparams.epochs)) + loader = Flux.DataLoader( + train_set; batchsize=hparams.samples, shuffle=true, partial=false + ) + + + loss = sliced_invariant_statistical_loss_distributed(gen, loader, hparams) + end +end; diff --git a/src/CustomLossFunction.jl b/src/CustomLossFunction.jl index 0fa69b3..6ebb659 100644 --- a/src/CustomLossFunction.jl +++ b/src/CustomLossFunction.jl @@ -389,8 +389,6 @@ end # Train and output the model according to the chosen hyperparameters `hparams` - - function ts_invariant_statistical_loss_one_step_prediction(rec, gen, Xₜ, Xₜ₊₁, hparams) losses = [] optim_rec = Flux.setup(Flux.Adam(hparams.η), rec) @@ -416,7 +414,6 @@ function ts_invariant_statistical_loss_one_step_prediction(rec, gen, Xₜ, Xₜ return losses end - """ ts_invariant_statistical_loss(rec, gen, Xₜ, Xₜ₊₁, hparams) @@ -466,3 +463,186 @@ function ts_invariant_statistical_loss(rec, gen, Xₜ, Xₜ₊₁, hparams) end return losses end + +Base.@kwdef mutable struct HyperParamsSlicedISL + seed::Int = 72 # Random seed + dev = cpu # Device: cpu or gpu + η::Float64 = 1e-3 # Learning rate + epochs::Int = 100 # Number of epochs + noise_model = MvNormal([0.0, 0.0], [1.0 0.0; 0.0 1.0]) # Noise to add to the data + samples::Int = 1000 # Window size for the histogram + K::Int = 10 # Number of simulted observations + m::Int = 10 # Number of random directions +end + +function sample_random_direction(n::Int)::Vector{Float32} + # Generate a random vector where each component is from a standard normal distribution + random_vector = rand(Float32, n) + + normalized_vector = random_vector / norm(random_vector) + + return normalized_vector + +end + +function sliced_invariant_statistical_loss(nn_model, loader, hparams::HyperParamsSlicedISL) + @assert loader.batchsize == hparams.samples + @assert length(loader) == hparams.epochs + losses = Vector{Float32}() + optim = Flux.setup(Flux.Adam(hparams.η), nn_model) + @showprogress for data in loader + loss, grads = Flux.withgradient(nn_model) do nn + Ω = [sample_random_direction(size(data)[1]) for _ in 1:(hparams.m)] + total = 0.0f0 + for ω in Ω + aₖ = zeros(hparams.K + 1) + for i in 1:(hparams.samples) + x = Float32.(rand(hparams.noise_model, hparams.K)) + yₖ = nn(x) + s = collect(reshape(ω' * yₖ, 1, hparams.K)) + aₖ += generate_aₖ(s, ω ⋅ data[:, i]) + end + total += scalar_diff(aₖ ./ sum(aₖ)) + end + total / hparams.m + end + Flux.update!(optim, nn_model, grads[1]) + push!(losses, loss) + end + return losses +end; + +function sliced_invariant_statistical_loss_2( + nn_model, loader, hparams::HyperParamsSlicedISL +) + @assert loader.batchsize == hparams.samples + @assert length(loader) == hparams.epochs + losses = Vector{Float32}() + optim = Flux.setup(Flux.Adam(hparams.η), nn_model) + + @showprogress for data in loader + loss, grads = Flux.withgradient(nn_model) do nn + Ω = [sample_random_direction(size(data)[1]) for _ in 1:(hparams.m)] + total = 0.0f0 + # Vectorized operations + X = broadcast(vec -> Float32.(vec), rand(hparams.noise_model, hparams.K, hparams.samples)) # All random numbers at once + Yₖ = nn.(X) # Apply nn to the entire batch + for ω in Ω + S = broadcast(x -> dot(ω, x), Yₖ) # Vectorized computation + reshaped_S = [reshape(S[:, i], :, 1) for i in 1:size(S, 2)] + aₖ = sum(generate_aₖ.(reshaped_S, ω' * data)) # Sum over samples + total += scalar_diff(aₖ ./ sum(aₖ)) + end + + total / hparams.m + end + + Flux.update!(optim, nn_model, grads[1]) + push!(losses, loss) + end + + return losses +end; + +using Zygote +using Zygote: bufferfrom +using Base.Threads: @spawn + +using Base.Threads + +# Set batch size based on the number of threads +const BATCH_SIZE = Threads.nthreads() # Or a multiple of Threads.nthreads() + +function compute_loss_for_single_ω(nn, ω, data, hparams, preallocated_aₖ) + # Clear the preallocated array + fill!(preallocated_aₖ, 0) + + for i in 1:(hparams.samples) + x = Float32.(rand(hparams.noise_model, hparams.K)) + yₖ = nn(x) + s = collect(reshape(ω' * yₖ, 1, hparams.K)) + preallocated_aₖ += generate_aₖ(s, ω ⋅ data[:, i]) + end + return scalar_diff(preallocated_aₖ ./ sum(preallocated_aₖ)) +end + +function process_batch(batch, nn, data, hparams, preallocated_aₖ) + batch_results = Float32[] + for ω in batch + result = compute_loss_for_single_ω(nn, ω, data, hparams, preallocated_aₖ) + push!(batch_results, result) + end + return batch_results +end + +function sliced_invariant_statistical_loss_multithreaded( + nn_model, loader, hparams::HyperParamsSlicedISL +) + @assert loader.batchsize == hparams.samples + @assert length(loader) == hparams.epochs + losses = Vector{Float32}() + optim = Flux.setup(Flux.Adam(hparams.η), nn_model) + + @showprogress for data in loader + loss, grads = Flux.withgradient(nn_model) do nn + Ω = [sample_random_direction(size(data)[1]) for _ in 1:(hparams.m)] + + # Split Ω into batches + batches = [Ω[i:min(i + BATCH_SIZE - 1, end)] for i in 1:BATCH_SIZE:length(Ω)] + preallocated_aₖ = zeros(hparams.K + 1) + + # Process each batch in parallel + batch_tasks = [ + Threads.@spawn process_batch(batch, nn, data, hparams, preallocated_aₖ) + for batch in batches + ] + + # Collect and sum up the results + loss_components = vcat(fetch.(batch_tasks)...) + sum(loss_components) / hparams.m + end + + Flux.update!(optim, nn_model, grads[1]) + push!(losses, loss) + end + + return losses +end + + +function compute_forward_pass(nn, ω, data, hparams) + aₖ = zeros(hparams.K + 1) + for i in 1:hparams.samples + x = Float32.(rand(hparams.noise_model, hparams.K)) + yₖ = nn(x) + s = Matrix(reshape(ω' * yₖ, 1, hparams.K)) # Convert to Matrix + aₖ += generate_aₖ(s, ω ⋅ data[:, i]) + end + return aₖ +end + +function sliced_invariant_statistical_loss_multithreaded_2(nn_model, loader, hparams::HyperParamsSlicedISL) + @assert loader.batchsize == hparams.samples + @assert length(loader) == hparams.epochs + losses = Vector{Float32}() + optim = Flux.setup(Flux.Adam(hparams.η), nn_model) + + @showprogress for data in loader + Ω = [sample_random_direction(size(data)[1]) for _ in 1:hparams.m] + + # Perform the forward pass in parallel + forward_pass_results = [Threads.@spawn compute_forward_pass(nn_model, ω, data, hparams) for ω in Ω] + aₖ_results = fetch.(forward_pass_results) + + # Compute gradients sequentially + loss, grads = Flux.withgradient(nn_model) do nn + total_loss = sum([scalar_diff(aₖ_result ./ sum(aₖ_result)) for aₖ_result in aₖ_results]) / hparams.m + total_loss + end + + Flux.update!(optim, nn_model, grads[1]) + push!(losses, loss) + end + + return losses +end diff --git a/src/ISL.jl b/src/ISL.jl index ae91a73..a7d24b1 100644 --- a/src/ISL.jl +++ b/src/ISL.jl @@ -6,6 +6,7 @@ using StatsBase using Distributions: Normal, rand using HypothesisTests: pvalue, ChisqTest using MLUtils +using LinearAlgebra using Parameters: @with_kw using ProgressMeter @@ -31,5 +32,10 @@ export _sigmoid, auto_invariant_statistical_loss_1, HyperParamsTS, ts_invariant_statistical_loss_one_step_prediction, - ts_invariant_statistical_loss + ts_invariant_statistical_loss, + HyperParamsSlicedISL, + sliced_invariant_statistical_loss, + sliced_invariant_statistical_loss_2, + sliced_invariant_statistical_loss_multithreaded, + sliced_invariant_statistical_loss_multithreaded_2 end