Skip to content

Commit

Permalink
Merge pull request #263 from slimgroup/probase
Browse files Browse the repository at this point in the history
Performance improvements
  • Loading branch information
mloubout authored Jul 24, 2024
2 parents adcb13a + 926cdda commit 9458f55
Show file tree
Hide file tree
Showing 25 changed files with 343 additions and 314 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci-judi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ jobs:

- name: Set julia python
run: |
echo "PYTHON=$(which python3)" >> $GITHUB_ENV
PYTHON=$(which python3) julia -e 'using Pkg;Pkg.add("PyCall");Pkg.build("PyCall")'
- name: Build JUDI
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/ci-op.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ jobs:

- name: Set julia python
run: |
echo "PYTHON=$(which python3)" >> $GITHUB_ENV
PYTHON=$(which python3) julia -e 'using Pkg;Pkg.add("PyCall");Pkg.build("PyCall")'
- name: Build JUDI
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "JUDI"
uuid = "f3b833dc-6b2e-5b9c-b940-873ed6319979"
authors = ["Philipp Witte, Mathias Louboutin"]
version = "3.4.4"
version = "3.4.5"

This comment has been minimized.

Copy link
@mloubout

mloubout Jul 24, 2024

Author Member

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down Expand Up @@ -51,4 +51,4 @@ test = ["Aqua", "JLD2", "Printf", "Test", "TimerOutputs", "Flux"]
[weakdeps]
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
25 changes: 10 additions & 15 deletions deps/build.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,32 @@ struct DevitoException <: Exception
msg::String
end

python = PyCall.pyprogramname

try
pk = pyimport("pkg_resources")
pk = try
pyimport("pkg_resources")
catch e
Cmd([python, "-m", "pip", "install", "--user", "setuptools"])
run(cmd)
pk = pyimport("pkg_resources")
run(PyCall.python_cmd(`-m pip install --user setuptools`))
pyimport("pkg_resources")
end

################## Devito ##################
# pip command
cmd = Cmd([python, "-m", "pip", "install", "-U", "--user", "devito[extras,tests]>=4.4"])
dvver = "4.8.10"
cmd = PyCall.python_cmd(`-m pip install --user devito\[extras,tests\]\>\=$(dvver)`)

try
dv_ver = split(pk.get_distribution("devito").version, "+")[1]
if cmp(dv_ver, "4.8.7") < 0
@info "Devito version too low, updating to >=4.8.7"
dv_ver = VersionNumber(split(pk.get_distribution("devito").version, "+")[1])
if dv_ver < VersionNumber(dvver)
@info "Devito version too low, updating to >=$(dvver)"
run(cmd)
end
catch e
@info "Devito not installed, installing with PyCall python"
run(cmd)
end


################## Matplotlib ##################
# pip command
cmd = Cmd([python, "-m", "pip", "install", "--user", "matplotlib"])
try
mpl = pyimport("matplotlib")
catch e
run(cmd)
run(PyCall.python_cmd(`-m pip install --user matplotlib`))
end
2 changes: 0 additions & 2 deletions docs/src/helper.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,6 @@ remove_out_of_bounds_receivers
```@docs
devito_model
setup_grid
pad_sizes
pad_array
remove_padding
convertToCell
process_input_data
Expand Down
6 changes: 4 additions & 2 deletions examples/scripts/fwi_example_2D.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#

using Statistics, Random, LinearAlgebra
using JUDI, SlimOptim, HDF5, SegyIO, PyPlot
using JUDI, HDF5, SegyIO, SlimOptim, SlimPlotting

# Load starting model
n,d,o,m0 = read(h5open("$(JUDI.JUDI_DATA)/overthrust_model.h5","r"), "n", "d", "o", "m0")
Expand Down Expand Up @@ -66,4 +66,6 @@ for j=1:niterations
model0.m .= proj(model0.m .+ step .* p)
end

figure(); imshow(sqrt.(1f0./adjoint(model0.m))); title("FWI with SGD")
figure()
plot_velocity(model0.m'.^(-.5))
title("FWI with SGD")
53 changes: 14 additions & 39 deletions examples/scripts/modeling_basic_2D.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
#' This example is converted to a markdown file for the documentation.

#' # Import JUDI, Linear algebra utilities and Plotting
using JUDI, PyPlot, LinearAlgebra
using JUDI, LinearAlgebra, SlimPlotting

#+ echo = false; results = "hidden"
close("all")
imcmap = "cet_CET_L1"
dcmap = "PuOr"

#' # Create a JUDI model structure
#' In JUDI, a `Model` structure contains the grid information (origin, spacing, number of gridpoints)
Expand Down Expand Up @@ -91,7 +93,7 @@ q = judiVector(srcGeometry, wavelet)
#' condition for the propagation.

# Setup options
opt = Options(subsampling_factor=2, space_order=32)
opt = Options(subsampling_factor=2, space_order=16, free_surface=false)

#' Linear Operators
#' The core idea behind JUDI is to abstract seismic inverse problems in term of linear algebra. In its simplest form, seismic inversion can be formulated as
Expand Down Expand Up @@ -119,10 +121,7 @@ dobs = Pr*F*adjoint(Ps)*q

#' Plot the shot record
fig = figure()
imshow(dobs.data[1], vmin=-1, vmax=1, cmap="PuOr", extent=[xrec[1], xrec[end], timeD/1000, 0], aspect="auto")
xlabel("Receiver position (m)")
ylabel("Time (s)")
title("Synthetic data")
plot_sdata(dobs[1]; new_fig=false, name="Synthetic data", cmap=dcmap)
display(fig)

#' Because we have abstracted the linear algebra, we can solve the adjoint wave-equation as well
Expand Down Expand Up @@ -152,19 +151,13 @@ rtm = adjoint(J)*dD

#' We show the linearized data.
fig = figure()
imshow(dD.data[1], vmin=-1, vmax=1, cmap="PuOr", extent=[xrec[1], xrec[end], timeD/1000, 0], aspect="auto")
xlabel("Receiver position (m)")
ylabel("Time (s)")
title("Linearized data")
plot_sdata(dobs[1]; new_fig=false, name="Linearized data", cmap=dcmap)
display(fig)


#' And the RTM image
fig = figure()
imshow(rtm', vmin=-1e2, vmax=1e2, cmap="Greys", extent=[0, (n[1]-1)*d[1], (n[2]-1)*d[2], 0 ], aspect="auto")
xlabel("Lateral position(m)")
ylabel("Depth (m)")
title("RTM image")
plot_simage(rtm'; new_fig=false, name="RTM image", cmap=imcmap)
display(fig)

#' ## Inversion utility functions
Expand All @@ -185,10 +178,7 @@ f, g = fwi_objective(model0, q, dobs; options=opt)

#' Plot gradient
fig = figure()
imshow(g', vmin=-1e2, vmax=1e2, cmap="Greys", extent=[0, (n[1]-1)*d[1], (n[2]-1)*d[2], 0 ], aspect="auto")
xlabel("Lateral position(m)")
ylabel("Depth (m)")
title("FWI gradient")
plot_simage(g'; new_fig=false, name="FWI gradient", cmap=imcmap)
display(fig)


Expand All @@ -199,17 +189,11 @@ fjn, gjn = lsrtm_objective(model0, q, dobs, dm; nlind=true, options=opt)

#' Plot gradients
fig = figure()
imshow(gj', vmin=-1, vmax=1, cmap="Greys", extent=[0, (n[1]-1)*d[1], (n[2]-1)*d[2], 0 ], aspect="auto")
xlabel("Lateral position(m)")
ylabel("Depth (m)")
title("LSRTM gradient")
plot_simage(gj'; new_fig=false, name="LSRTM gradient", cmap=imcmap, cbar=true)
display(fig)

fig = figure()
imshow(gjn', vmin=-1, vmax=1, cmap="Greys", extent=[0, (n[1]-1)*d[1], (n[2]-1)*d[2], 0 ], aspect="auto")
xlabel("Lateral position(m)")
ylabel("Depth (m)")
title("LSRTM gradient with background data substracted")
plot_simage(gjn'; new_fig=false, name="LSRTM gradient with background data substracted", cmap=imcmap, cbar=true)
display(fig)

#' By extension, lsrtm_objective is the same as fwi_objecive when `dm` is zero
Expand All @@ -218,13 +202,10 @@ display(fig)
#' OMP_NUM_THREADS=1 (no parllelism) produces the exact (difference == 0) same result
#' gjn2 == g
fjn2, gjn2 = lsrtm_objective(model0, q, dobs, 0f0.*dm; nlind=true, options=opt)
fig = figure()

#' Plot gradient
imshow(gjn2', vmin=-1e2, vmax=1e2, cmap="Greys", extent=[0, (n[1]-1)*d[1], (n[2]-1)*d[2], 0 ], aspect="auto")
xlabel("Lateral position(m)")
ylabel("Depth (m)")
title("LSRTM gradient with zero perturbation")
fig = figure()
plot_simage(gjn2'; new_fig=false, name="LSRTM gradient with zero perturbation", cmap=imcmap)
display(fig)


Expand All @@ -236,15 +217,9 @@ f, gmf = twri_objective(model0, q, dobs, nothing; options=Options(frequencies=[[

#' Plot gradients
fig = figure()
imshow(gm', vmin=-1, vmax=1, cmap="Greys", extent=[0, (n[1]-1)*d[1], (n[2]-1)*d[2], 0 ], aspect="auto")
xlabel("Lateral position(m)")
ylabel("Depth (m)")
title("TWRI gradient w.r.t m")
plot_simage(gm'; new_fig=false, name="TWRI gradient w.r.t m", cmap=imcmap)
display(fig)

fig = figure()
imshow(gy.data[1], vmin=-1e2, vmax=1e2, cmap="PuOr", extent=[xrec[1], xrec[end], timeD/1000, 0], aspect="auto")
xlabel("Receiver position (m)")
ylabel("Time (s)")
title("TWRI gradient w.r.t y")
plot_sdata(gy[1]; new_fig=false, name="TWRI gradient w.r.t y", cmap=dcmap)
display(fig)
8 changes: 3 additions & 5 deletions src/JUDI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ module JUDI
export JUDIPATH, set_verbosity, ftp_data, get_serial, set_serial, set_parallel
JUDIPATH = dirname(pathof(JUDI))


# Only needed if extension not available (julia < 1.9)
if !isdefined(Base, :get_extension)
using Requires
Expand Down Expand Up @@ -102,10 +101,12 @@ function _worker_pool()
return nothing
end
p = default_worker_pool()
pool = length(p) < 2 ? nothing : p
pool = nworkers(p) < 2 ? nothing : p
return pool
end

nworkers(::Any) = length(workers())

_TFuture = Future
_verbose = false
_devices = []
Expand Down Expand Up @@ -178,9 +179,6 @@ function __init__()
copy!(devito, pyimport("devito"))
# Initialize lock at session start
PYLOCK[] = ReentrantLock()

# Prevent autopadding to use external allocator
set_devito_config("autopadding", false)

# Make sure there is no conflict for the cuda init thread with CUDA.jl
if get(ENV, "DEVITO_PLATFORM", "") == "nvidiaX"
Expand Down
27 changes: 20 additions & 7 deletions src/TimeModeling/Modeling/distributed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,6 @@ end
x
end

"""
safe_gc()
Generic GC, compatible with different julia versions of it.
"""
safe_gc() = try Base.GC.gc(); catch; gc() end

"""
local_reduce!(future, other)
Expand Down Expand Up @@ -64,9 +58,28 @@ Adapted from `DistributedOperations.jl` (MIT license). Striped from custom types
with different reduction functions.
"""
function reduce!(futures::Vector{_TFuture})
isnothing(_worker_pool()) && return reduce_all_workers!(futures)
# Number of parallel workers
nwork = nworkers(_worker_pool())
nf = length(futures)
# Reduction batch. We want to avoid finished task to hang waiting for the
# binary tree reduction to reach their index holding memory.
bsize = min(nwork, nf)
# First batch
res = reduce_all_workers!(futures[1:bsize])
# Loop until all reduced
for i = bsize+1:bsize:nf
last = min(nf, i + bsize - 1)
single_reduce!(res, reduce_all_workers!(futures[i:last]))
end
return res
end


function reduce_all_workers!(futures::Vector{_TFuture})
# Get length a next power of two for binary reduction
M = length(futures)
L = round(Int,log2(prevpow(2,M)))
L = round(Int, log2(prevpow(2,M)))
m = 2^L
# remainder
R = M - m
Expand Down
2 changes: 2 additions & 0 deletions src/TimeModeling/Modeling/misfit_fg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ function _multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes,
data_precon=nothing, model_precon=LinearAlgebra.I)
GC.gc(true)
devito.clear_cache()

# assert this is for single source LSRTM
@assert source.nsrc == 1 "Multiple sources are used in a single-source fwi_objective"
@assert dObs.nsrc == 1 "Multiple-source data is used in a single-source fwi_objective"
Expand Down Expand Up @@ -63,6 +64,7 @@ function _multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes,

length(options.frequencies) == 0 ? freqs = nothing : freqs = options.frequencies
IT = illum ? (PyArray, PyArray) : (PyObject, PyObject)

@juditime "Python call to J_adjoint" begin
argout = rlock_pycall(ac."J_adjoint", Tuple{Float32, PyArray, IT...}, modelPy,
src_coords, qIn, rec_coords, dObserved, t_sub=options.subsampling_factor,
Expand Down
20 changes: 4 additions & 16 deletions src/TimeModeling/Modeling/propagation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ the pool is empty, a standard loop and accumulation is ran. If the pool is a jul
any custom Distributed pool, the loop is distributed via `remotecall` followed by are binary tree remote reduction.
"""
function run_and_reduce(func, pool, nsrc, arg_func::Function; kw=nothing)
# Allocate devices
_set_devices!()
# Run distributed loop
res = Vector{_TFuture}(undef, nsrc)
for i = 1:nsrc
Expand All @@ -44,21 +42,11 @@ function run_and_reduce(func, ::Nothing, nsrc, arg_func::Function; kw=nothing)
kw_loc = isnothing(kw) ? Dict() : kw(i)
next = func(arg_func(i)...; kw_loc...)
end
single_reduce!(out, next)
end
out
end

function _set_devices!()
ndevices = length(_devices)
if ndevices < 2
return
end
asyncmap(enumerate(workers())) do (pi, p)
remotecall_wait(p) do
pyut.set_device_ids(_devices[pi % ndevices + 1])
@juditime "Reducting $(func) for src $(i)" begin
single_reduce!(out, next)
end
end
out
end

_prop_fw(::judiPropagator{T, O}) where {T, O} = true
Expand Down Expand Up @@ -112,7 +100,7 @@ function multi_src_fg!(G, model, q, dobs, dm; options=Options(), ms_func=multi_s
kw_func = i -> Dict(:illum=> illum, Dict(k => kw_i(v, i) for (k, v) in kw)...)
# Distribute source
res = run_and_reduce(ms_func, pool, nsrc, arg_func; kw=kw_func)
f, g = update_illum(res, model, :adjoint_born)
res = update_illum(res, model, :adjoint_born)
f, g = as_vec(res, Val(options.return_array))
G .+= g
return f
Expand Down
2 changes: 1 addition & 1 deletion src/TimeModeling/Modeling/twri_objective.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ function _twri_objective(model_full::AbstractModel, source::judiVector, dObs::ju
dtComp = convert(Float32, modelPy."critical_dt")

# Extrapolate input data to computational grid
qIn = time_resample(source.data[1], source.geometry, dtComp)
qIn = time_resample(make_input(source), source.geometry, dtComp)
dObserved = time_resample(make_input(dObs), dObs.geometry, dtComp)

if isnothing(y)
Expand Down
Loading

1 comment on commit 9458f55

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/111687

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v3.4.5 -m "<description of version>" 9458f551e06ac43e580b08e088640916d441c046
git push origin v3.4.5

Please sign in to comment.