Skip to content

Commit

Permalink
adding sliced algo
Browse files Browse the repository at this point in the history
  • Loading branch information
josemanuel22 committed Dec 30, 2023
1 parent ca7f2ec commit 1a4cd75
Show file tree
Hide file tree
Showing 6 changed files with 651 additions and 16 deletions.
8 changes: 8 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,25 @@ CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Formatting = "59287772-0a20-5a39-b81b-1366585eb4c0"
GR = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71"
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
25 changes: 13 additions & 12 deletions examples/Learning1d_distribution/benchmark_unimodal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ include("../utils.jl")
n_samples = 10000

@test_experiments "N(0,1) to N(23,1)" begin
gen = Chain(Dense(1, 7), elu, Dense(7, 13), elu, Dense(13, 7), elu, Dense(7, 1))
gen = Chain(Dense(2, 7), elu, Dense(7, 13), elu, Dense(13, 7), elu, Dense(7, 2))
dscr = Chain(
Dense(1, 11), elu, Dense(11, 29), elu, Dense(29, 11), elu, Dense(11, 1, σ)
)
target_model = Normal(4.0f0, 2.0f0)
target_model = MixtureModel([
Normal(5.0f0, 2.0f0), Pareto(5.0f0,1.0f0),
])
hparams = HyperParamsVanillaGan(;
data_size=100,
batch_size=1,
Expand All @@ -29,17 +31,16 @@ include("../utils.jl")
train_vanilla_gan(dscr, gen, hparams)

hparams = AutoISLParams(;
max_k=10, samples=1000, epochs=1000, η=1e-2, transform=noise_model
max_k=10, samples=1000, epochs=100, η=1e-2, transform=noise_model
)

train_set = Float32.(rand(target_model, hparams.samples))
loader = Flux.DataLoader(train_set; batchsize=-1, shuffle=true, partial=false)

train_set = Float32.(rand(target_model, hparams.samples * hparams.epochs))
loader = Flux.DataLoader(train_set; batchsize=hparams.samples, shuffle=true, partial=false)
auto_invariant_statistical_loss(gen, loader, hparams)
end

@test_experiments "N(0,1) to Uniform(22,24)" begin
gen = Chain(Dense(1, 7), elu, Dense(7, 13), elu, Dense(13, 7), elu, Dense(7, 1))
gen = Chain(Dense(1, 7), elu, Dense(7, 13), elu, Dense(13, 7), elu, Dense(7, 2))
dscr = Chain(
Dense(1, 11), elu, Dense(11, 29), elu, Dense(29, 11), elu, Dense(11, 1, σ)
)
Expand All @@ -59,7 +60,7 @@ include("../utils.jl")
train_vanilla_gan(dscr, gen, hparams)

hparams = AutoISLParams(;
max_k=10, samples=1000, epochs=1000, η=1e-2, transform=noise_model
max_k=10, samples=1000, epochs=100, η=1e-2, transform=noise_model
)
train_set = Float32.(rand(target_model, hparams.samples))
loader = Flux.DataLoader(train_set; batchsize=-1, shuffle=true, partial=false)
Expand All @@ -73,7 +74,7 @@ include("../utils.jl")
gen,
n_samples,
(-3:0.1:3),
(0:0.1:10),
(0:0.1:30),
)
end

Expand All @@ -100,9 +101,9 @@ include("../utils.jl")
train_vanilla_gan(dscr, gen, hparams)

hparams = AutoISLParams(;
max_k=10, samples=1000, epochs=1000, η=1e-2, transform=noise_model
max_k=10, samples=1000, epochs=100, η=1e-2, transform=noise_model
)
train_set = Float32.(rand(target_model, hparams.samples))
train_set = Float32.(rand(target_model, hparams.sample * hparams.epochs))
loader = Flux.DataLoader(train_set; batchsize=-1, shuffle=true, partial=false)

auto_invariant_statistical_loss(gen, loader, hparams)
Expand All @@ -122,7 +123,7 @@ include("../utils.jl")
gen,
n_samples,
(-3:0.1:3),
(0:0.1:10),
(-20:0.1:20),
)
end

Expand Down
63 changes: 63 additions & 0 deletions examples/Sliced_ISL/MNIST_sliced.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
using ISL
using Flux
using MLDatasets
using Images
using ImageTransformations # For resizing images if necessary

function load_mnist(digit::Int)
# Load MNIST data
train_x, train_y = MLDatasets.MNIST.traindata()
test_x, test_y = MLDatasets.MNIST.testdata()

# Find indices where the label is digit
selected_indices = findall(x -> x == digit, train_y)

selected_images = train_x[:, :, selected_indices]

return (reshape(Float32.(selected_images), 784, :), train_y)#, (test_x, test_y)
end

(train_x, train_y) = load_mnist(5)


model = Chain(
Dense(3, 512, relu),
Dense(512, 28*28, sigmoid)
)

model = Chain(
Dense(3, 256, relu),
#BatchNorm(256),
Dense(256, 512, relu),
#BatchNorm(512, relu),
Dense(512, 28*28, identity),
x -> reshape(x, 28, 28, 1, :),
Conv((3, 3), 1=>16, relu),
MaxPool((2,2)),
x -> reshape(x, :, size(x, 4)),
Flux.flatten,
Dense(2704, 28*28)
)

# Define hyperparameters
noise_model = MvNormal([0.0, 0.0, 0.0], [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 1.0])
n_samples = 10000

hparams = HyperParamsSlicedISL(;
K=10, samples=1000, epochs=5, η=1e-2, noise_model=noise_model, m=20
)

# Create a data loader for training
batch_size = 1000
train_loader = DataLoader(train_x; batchsize=batch_size, shuffle=false, partial=false)

total_loss = []
for _ in 1:10
append!(total_loss, sliced_invariant_statistical_loss(model, train_loader, hparams))
end

img = model(Float32.(rand(hparams.noise_model, 1)))
img2 = reshape(img, 28, 28)
display(Gray.(img2))
transformed_matrix = Float32.(img2 .> 0.1)
display(Gray.(transformed_matrix))
Loading

0 comments on commit 1a4cd75

Please sign in to comment.