Skip to content

Commit

Permalink
add sliced_invariant_statistical_loss_optimized
Browse files Browse the repository at this point in the history
  • Loading branch information
josemanuel22 committed Jan 14, 2024
1 parent 12b095a commit 1db54a7
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 5 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
PProf = "e4faabce-9ead-11e9-39d9-4379958e3056"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Expand Down
11 changes: 7 additions & 4 deletions examples/Sliced_ISL/MNIST_sliced.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,16 +144,19 @@ noise_model = MvNormal(mean_vector, cov_matrix)
n_samples = 10000

hparams = HyperParamsSlicedISL(;
K=10, samples=100, epochs=1, η=1e-2, noise_model=noise_model, m=10
K=10, samples=100, epochs=1, η=1e-2, noise_model=noise_model, m=100
)

# Create a data loader for training
batch_size = 100
train_loader = DataLoader(train_x; batchsize=batch_size, shuffle=true, partial=false)
train_loader = DataLoader(train_x; batchsize=batch_size, shuffle=false, partial=false)

total_loss = []
@showprogress for _ in 1:200
append!(total_loss, optimized_loss(model, train_loader, hparams))
@showprogress for _ in 1:10
append!(
total_loss,
sliced_invariant_statistical_loss_optimized_3(model, train_loader, hparams),
)
end

img = model(Float32.(rand(hparams.noise_model, 1)))
Expand Down
105 changes: 105 additions & 0 deletions src/CustomLossFunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,111 @@ function sliced_invariant_statistical_loss_optimized(nn_model, loader, hparams)
return losses
end

function sliced_invariant_statistical_loss_optimized_3(nn_model, loader, hparams)
@inline function compute_for_ω(ω, nn, data, hparams)
aₖ = zeros(Float32, hparams.K + 1)

# Generate all random numbers in one go
x_batch = rand(hparams.noise_model, hparams.samples * hparams.K)

# Process batch through nn_model
yₖ_batch = nn(Float32.(x_batch))

s = Matrix' * yₖ_batch)

# Pre-compute column indices for slicing
start_cols = hparams.K * (1:(hparams.samples - 1))
end_cols = hparams.K * (2:(hparams.samples)) .- 1

# Create slices of 's' for all 'aₖ_slice'
aₖ_slices = [
s[:, start_col:(end_col - 1)] for
(start_col, end_col) in zip(start_cols, end_cols)
]

# Compute the dot products for all iterations at once
ω_data_dot_products = [dot(ω, data[:, i]) for i in 2:(hparams.samples)]

# Apply 'generate_aₖ' for each pair and sum the results
aₖ = sum([
generate_aₖ(aₖ_slice, ω_data_dot_product) for
(aₖ_slice, ω_data_dot_product) in zip(aₖ_slices, ω_data_dot_products)
])
return scalar_diff(aₖ ./ sum(aₖ))
end

@assert loader.batchsize == hparams.samples
@assert length(loader) == hparams.epochs
losses = Vector{Float32}()
optim = Flux.setup(Flux.Adam(hparams.η), nn_model)

@showprogress for data in loader
Ω = ThreadsX.map(_ -> sample_random_direction(size(data)[1]), 1:(hparams.m))
total = 0.0f0
loss, grads = Flux.withgradient(nn_model) do nn
# Compute for each ω in Ω in parallel and sum the results
total = sum([compute_for_ω(ω, nn, data, hparams) for ω in Ω]) / hparams.m
#for ω in Ω
# total += compute_for_ω(ω, nn, data, hparams)
#end
#total = ThreadsX.sum(ω -> compute_for_ω(ω, nn, data, hparams), Ω)
total
end
Flux.update!(optim, nn_model, grads[1])
push!(losses, loss)
end
return losses
end

function sliced_invariant_statistical_loss_optimized_2(nn_model, loader, hparams)
@assert loader.batchsize == hparams.samples
@assert length(loader) == hparams.epochs
losses = Vector{Float32}()
optim = Flux.setup(Flux.Adam(hparams.η), nn_model)

@showprogress for data in loader
Ω = ThreadsX.map(_ -> sample_random_direction(size(data)[1]), 1:(hparams.m))
loss, grads = Flux.withgradient(nn_model) do nn
total = 0.0f0
for ω in Ω
aₖ = zeros(Float32, hparams.K + 1) # Reset aₖ for each new ω

# Generate all random numbers in one go
x_batch = rand(hparams.noise_model, hparams.samples * hparams.K)

# Process batch through nn_model
yₖ_batch = nn(Float32.(x_batch))

s = Matrix' * yₖ_batch)

# Pre-compute column indices for slicing
start_cols = hparams.K * (1:(hparams.samples - 1))
end_cols = hparams.K * (2:(hparams.samples)) .- 1

# Create slices of 's' for all 'aₖ_slice'
aₖ_slices = [
s[:, start_col:(end_col - 1)] for
(start_col, end_col) in zip(start_cols, end_cols)
]

# Compute the dot products for all iterations at once
ω_data_dot_products = [dot(ω, data[:, i]) for i in 2:(hparams.samples)]

# Apply 'generate_aₖ' for each pair and sum the results
aₖ = sum([
generate_aₖ(aₖ_slice, ω_data_dot_product) for
(aₖ_slice, ω_data_dot_product) in zip(aₖ_slices, ω_data_dot_products)
])
total += scalar_diff(aₖ ./ sum(aₖ))
end
total / hparams.m
end
Flux.update!(optim, nn_model, grads[1])
push!(losses, loss)
end
return losses
end

function sliced_ortonormal_invariant_statistical_loss(
nn_model, loader, hparams::HyperParamsSlicedISL
)
Expand Down
4 changes: 3 additions & 1 deletion src/ISL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,7 @@ export _sigmoid,
sliced_invariant_statistical_loss_multithreaded_2,
sliced_invariant_statistical_loss_selected_directions,
sliced_ortonormal_invariant_statistical_loss,
sliced_invariant_statistical_loss_optimized
sliced_invariant_statistical_loss_optimized,
sliced_invariant_statistical_loss_optimized_2,
sliced_invariant_statistical_loss_optimized_3
end

0 comments on commit 1db54a7

Please sign in to comment.