Skip to content

Commit

Permalink
refactor ts_invariant_statistical_loss_multivariate
Browse files Browse the repository at this point in the history
  • Loading branch information
josemanuel22 committed Feb 11, 2024
1 parent 4a9a765 commit 3d8a52e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 90 deletions.
30 changes: 14 additions & 16 deletions examples/Learning1d_distribution/benchmark_multimodal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,38 +13,39 @@ include("../utils.jl")
dscr = Chain(
Dense(1, 11), elu, Dense(11, 29), elu, Dense(29, 11), elu, Dense(11, 1, σ)
)
target_model = Normal(4.0f0, 2.0f0)
target_model = MixtureModel([
Normal(5.0f0, 2.0f0), Normal(-1.0f0, 1.0f0), Normal(-10.0f0, 3.0f0)
])
hparams = HyperParamsVanillaGan(;
data_size=100,
batch_size=1,
epochs=1e3,
lr_dscr=1e-4,
epochs=1e4,
lr_dscr=1e-3,
lr_gen=1e-4,
dscr_steps=0,
gen_steps=0,
dscr_steps=4,
gen_steps=1,
noise_model=noise_model,
target_model=target_model,
)

train_vanilla_gan(dscr, gen, hparams)

hparams = AutoISLParams(;
max_k=20, samples=1000, epochs=2000, η=1e-2, transform=noise_model
max_k=10, samples=1000, epochs=1000, η=1e-2, transform=noise_model
)
train_set = Float32.(rand(target_model, hparams.samples))
loader = Flux.DataLoader(train_set; batchsize=-1, shuffle=true, partial=false)

auto_invariant_statistical_loss(gen, loader, hparams)

#save_gan_model(gen, dscr, hparams)
plot_global(
x -> quantile.(target_model, cdf(noise_model, x)),
noise_model,
target_model,
gen,
n_samples,
(-3:0.1:3),
(0:0.02:10),
(0:0.2:10),
)

#@test js_divergence(hist1.weights, hist2.weights)/hparams.samples < 0.03
Expand All @@ -56,16 +57,14 @@ include("../utils.jl")
dscr = Chain(
Dense(1, 11), elu, Dense(11, 29), elu, Dense(29, 11), elu, Dense(11, 1, σ)
)
target_model = MixtureModel([
Normal(5.0f0, 2.0f0), Pareto(5.0f0,1.0f0),
])
target_model = MixtureModel([Normal(5.0f0, 2.0f0), Pareto(5.0f0, 1.0f0)])
hparams = HyperParamsVanillaGan(;
data_size=100,
batch_size=1,
epochs=1000,
epochs=1e4,
lr_dscr=1e-4,
lr_gen=1e-4,
dscr_steps=1,
dscr_steps=2,
gen_steps=1,
noise_model=noise_model,
target_model=target_model,
Expand All @@ -74,7 +73,7 @@ include("../utils.jl")
train_vanilla_gan(dscr, gen, hparams)

hparams = AutoISLParams(;
max_k=20, samples=1000, epochs=1000, η=1e-2, transform=noise_model
max_k=10, samples=10000, epochs=1000, η=1e-3, transform=noise_model
)
train_set = Float32.(rand(target_model, hparams.samples))
loader = Flux.DataLoader(train_set; batchsize=-1, shuffle=true, partial=false)
Expand All @@ -86,7 +85,7 @@ include("../utils.jl")
mse = MSE(noise_model, x -> 2 * cdf(Normal(0, 1), x) + 22, n_sample)

plot_global(
x -> -quantile.(-target_model, cdf(noise_model, x)),
x -> quantile.(target_model, cdf(noise_model, x)),
noise_model,
target_model,
gen,
Expand Down Expand Up @@ -255,7 +254,6 @@ include("../utils.jl")
loader = Flux.DataLoader(train_set; batchsize=-1, shuffle=true, partial=false)

auto_invariant_statistical_loss(gen, loader, hparams)

end

@test_experiments "Uniform(-1,1) to Pareto(1,23)" begin
Expand Down
76 changes: 2 additions & 74 deletions examples/Learning1d_distribution/benchmark_unimodal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ include("../utils.jl")
dscr = Chain(
Dense(1, 11), elu, Dense(11, 29), elu, Dense(29, 11), elu, Dense(11, 1, σ)
)
target_model = MixtureModel([
Normal(5.0f0, 2.0f0), Pareto(5.0f0,1.0f0),
])
target_model = MixtureModel([Normal(5.0f0, 2.0f0), Pareto(5.0f0, 1.0f0)])
hparams = HyperParamsVanillaGan(;
data_size=100,
batch_size=1,
Expand All @@ -36,68 +34,9 @@ include("../utils.jl")

train_set = Float32.(rand(target_model, hparams.samples))
loader = Flux.DataLoader(train_set; batchsize=-1, shuffle=true, partial=false)

ksd1 = 0.0
ranges = (0:0.1:8)
@showprogress for i in 1:10
target_model = MixtureModel([
Normal(5.0f0, 2.0f0), Pareto(5.0f0,1.0f0),
])
gen = Chain(Dense(1, 7), elu, Dense(7, 13), elu, Dense(13, 7), elu, Dense(7, 1))
dscr = Chain(
Dense(1, 11), elu, Dense(11, 29), elu, Dense(29, 11), elu, Dense(11, 1, σ)
)
hparams = AutoISLParams(;
max_k=10, samples=1000, epochs=100, η=1e-2, transform=noise_model
)

train_set = Float32.(rand(target_model, hparams.samples))
loader = Flux.DataLoader(train_set; batchsize=-1, shuffle=true, partial=false)
loss = auto_invariant_statistical_loss(gen, loader, hparams)
ksd1 += KSD(noise_model, target_model, gen, n_samples, ranges)
end

ksd2 = 0.0
@showprogress for i in 1:10
target_model = MixtureModel([
Normal(5.0f0, 2.0f0), Pareto(5.0f0,1.0f0),
])
gen = Chain(Dense(1, 7), elu, Dense(7, 13), elu, Dense(13, 7), elu, Dense(7, 1))
dscr = Chain(
Dense(1, 11), elu, Dense(11, 29), elu, Dense(29, 11), elu, Dense(11, 1, σ)
)
hparams = AutoISLParams(;
max_k=10, samples=1000, epochs=100, η=1e-2, transform=noise_model
)

train_set = Float32.(rand(target_model, hparams.samples))
loader = Flux.DataLoader(train_set; batchsize=-1, shuffle=true, partial=false)
loss = auto_invariant_statistical_loss_2(gen, loader, hparams)
ksd2 += KSD(noise_model, target_model, gen, n_samples, ranges)
end

loss = auto_invariant_statistical_loss(gen, loader, hparams)

K= 10
data = loader
aₖ = zeros(K + 1)
for i in 1:(hparams.samples)
x = rand(hparams.transform, K)
yₖ = gen(x')
aₖ += generate_aₖ(yₖ, data.data[i])
end
end

function loss_2(x)
return scalar_diff(Float32.(x)./sum(Float32.(x)))
auto_invariant_statistical_loss(gen, loader, hparams)
end

plot(moving_average([sqrt(l[1]) for l in loss], 10), label="ISL Theoretical Loss", xtickfontsize=10, ytickfontsize=10, legendfontsize=10, color=:red, linewidth=3, yaxis=:log10)
plot!(moving_average([loss_2(l[2]) for l in loss], 10), label="ISL Surrogate Loss", color=:blue, linewidth=3, yaxis=:log10)
xlabel!("Epochs")
ylabel!("Loss")
ylims!((0.0,0.4)) # Setting the y-axis limits

@test_experiments "N(0,1) to Uniform(22,24)" begin
gen = Chain(Dense(1, 7), elu, Dense(7, 13), elu, Dense(13, 7), elu, Dense(7, 1))
dscr = Chain(
Expand Down Expand Up @@ -167,14 +106,6 @@ include("../utils.jl")

auto_invariant_statistical_loss(gen, loader, hparams)

# ksd = KSD(noise_model, target_model, n_samples, 18:0.1:25)
# mae = MAE(
# noise_model, x -> quantile.(target_model, cdf(noise_model, x)), n_samples
# )
# mse = MSE(
# noise_model, x -> quantile.(target_model, cdf(noise_model, x)), n_samples
# )

plot_global(
x -> quantile.(target_model, cdf(noise_model, x)),
noise_model,
Expand Down Expand Up @@ -371,9 +302,6 @@ end
loader = Flux.DataLoader(train_set; batchsize=-1, shuffle=true, partial=false)

auto_invariant_statistical_loss(gen, loader, hparams)

#save_gan_model(gen, dscr, hparams)

end

@test_experiments "N(0,1) to Uniform(22,24)" begin
Expand Down

0 comments on commit 3d8a52e

Please sign in to comment.