diff --git a/scaling.jl b/scaling.jl index 4e1328c..2d1e9a2 100644 --- a/scaling.jl +++ b/scaling.jl @@ -30,9 +30,9 @@ config = ARGS[8] # For scaling tests, use 4 modes, training use 25% modes -modesx = 8 # max(dimx÷32, 4) -modesy = 8 # max(dimy÷32, 4) -modesz = 8 # max(dimz÷32, 4) +modesx = 4 # max(dimx÷32, 4) +modesy = 4 # max(dimy÷32, 4) +modesz = 4 # max(dimz÷32, 4) modest = 4 # max(dimt÷32, 4) (gpus > 64) && (modesy = modesy * 2) diff --git a/src/models/DFNO_3D/model.jl b/src/models/DFNO_3D/model.jl index 7e190d0..030b8f0 100644 --- a/src/models/DFNO_3D/model.jl +++ b/src/models/DFNO_3D/model.jl @@ -82,14 +82,22 @@ mutable struct Model fourier_z = ParDFT(Complex{T}, config.nz) fourier_t = ParDFT(T, config.nt) - # Build restrictions to low-frequency modes - restrict_x = ParRestriction(Complex{T}, Range(fourier_x), [1:mx, config.nx-mx+1:config.nx]) - restrict_y = ParRestriction(Complex{T}, Range(fourier_y), [1:my, config.ny-my+1:config.ny]) - restrict_z = ParRestriction(Complex{T}, Range(fourier_z), [1:mz, config.nz-mz+1:config.nz]) - restrict_t = ParRestriction(Complex{T}, Range(fourier_t), [1:mt]) - - input_shape = (config.nc_lift, config.mt*config.mx, config.my*config.mz) - weight_shape = (config.nc_lift, config.nc_lift, config.mt*config.mx, config.my*config.mz) + restrict_x = ParRestriction(Complex{T}, Range(fourier_x), [1:config.mx, config.nx-config.mx+1:config.nx]) + restrict_y = ParRestriction(Complex{T}, Range(fourier_y), [1:config.my, config.ny-config.my+1:config.ny]) + restrict_z = ParRestriction(Complex{T}, Range(fourier_z), [1:config.mz, config.nz-config.mz+1:config.nz]) + restrict_t = ParRestriction(Complex{T}, Range(fourier_t), [1:config.mt]) + + input_shape = (config.nc_lift, config.mt*(2*config.mx), (2*config.my)*(2*config.mz)) + weight_shape = (config.nc_lift, config.nc_lift, config.mt*(2*config.mx), (2*config.my)*(2*config.mz)) + + # # Build restrictions to low-frequency modes + # restrict_x = ParRestriction(Complex{T}, Range(fourier_x), [1:mx, config.nx-mx+1:config.nx]) + # restrict_y = ParRestriction(Complex{T}, Range(fourier_y), [1:my, config.ny-my+1:config.ny]) + # restrict_z = ParRestriction(Complex{T}, Range(fourier_z), [1:mz, config.nz-mz+1:config.nz]) + # restrict_t = ParRestriction(Complex{T}, Range(fourier_t), [1:mt]) + + # input_shape = (config.nc_lift, config.mt*config.mx, config.my*config.mz) + # weight_shape = (config.nc_lift, config.nc_lift, config.mt*config.mx, config.my*config.mz) input_order = (1, 2, 3) weight_order = (1, 4, 2, 3)