From 6b171f81a99c7ed262292c06775e4f5d5b3185a0 Mon Sep 17 00:00:00 2001 From: JinraeKim Date: Fri, 12 May 2023 15:48:00 +0900 Subject: [PATCH 1/2] Hotfix for NormalisedApproximator --- src/approximators/normalized_approximators.jl | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/src/approximators/normalized_approximators.jl b/src/approximators/normalized_approximators.jl index fe5900e..1ffaa2e 100644 --- a/src/approximators/normalized_approximators.jl +++ b/src/approximators/normalized_approximators.jl @@ -28,34 +28,37 @@ function MaxAbsNormalisedApproximator( dataset::DecisionMakingDataset, ) (; conditions, decisions, costs) = dataset - condition_max_abs = maximum(abs.(hcat(conditions...)), dims=length(size(conditions))) - decision_max_abs = maximum(abs.(hcat(decisions...)), dims=length(size(decisions))) - cost_max_abs = maximum(abs.(hcat(costs...)), dims=length(size(costs))) + c = hcat(conditions...) + d = hcat(decisions...) + J = hcat(costs...) + condition_max_abs = maximum(abs.(c), dims=length(size(c))) + decision_max_abs = maximum(abs.(d), dims=length(size(d))) + cost_max_abs = maximum(abs.(J), dims=length(size(J))) MaxAbsNormalisedApproximator(network, condition_max_abs, decision_max_abs, cost_max_abs) end function (nn::NormalisedApproximator)(x, u) (; network, condition_max_abs, decision_max_abs, cost_max_abs) = nn - x = normalise(nn, x, :condition) - u = normalise(nn, u, :decision) - f = network(x, u) - f = unnormalise(nn, f, :cost) - return f + x_new = normalise(nn, x, :condition) + u_new = normalise(nn, u, :decision) + f = network(x_new, u_new) + f_new = unnormalise(nn, f, :cost) + return f_new end function normalise(nn::MaxAbsNormalisedApproximator, z, which::Symbol) @assert which in (:condition, :decision, :cost) factor = getproperty(nn, Symbol(String(which) * "_max_abs")) - z = factor != nothing ? z ./ factor : z - return z + z_new = factor != nothing ? z ./ factor : z + return z_new end function unnormalise(nn::MaxAbsNormalisedApproximator, z, which::Symbol) @assert which in (:condition, :decision, :cost) factor = getproperty(nn, Symbol(String(which) * "_max_abs")) - z = factor != nothing ? z .* factor : z - return z + z_new = factor != nothing ? z .* factor : z + return z_new end From 925233976f0c9e17392ed7814f0a55ac4f9e7672 Mon Sep 17 00:00:00 2001 From: JinraeKim Date: Fri, 12 May 2023 15:49:21 +0900 Subject: [PATCH 2/2] wip --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 16f99a9..a878f31 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ParametrisedConvexApproximators" uuid = "668502ff-1e8f-42bf-95c7-24f1e819f537" authors = ["JinraeKim and contributors"] -version = "0.2.1" +version = "0.2.2" [deps] Convex = "f65535da-76fb-5f13-bab9-19810c17039a"