From acbc932efca036aedcf4b6be3dfc0528d345eafe Mon Sep 17 00:00:00 2001 From: turquoisedragon2926 Date: Mon, 15 Apr 2024 09:22:59 -0400 Subject: [PATCH] test --- src/models/DFNO_3D/forward.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/models/DFNO_3D/forward.jl b/src/models/DFNO_3D/forward.jl index e8d19c3..294d332 100644 --- a/src/models/DFNO_3D/forward.jl +++ b/src/models/DFNO_3D/forward.jl @@ -43,7 +43,8 @@ function forward(model::Model, θ, x::Any) x2 = (model.convs[i](θ) * x) + model.sconv_biases[i](θ) x = vec(x1) + vec(x2) - x = reshape(x, (model.config.nc_lift, :)) + x = reshape(x, (model.config.nc_lift, model.config.nt * model.config.nx ÷ model.config.partition[1], model.config.ny * model.config.nz ÷ model.config.partition[2], :)) + # x = reshape(x, (model.config.nc_lift, :)) N = ndims(x) ϵ = 1f-5 @@ -72,6 +73,7 @@ function forward(model::Model, θ, x::Any) x = relu.(x) end end + x = reshape(x, (model.config.nc_lift, :)) x = (model.projects[1](θ) * x) + model.biases[2](θ) x = relu.(x)