Skip to content

Commit

Permalink
add sliced_invariant_statistical_loss_optimized
Browse files Browse the repository at this point in the history
  • Loading branch information
josemanuel22 committed Jan 13, 2024
1 parent 821c037 commit 12b095a
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 25 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.1.0"

[deps]
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Expand All @@ -24,10 +25,12 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d"
Unzip = "41fe7b60-77ed-43a1-b4f0-825fd5a5650d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extras]
Expand Down
30 changes: 21 additions & 9 deletions examples/Sliced_ISL/MNIST_sliced.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ using Images
using ImageTransformations # For resizing images if necessary
using LinearAlgebra

function load_mnist()
# Load MNIST data
train_x, train_y = MLDatasets.MNIST.traindata()
test_x, test_y = MLDatasets.MNIST.testdata()

return (reshape(Float32.(train_x), 28 * 28, :), train_y)#, (test_x, test_y)
end

function load_mnist(digit::Int)
# Load MNIST data
train_x, train_y = MLDatasets.MNIST.traindata()
Expand Down Expand Up @@ -41,14 +49,14 @@ function load_mnist_normalized(digit::Int, max::Int)

image_tensor = reshape(@.(2.0f0 * selected_images - 1.0f0), 28, 28, :)

train_data = reshape(image_tensor, 28 * 28, :)

return (train_data, train_y)
return (reshape(Float32.(image_tensor), 28 * 28, :), train_y)
end

(train_x, train_y) = load_mnist()
(train_x, train_y) = (train_x[:, 1:5000], train_y[1:5000])
(train_x, train_y) = load_mnist(0)
(train_x, train_y) = load_mnist(9, 100)
(train_x, train_y) = load_mnist_normalized(9, 100)
(train_x, train_y) = load_mnist_normalized(8, 100)

# Dimension
dims = 100
Expand Down Expand Up @@ -103,6 +111,9 @@ function Discriminator()
)
end

latent_dim = 100
# weight initialization as given in the paper https://arxiv.org/abs/1511.06434
dcgan_init(shape...) = randn(Float32, shape...) * 0.02f0
function Generator(latent_dim::Int)
return Chain(
Dense(latent_dim, 7 * 7 * 256),
Expand All @@ -113,11 +124,12 @@ function Generator(latent_dim::Int)
ConvTranspose((4, 4), 128 => 64; stride=2, pad=1, init=dcgan_init),
BatchNorm(64, relu),
ConvTranspose((4, 4), 64 => 1; stride=2, pad=1, init=dcgan_init),
Flux.flatten,
x -> tanh.(x),
)
end

model = Generator(dims)
model = Generator(latent_dim)
#model = Chain( ConvTranspose((7, 7), 100 => 256, stride=1, padding=0), BatchNorm(256, relu), ConvTranspose((4, 4), 256 => 128, stride=2, padding=1), BatchNorm(128, relu), ConvTranspose((4, 4), 128 => 1, stride=2, padding=1), tanh ))

# Mean vector (zero vector of length dim)
Expand All @@ -132,16 +144,16 @@ noise_model = MvNormal(mean_vector, cov_matrix)
n_samples = 10000

hparams = HyperParamsSlicedISL(;
K=10, samples=100, epochs=1, η=1e-2, noise_model=noise_model, m=200
K=10, samples=100, epochs=1, η=1e-2, noise_model=noise_model, m=10
)

# Create a data loader for training
batch_size = 100
train_loader = DataLoader(train_x; batchsize=batch_size, shuffle=false, partial=false)
train_loader = DataLoader(train_x; batchsize=batch_size, shuffle=true, partial=false)

total_loss = []
@showprogress for _ in 1:20
append!(total_loss, sliced_invariant_statistical_loss(model, train_loader, hparams))
@showprogress for _ in 1:200
append!(total_loss, optimized_loss(model, train_loader, hparams))
end

img = model(Float32.(rand(hparams.noise_model, 1)))
Expand Down
47 changes: 47 additions & 0 deletions examples/Sliced_ISL/MNIST_sliced2.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
using ISL
using Flux
using MLDatasets
using Images
using ImageTransformations # For resizing images if necessary
using LinearAlgebra

function load_mnist()
# Load MNIST data
train_x, train_y = MLDatasets.MNIST.traindata()
test_x, test_y = MLDatasets.MNIST.testdata()
return (reshape(Float32.(train_x), 28 * 28, :), train_y)#, (test_x, test_y)
end

(images, labels) = load_mnist()

n_outputs = length(unique(labels))

ys = [Flux.onehot(labels, 0:9) for labels in labels]

n_inputs, n_latent, n_outputs = 28 * 28, 50, 10
model = Chain(
Dense(n_inputs, n_latent, identity),
Dense(n_latent, n_latent, identity),
Dense(n_latent, n_outputs, identity),
softmax,
)
loss(x, y) = Flux.crossentropy(model(x), y)

function create_batch(r)
xs = images[:, r]
ys = [Flux.onehot(labels, 0:9) for labels in labels[r]]
return (xs, Flux.batch(ys))
end

trainbatch = create_batch(1:5000)

opt = Flux.setup(Flux.Adam(hparams.η), model)
opt = ADAM()

@showprogress for _ in 1:1000
Flux.train!(loss, Flux.params(model), [trainbatch], opt)
end

model(images[:, 1])
img2 = reshape(images[:, 1], 28, 28)
display(Gray.(img2))
85 changes: 70 additions & 15 deletions src/CustomLossFunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ The contribution is computed according to the formula:
```
"""
function γ(yₖ::Matrix{T}, yₙ::T, m::Int64) where {T<:AbstractFloat}
eₘ(m) = [j == m ? 1.0 : 0.0 for j in 0:length(yₖ)]
eₘ(m) = [j == m ? T(1.0) : T(0.0) for j in 0:length(yₖ)]
return eₘ(m) * ψₘ(ϕ(yₖ, yₙ), m)
end;

Expand Down Expand Up @@ -140,7 +140,7 @@ The formula for generating `aₖ` is as follows:
aₖ = ∑_{k=0}^K γ(ŷ, y, k) = ∑_{k=0}^K ∑_{i=1}^N ψₖ(ŷ, yᵢ)
```
"""
function generate_aₖ(ŷ::Matrix{T}, y::T) where {T<:AbstractFloat}
@inline function generate_aₖ(ŷ::Matrix{T}, y::T) where {T<:AbstractFloat}
return sum([γ(ŷ, y, k) for k in 0:length(ŷ)])
end

Expand All @@ -153,7 +153,8 @@ Scalar difference between the vector representing our subrogate histogram and th
loss = ||q-1/k+1||_{2} = ∑_{k=0}^K (qₖ - 1/K+1)^2
```
"""
scalar_diff(q::Vector{T}) where {T<:AbstractFloat} = sum((q .- (1 ./ length(q))) .^ 2)
@inline scalar_diff(q::Vector{T}) where {T<:AbstractFloat} =
sum((q .- (T(1.0f0) ./ T(length(q)))) .^ 2)

Check warning on line 157 in src/CustomLossFunction.jl

View check run for this annotation

Codecov / codecov/patch

src/CustomLossFunction.jl#L156-L157

Added lines #L156 - L157 were not covered by tests

"""
`jensen_shannon_∇(aₖ)``
Expand Down Expand Up @@ -277,6 +278,13 @@ function get_window_of_Aₖ(transform, model, data, K::Int64)
return [count(x -> x == i, window) for i in 0:K]
end;

@inline function get_window_of_Aₖ(transform, model, ω, data, K::Int64)
ŷₖ = model(Float32.(rand(transform, K)))
ŷₖ_proj = [dot(ω, ŷₖ[:, i]) for i in 1:size(ŷₖ, 2)]
window = count.([ŷₖ_proj .< d for d in data])
return [count(x -> x == i, window) for i in 0:K]
end;

Check warning on line 287 in src/CustomLossFunction.jl

View check run for this annotation

Codecov / codecov/patch

src/CustomLossFunction.jl#L281-L287

Added lines #L281 - L287 were not covered by tests
"""
`convergence_to_uniform(aₖ)``
Expand Down Expand Up @@ -475,7 +483,7 @@ Base.@kwdef mutable struct HyperParamsSlicedISL
m::Int = 10 # Number of random directions
end

function sample_random_direction(n::Int)::Vector{Float32}
@inline function sample_random_direction(n::Int)::Vector{Float32}
# Generate a random vector where each component is from a standard normal distribution
random_vector = randn(Float32, n)

Expand All @@ -492,21 +500,22 @@ function sample_ornormal_random_direction(n::Int, m::Int)::Vector{Vector{Float32
return [matrix[:, i] for i in 1:m]
end

using ThreadsX
function sliced_invariant_statistical_loss(nn_model, loader, hparams::HyperParamsSlicedISL)
@assert loader.batchsize == hparams.samples
@assert length(loader) == hparams.epochs
losses = Vector{Float32}()
optim = Flux.setup(Flux.Adam(hparams.η), nn_model)
@showprogress for data in loader
Ω = ThreadsX.map(_ -> sample_random_direction(size(data)[1]), 1:(hparams.m))
loss, grads = Flux.withgradient(nn_model) do nn
Ω = [sample_random_direction(size(data)[1]) for _ in 1:(hparams.m)]
total = 0.0f0
for ω in Ω
aₖ = zeros(hparams.K + 1)
for i in 1:(hparams.samples)
x = Float32.(rand(hparams.noise_model, hparams.K))
yₖ = nn(x)
s = collect(reshape(ω' * yₖ, 1, hparams.K))
s = Matrix(ω' * yₖ)
aₖ += generate_aₖ(s, ω data[:, i])
end
total += scalar_diff(aₖ ./ sum(aₖ))
Expand All @@ -519,6 +528,46 @@ function sliced_invariant_statistical_loss(nn_model, loader, hparams::HyperParam
return losses
end;

function sliced_invariant_statistical_loss_optimized(nn_model, loader, hparams)
@assert loader.batchsize == hparams.samples
@assert length(loader) == hparams.epochs
losses = Vector{Float32}()
optim = Flux.setup(Flux.Adam(hparams.η), nn_model)

@showprogress for data in loader
Ω = ThreadsX.map(_ -> sample_random_direction(size(data)[1]), 1:(hparams.m))
loss, grads = Flux.withgradient(nn_model) do nn
total = 0.0f0
for ω in Ω
aₖ = zeros(Float32, hparams.K + 1) # Reset aₖ for each new ω

# Generate all random numbers in one go
x_batch = rand(hparams.noise_model, hparams.samples * hparams.K)

# Process batch through nn_model
yₖ_batch = nn(Float32.(x_batch))

s = Matrix' * yₖ_batch)

@inbounds for i in 2:(hparams.samples)
start_col = hparams.K * (i - 1)
end_col = hparams.K * i

aₖ_slice = s[:, start_col:(end_col - 1)]
ω_data_dot_product = ω data[:, i]

aₖ += generate_aₖ(aₖ_slice, ω_data_dot_product)
end
total += scalar_diff(aₖ ./ sum(aₖ))
end
total / hparams.m
end
Flux.update!(optim, nn_model, grads[1])
push!(losses, loss)
end
return losses
end

function sliced_ortonormal_invariant_statistical_loss(
nn_model, loader, hparams::HyperParamsSlicedISL
)
Expand Down Expand Up @@ -559,17 +608,23 @@ function sliced_invariant_statistical_loss_selected_directions(
@assert length(loader) == hparams.epochs
losses = Vector{Float32}()
optim = Flux.setup(Flux.Adam(hparams.η), nn_model)

function compute_p_value(nn, data, hparams)
ω = sample_random_direction(size(data)[1])
return (
ω,
convergence_to_uniform(
get_window_of_Aₖ(hparams.noise_model, nn, ω, data, hparams.K)
),
)
end

@showprogress for data in loader
values = ThreadsX.map(_ -> compute_p_value(nn_model, data, hparams), 1:1000)

sorted_Ω = [direction for (direction, _) in sort(values; by=x -> x[2], rev=true)][1:(hparams.m)]

loss, grads = Flux.withgradient(nn_model) do nn
Ω = [sample_random_direction(size(data)[1]) for _ in 1:1000]
p_values = [
convergence_to_uniform(
get_window_of_Aₖ(hparams.noise_model, nn, ω .⋅ data, hparams.K)
) for ω in Ω
]
direction_pvalues = zip(Ω, p_values)
sorted_directions = sort(direction_pvalues; by=x -> x[2])
sorted_Ω = [direction for (direction, pvalue) in sorted_directions]
total = 0.0f0
for ω in sorted_Ω
aₖ = zeros(hparams.K + 1)
Expand Down
4 changes: 3 additions & 1 deletion src/ISL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using MLUtils
using LinearAlgebra
using Parameters: @with_kw
using ProgressMeter
using Random

using StaticArrays

Expand Down Expand Up @@ -39,5 +40,6 @@ export _sigmoid,
sliced_invariant_statistical_loss_multithreaded,
sliced_invariant_statistical_loss_multithreaded_2,
sliced_invariant_statistical_loss_selected_directions,
sliced_ortonormal_invariant_statistical_loss
sliced_ortonormal_invariant_statistical_loss,
sliced_invariant_statistical_loss_optimized
end

0 comments on commit 12b095a

Please sign in to comment.