From cd03019e51978d396beb5dd3cd0c241fc713e654 Mon Sep 17 00:00:00 2001 From: turquoisedragon2926 Date: Mon, 15 Apr 2024 09:12:02 -0400 Subject: [PATCH] test --- scaling.jl | 6 +++--- src/models/DFNO_3D/model.jl | 20 +++++++------------- 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/scaling.jl b/scaling.jl index 2d1e9a2..4e1328c 100644 --- a/scaling.jl +++ b/scaling.jl @@ -30,9 +30,9 @@ config = ARGS[8] # For scaling tests, use 4 modes, training use 25% modes -modesx = 4 # max(dimx÷32, 4) -modesy = 4 # max(dimy÷32, 4) -modesz = 4 # max(dimz÷32, 4) +modesx = 8 # max(dimx÷32, 4) +modesy = 8 # max(dimy÷32, 4) +modesz = 8 # max(dimz÷32, 4) modest = 4 # max(dimt÷32, 4) (gpus > 64) && (modesy = modesy * 2) diff --git a/src/models/DFNO_3D/model.jl b/src/models/DFNO_3D/model.jl index 1080954..7e190d0 100644 --- a/src/models/DFNO_3D/model.jl +++ b/src/models/DFNO_3D/model.jl @@ -83,19 +83,13 @@ mutable struct Model fourier_t = ParDFT(T, config.nt) # Build restrictions to low-frequency modes - restrict_x = ParRestriction(Complex{T}, Range(fourier_x), [1:config.mx, config.nx-config.mx+1:config.nx]) - restrict_y = ParRestriction(Complex{T}, Range(fourier_y), [1:config.my, config.ny-config.my+1:config.ny]) - restrict_z = ParRestriction(Complex{T}, Range(fourier_z), [1:config.mz, config.nz-config.mz+1:config.nz]) - restrict_t = ParRestriction(Complex{T}, Range(fourier_t), [1:config.mt]) - # restrict_x = ParRestriction(Complex{T}, Range(fourier_x), unique_range([1:mx, config.nx-mx+1:config.nx])) - # restrict_y = ParRestriction(Complex{T}, Range(fourier_y), unique_range([1:my, config.ny-my+1:config.ny])) - # restrict_z = ParRestriction(Complex{T}, Range(fourier_z), unique_range([1:mz, config.nz-mz+1:config.nz])) - # restrict_t = ParRestriction(Complex{T}, Range(fourier_t), unique_range([1:mt])) - - input_shape = (config.nc_lift, config.mt*(2*config.mx), (2*config.my)*(2*config.mz)) - weight_shape = (config.nc_lift, config.nc_lift, config.mt*(2*config.mx), (2*config.my)*(2*config.mz)) - # input_shape = (config.nc_lift, config.mt*config.mx, config.my*config.mz) - # weight_shape = (config.nc_lift, config.nc_lift, config.mt*config.mx, config.my*config.mz) + restrict_x = ParRestriction(Complex{T}, Range(fourier_x), [1:mx, config.nx-mx+1:config.nx]) + restrict_y = ParRestriction(Complex{T}, Range(fourier_y), [1:my, config.ny-my+1:config.ny]) + restrict_z = ParRestriction(Complex{T}, Range(fourier_z), [1:mz, config.nz-mz+1:config.nz]) + restrict_t = ParRestriction(Complex{T}, Range(fourier_t), [1:mt]) + + input_shape = (config.nc_lift, config.mt*config.mx, config.my*config.mz) + weight_shape = (config.nc_lift, config.nc_lift, config.mt*config.mx, config.my*config.mz) input_order = (1, 2, 3) weight_order = (1, 4, 2, 3)