From 3d8a52e77171d23690bbf74815cb121f779a6d53 Mon Sep 17 00:00:00 2001 From: josemanuel22 Date: Sun, 11 Feb 2024 09:49:36 +0100 Subject: [PATCH] refactor ts_invariant_statistical_loss_multivariate --- .../benchmark_multimodal.jl | 30 ++++---- .../benchmark_unimodal.jl | 76 +------------------ 2 files changed, 16 insertions(+), 90 deletions(-) diff --git a/examples/Learning1d_distribution/benchmark_multimodal.jl b/examples/Learning1d_distribution/benchmark_multimodal.jl index f178a07..9482296 100644 --- a/examples/Learning1d_distribution/benchmark_multimodal.jl +++ b/examples/Learning1d_distribution/benchmark_multimodal.jl @@ -13,15 +13,17 @@ include("../utils.jl") dscr = Chain( Dense(1, 11), elu, Dense(11, 29), elu, Dense(29, 11), elu, Dense(11, 1, σ) ) - target_model = Normal(4.0f0, 2.0f0) + target_model = MixtureModel([ + Normal(5.0f0, 2.0f0), Normal(-1.0f0, 1.0f0), Normal(-10.0f0, 3.0f0) + ]) hparams = HyperParamsVanillaGan(; data_size=100, batch_size=1, - epochs=1e3, - lr_dscr=1e-4, + epochs=1e4, + lr_dscr=1e-3, lr_gen=1e-4, - dscr_steps=0, - gen_steps=0, + dscr_steps=4, + gen_steps=1, noise_model=noise_model, target_model=target_model, ) @@ -29,14 +31,13 @@ include("../utils.jl") train_vanilla_gan(dscr, gen, hparams) hparams = AutoISLParams(; - max_k=20, samples=1000, epochs=2000, η=1e-2, transform=noise_model + max_k=10, samples=1000, epochs=1000, η=1e-2, transform=noise_model ) train_set = Float32.(rand(target_model, hparams.samples)) loader = Flux.DataLoader(train_set; batchsize=-1, shuffle=true, partial=false) auto_invariant_statistical_loss(gen, loader, hparams) - #save_gan_model(gen, dscr, hparams) plot_global( x -> quantile.(target_model, cdf(noise_model, x)), noise_model, @@ -44,7 +45,7 @@ include("../utils.jl") gen, n_samples, (-3:0.1:3), - (0:0.02:10), + (0:0.2:10), ) #@test js_divergence(hist1.weights, hist2.weights)/hparams.samples < 0.03 @@ -56,16 +57,14 @@ include("../utils.jl") dscr = Chain( Dense(1, 11), elu, Dense(11, 29), elu, Dense(29, 11), elu, Dense(11, 1, σ) ) - target_model = MixtureModel([ - Normal(5.0f0, 2.0f0), Pareto(5.0f0,1.0f0), - ]) + target_model = MixtureModel([Normal(5.0f0, 2.0f0), Pareto(5.0f0, 1.0f0)]) hparams = HyperParamsVanillaGan(; data_size=100, batch_size=1, - epochs=1000, + epochs=1e4, lr_dscr=1e-4, lr_gen=1e-4, - dscr_steps=1, + dscr_steps=2, gen_steps=1, noise_model=noise_model, target_model=target_model, @@ -74,7 +73,7 @@ include("../utils.jl") train_vanilla_gan(dscr, gen, hparams) hparams = AutoISLParams(; - max_k=20, samples=1000, epochs=1000, η=1e-2, transform=noise_model + max_k=10, samples=10000, epochs=1000, η=1e-3, transform=noise_model ) train_set = Float32.(rand(target_model, hparams.samples)) loader = Flux.DataLoader(train_set; batchsize=-1, shuffle=true, partial=false) @@ -86,7 +85,7 @@ include("../utils.jl") mse = MSE(noise_model, x -> 2 * cdf(Normal(0, 1), x) + 22, n_sample) plot_global( - x -> -quantile.(-target_model, cdf(noise_model, x)), + x -> quantile.(target_model, cdf(noise_model, x)), noise_model, target_model, gen, @@ -255,7 +254,6 @@ include("../utils.jl") loader = Flux.DataLoader(train_set; batchsize=-1, shuffle=true, partial=false) auto_invariant_statistical_loss(gen, loader, hparams) - end @test_experiments "Uniform(-1,1) to Pareto(1,23)" begin diff --git a/examples/Learning1d_distribution/benchmark_unimodal.jl b/examples/Learning1d_distribution/benchmark_unimodal.jl index e8063f1..20251b5 100644 --- a/examples/Learning1d_distribution/benchmark_unimodal.jl +++ b/examples/Learning1d_distribution/benchmark_unimodal.jl @@ -13,9 +13,7 @@ include("../utils.jl") dscr = Chain( Dense(1, 11), elu, Dense(11, 29), elu, Dense(29, 11), elu, Dense(11, 1, σ) ) - target_model = MixtureModel([ - Normal(5.0f0, 2.0f0), Pareto(5.0f0,1.0f0), - ]) + target_model = MixtureModel([Normal(5.0f0, 2.0f0), Pareto(5.0f0, 1.0f0)]) hparams = HyperParamsVanillaGan(; data_size=100, batch_size=1, @@ -36,68 +34,9 @@ include("../utils.jl") train_set = Float32.(rand(target_model, hparams.samples)) loader = Flux.DataLoader(train_set; batchsize=-1, shuffle=true, partial=false) - - ksd1 = 0.0 - ranges = (0:0.1:8) - @showprogress for i in 1:10 - target_model = MixtureModel([ - Normal(5.0f0, 2.0f0), Pareto(5.0f0,1.0f0), - ]) - gen = Chain(Dense(1, 7), elu, Dense(7, 13), elu, Dense(13, 7), elu, Dense(7, 1)) - dscr = Chain( - Dense(1, 11), elu, Dense(11, 29), elu, Dense(29, 11), elu, Dense(11, 1, σ) - ) - hparams = AutoISLParams(; - max_k=10, samples=1000, epochs=100, η=1e-2, transform=noise_model - ) - - train_set = Float32.(rand(target_model, hparams.samples)) - loader = Flux.DataLoader(train_set; batchsize=-1, shuffle=true, partial=false) - loss = auto_invariant_statistical_loss(gen, loader, hparams) - ksd1 += KSD(noise_model, target_model, gen, n_samples, ranges) - end - - ksd2 = 0.0 - @showprogress for i in 1:10 - target_model = MixtureModel([ - Normal(5.0f0, 2.0f0), Pareto(5.0f0,1.0f0), - ]) - gen = Chain(Dense(1, 7), elu, Dense(7, 13), elu, Dense(13, 7), elu, Dense(7, 1)) - dscr = Chain( - Dense(1, 11), elu, Dense(11, 29), elu, Dense(29, 11), elu, Dense(11, 1, σ) - ) - hparams = AutoISLParams(; - max_k=10, samples=1000, epochs=100, η=1e-2, transform=noise_model - ) - - train_set = Float32.(rand(target_model, hparams.samples)) - loader = Flux.DataLoader(train_set; batchsize=-1, shuffle=true, partial=false) - loss = auto_invariant_statistical_loss_2(gen, loader, hparams) - ksd2 += KSD(noise_model, target_model, gen, n_samples, ranges) - end - - loss = auto_invariant_statistical_loss(gen, loader, hparams) - - K= 10 - data = loader - aₖ = zeros(K + 1) - for i in 1:(hparams.samples) - x = rand(hparams.transform, K) - yₖ = gen(x') - aₖ += generate_aₖ(yₖ, data.data[i]) - end - end - - function loss_2(x) - return scalar_diff(Float32.(x)./sum(Float32.(x))) + auto_invariant_statistical_loss(gen, loader, hparams) end - plot(moving_average([sqrt(l[1]) for l in loss], 10), label="ISL Theoretical Loss", xtickfontsize=10, ytickfontsize=10, legendfontsize=10, color=:red, linewidth=3, yaxis=:log10) - plot!(moving_average([loss_2(l[2]) for l in loss], 10), label="ISL Surrogate Loss", color=:blue, linewidth=3, yaxis=:log10) - xlabel!("Epochs") - ylabel!("Loss") - ylims!((0.0,0.4)) # Setting the y-axis limits - @test_experiments "N(0,1) to Uniform(22,24)" begin gen = Chain(Dense(1, 7), elu, Dense(7, 13), elu, Dense(13, 7), elu, Dense(7, 1)) dscr = Chain( @@ -167,14 +106,6 @@ include("../utils.jl") auto_invariant_statistical_loss(gen, loader, hparams) - # ksd = KSD(noise_model, target_model, n_samples, 18:0.1:25) - # mae = MAE( - # noise_model, x -> quantile.(target_model, cdf(noise_model, x)), n_samples - # ) - # mse = MSE( - # noise_model, x -> quantile.(target_model, cdf(noise_model, x)), n_samples - # ) - plot_global( x -> quantile.(target_model, cdf(noise_model, x)), noise_model, @@ -371,9 +302,6 @@ end loader = Flux.DataLoader(train_set; batchsize=-1, shuffle=true, partial=false) auto_invariant_statistical_loss(gen, loader, hparams) - - #save_gan_model(gen, dscr, hparams) - end @test_experiments "N(0,1) to Uniform(22,24)" begin