From 610e03f3fbd1605bf4a513ac5fced2b618fe7165 Mon Sep 17 00:00:00 2001 From: josemanuel22 Date: Mon, 12 Feb 2024 16:55:29 +0100 Subject: [PATCH] add CondTimeGenModel --- .../time_series_predictions/benchmark_ts.jl | 17 ++-- src/CustomLossFunction.jl | 88 +++++++++++++++---- src/ISL.jl | 3 +- 3 files changed, 83 insertions(+), 25 deletions(-) diff --git a/examples/time_series_predictions/benchmark_ts.jl b/examples/time_series_predictions/benchmark_ts.jl index 494ff37..fc926ae 100644 --- a/examples/time_series_predictions/benchmark_ts.jl +++ b/examples/time_series_predictions/benchmark_ts.jl @@ -767,7 +767,7 @@ end df1 = CSV.File(csv1; delim=',', header=true, decimal='.') - hparams = HyperParamsTS(; seed=1234, η=1e-2, epochs=2000, window_size=2000, K=40) + hparams = HyperParamsTS(; seed=1234, η=1e-2, epochs=2000, window_size=2000, K=20) rec = Chain(RNN(1 => 3, relu), LayerNorm(3)) gen = Chain(Dense(4, 10, relu), Dropout(0.1), Dense(10, 1, identity)) @@ -844,11 +844,12 @@ end losses = [] mses = [] maes = [] + + model = CondTimeGenModel(rec, gen, nothing, Normal(0.0f0, 1.0f0)) @showprogress for _ in 1:10 - loss = ts_invariant_statistical_loss( - rec, gen, loaderXtrain, loaderYtrain, hparams, loaderXtest; cond=0.5 - ) + loss = ts_invariant_statistical_loss(model, loaderXtrain, loaderYtrain, hparams) append!(losses, loss) + end mse = 0.0 mae = 0.0 for ts in (1:(length(names(df)) - 1)) @@ -950,7 +951,7 @@ end hparams = HyperParamsTS(; seed=1234, η=1e-2, epochs=2000, window_size=10000, K=10) rec = Chain( - LSTM(1 => 10), + L(1 => 10), #Dropout(0.05), LayerNorm(10), ) @@ -1097,7 +1098,7 @@ end end @test_experiments "ETDataset multivariated" begin - url = "https://github.com/zhouhaoyi/ETDataset/blob/main/ETT-small/ETTh1.csv?raw=true" + url = "https://github.com/zhouhaoyi/ETDataset/blob/main/ETT-small/ETTh2.csv?raw=true" # Download the CSV file csv1 = HTTP.download(url) @@ -1119,7 +1120,7 @@ end dataY = [matrix[i, :] for i in 2:size(matrix, 1)] # Model hyperparameters and architecture - hparams = HyperParamsTS(; seed=1234, η=1e-2, epochs=2000, window_size=2000, K=40) + hparams = HyperParamsTS(; seed=1234, η=5e-4, epochs=2000, window_size=2000, K=25) rec = Chain(RNN(7 => 3, relu), LayerNorm(3)) gen = Chain(Dense(4, 10, relu), Dropout(0.05), Dense(10, 7, identity)) @@ -1148,7 +1149,7 @@ end mse = 0.0 mae = 0.0 count = 0 - τ = 720 + τ = 336 for ts in (1:length(collect(loaderXtrain)[1][1])) #τ = 96 s = 0 diff --git a/src/CustomLossFunction.jl b/src/CustomLossFunction.jl index fa73265..81aa3b2 100644 --- a/src/CustomLossFunction.jl +++ b/src/CustomLossFunction.jl @@ -22,7 +22,7 @@ function _leaky_relu(ŷ::Matrix{T}, y::T) where {T<:AbstractFloat} end; """ -`ψₘ(y::T, m::Int64) where {T<:AbstractFloat}`` +`ψₘ(y::T, m::Int64) where {T<:AbstractFloat}` Calculate the bump function centered at `m`, implemented as a Gaussian function. @@ -193,7 +193,7 @@ end; end; """ - `invariant_statistical_loss(model, data, hparams)`` + `invariant_statistical_loss(model, data, hparams)` Custom loss function for the model. model is a Flux neuronal network model, data is a loader Flux object and hparams is a HyperParams object. @@ -299,7 +299,7 @@ function get_better_K(nn_model, data, min_K, hparams) end; """ - `auto_invariant_statistical_loss(model, data, hparams)`` + `auto_invariant_statistical_loss(model, data, hparams)` Custom loss function for the model. @@ -457,6 +457,63 @@ function MAE(y::Vector{<:Real}, ŷ::Vector{<:Real})::Real return mae end +# Define a combined model that incorporates both RNN and Generative models +mutable struct CondTimeGenModel + seq_model::Chain + gen_model::Chain + state + noise +end + +# Forward pass for the combined model +function (m::CondTimeGenModel)(x) + m.state = m.seq_model(x) + batch_size = size(x, 2) + noise = Float32.(rand(m.noise, (1, batch_size))) + gen_input = vcat(noise, m.state) + return m.gen_model(gen_input) +end + +function (m::CondTimeGenModel)(s, x, K) + batch_size = size(x, 2) + noise = Float32.(rand(hparams.noise_model, K, batch_size)) + gen_input = [[vcat(noise[i, j], s[:, j]) for i in 1:K] for j in 1:batch_size] + return [m.gen_model(gen_input[i]) for i in 1:batch_size] +end + +function (m::CondTimeGenModel)(s, x, K) + batch_size = size(x, 2) + noise = Float32.(rand(hparams.noise_model, K, batch_size)) + # Prepare s for each K and batch_size + s_repeated = repeat(s; inner=(1, K)) # Repeat s K times along the second dimension + + # Reshape noise to match s's repeated structure and concatenate + noise_reshaped = reshape(noise, (:, batch_size * K)) # Reshape noise for concatenation + gen_input = vcat(noise_reshaped, s_repeated) # Concatenate along the first dimension + + # Assuming m.gen_model can process the concatenated input in one call + # This part might need adjustment based on the actual implementation of m.gen_model + return m.gen_model(gen_input) +end + +function generated_fictitious(m::CondTimeGenModel, x, K) + m.state = m.seq_model(x) + batch_size = size(x, 2) + noise = Float32.(rand(m.noise, K, batch_size)) + # Prepare s for each K and batch_size + s_repeated = repeat(m.state; inner=(1, K)) # Repeat s K times along the second dimension + + # Reshape noise to match s's repeated structure and concatenate + noise_reshaped = reshape(noise, (:, batch_size * K)) # Reshape noise for concatenation + gen_input = vcat(noise_reshaped, s_repeated) # Concatenate along the first dimension + + # Assuming m.gen_model can process the concatenated input in one call + # This part might need adjustment based on the actual implementation of m.gen_model + return m.gen_model(gen_input) +end + +Flux.@functor CondTimeGenModel + """ ts_invariant_statistical_loss(rec, gen, Xₜ, Xₜ₊₁, hparams) @@ -482,25 +539,24 @@ This function train a model for time series data with statistical invariance los The function iterates through the provided time series data (`Xₜ` and `Xₜ₊₁`) in batches, with a sliding window of size `window_size`. """ -function ts_invariant_statistical_loss(rec, gen, Xₜ, Xₜ₊₁, hparams) +function ts_invariant_statistical_loss(model::CondTimeGenModel, Xₜ, Xₜ₊₁, hparams) losses = [] - optim_rec = Flux.setup(Flux.Adam(hparams.η), rec) - optim_gen = Flux.setup(Flux.Adam(hparams.η), gen) + optim = Flux.setup(Flux.Adam(hparams.η), model) for (batch_Xₜ, batch_Xₜ₊₁) in zip(Xₜ, Xₜ₊₁) - Flux.reset!(rec) + Flux.reset!(model) for j in (0:(hparams.window_size):(length(batch_Xₜ) - hparams.window_size)) - loss, grads = Flux.withgradient(rec, gen) do rec, gen + loss, grads = Flux.withgradient(model) do model aₖ = zeros(hparams.K + 1) - s = rec(batch_Xₜ[(j + 1):(j + hparams.window_size)]') - for i in 1:(hparams.window_size) - xₖ = rand(hparams.noise_model, hparams.K) - yₖ = hcat([gen(vcat(x, s[:, i])) for x in xₖ]...) - aₖ += generate_aₖ(yₖ, batch_Xₜ₊₁[j + i]) - end + yₖ = generated_fictitious( + model, batch_Xₜ[(j + 1):(j + hparams.window_size)]', hparams.K + ) + aₖ = sum([ + generate_aₖ(yₖ[:, i:(i + hparams.K)], batch_Xₜ₊₁[i]) for + i in 1:(hparams.K):(hparams.window_size) + ]) scalar_diff(aₖ ./ sum(aₖ)) end - Flux.update!(optim_rec, rec, grads[1]) - Flux.update!(optim_gen, gen, grads[2]) + Flux.update!(optim, model, grads[1]) push!(losses, loss) end end diff --git a/src/ISL.jl b/src/ISL.jl index 4affffe..e754a8c 100644 --- a/src/ISL.jl +++ b/src/ISL.jl @@ -33,5 +33,6 @@ export _sigmoid, HyperParamsTS, ts_invariant_statistical_loss_one_step_prediction, ts_invariant_statistical_loss, - ts_invariant_statistical_loss_multivariate + ts_invariant_statistical_loss_multivariate, + CondTimeGenModel end