Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
Richard2926 committed Apr 15, 2024
1 parent 08d87ad commit cd03019
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 16 deletions.
6 changes: 3 additions & 3 deletions scaling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ config = ARGS[8]

# For scaling tests, use 4 modes, training use 25% modes

modesx = 4 # max(dimx÷32, 4)
modesy = 4 # max(dimy÷32, 4)
modesz = 4 # max(dimz÷32, 4)
modesx = 8 # max(dimx÷32, 4)
modesy = 8 # max(dimy÷32, 4)
modesz = 8 # max(dimz÷32, 4)
modest = 4 # max(dimt÷32, 4)

(gpus > 64) && (modesy = modesy * 2)
Expand Down
20 changes: 7 additions & 13 deletions src/models/DFNO_3D/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,13 @@ mutable struct Model
fourier_t = ParDFT(T, config.nt)

# Build restrictions to low-frequency modes
restrict_x = ParRestriction(Complex{T}, Range(fourier_x), [1:config.mx, config.nx-config.mx+1:config.nx])
restrict_y = ParRestriction(Complex{T}, Range(fourier_y), [1:config.my, config.ny-config.my+1:config.ny])
restrict_z = ParRestriction(Complex{T}, Range(fourier_z), [1:config.mz, config.nz-config.mz+1:config.nz])
restrict_t = ParRestriction(Complex{T}, Range(fourier_t), [1:config.mt])
# restrict_x = ParRestriction(Complex{T}, Range(fourier_x), unique_range([1:mx, config.nx-mx+1:config.nx]))
# restrict_y = ParRestriction(Complex{T}, Range(fourier_y), unique_range([1:my, config.ny-my+1:config.ny]))
# restrict_z = ParRestriction(Complex{T}, Range(fourier_z), unique_range([1:mz, config.nz-mz+1:config.nz]))
# restrict_t = ParRestriction(Complex{T}, Range(fourier_t), unique_range([1:mt]))

input_shape = (config.nc_lift, config.mt*(2*config.mx), (2*config.my)*(2*config.mz))
weight_shape = (config.nc_lift, config.nc_lift, config.mt*(2*config.mx), (2*config.my)*(2*config.mz))
# input_shape = (config.nc_lift, config.mt*config.mx, config.my*config.mz)
# weight_shape = (config.nc_lift, config.nc_lift, config.mt*config.mx, config.my*config.mz)
restrict_x = ParRestriction(Complex{T}, Range(fourier_x), [1:mx, config.nx-mx+1:config.nx])
restrict_y = ParRestriction(Complex{T}, Range(fourier_y), [1:my, config.ny-my+1:config.ny])
restrict_z = ParRestriction(Complex{T}, Range(fourier_z), [1:mz, config.nz-mz+1:config.nz])
restrict_t = ParRestriction(Complex{T}, Range(fourier_t), [1:mt])

input_shape = (config.nc_lift, config.mt*config.mx, config.my*config.mz)
weight_shape = (config.nc_lift, config.nc_lift, config.mt*config.mx, config.my*config.mz)

input_order = (1, 2, 3)
weight_order = (1, 4, 2, 3)
Expand Down

0 comments on commit cd03019

Please sign in to comment.