Skip to content

Commit

Permalink
add data/model_precon support to lsrtm/fwi objectives
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Mar 22, 2024
1 parent 940ae55 commit 3a9eb54
Show file tree
Hide file tree
Showing 14 changed files with 245 additions and 74 deletions.
36 changes: 33 additions & 3 deletions docs/src/preconditioners.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Seismic Preconditioners

JUDI provides a selected number of preconditioners known to be beneficial to FWI and RTM. We welcome additional preconditionners from the community. Additionnaly, any JOLI operator can be used as a preconditiner in conbination with JUDI operator thanks to the fundamental interface between JUDI and JOLI.
JUDI provides a selected number of preconditioners known to be beneficial to FWI and RTM. We welcome additional preconditioners from the community. Additionnaly, any JOLI operator can be used as a preconditiner in conbination with JUDI operator thanks to the fundamental interface between JUDI and JOLI.

```@contents
Pages = ["preconditioners.md"]
Expand Down Expand Up @@ -74,7 +74,7 @@ m_mute = I'*vec(m)

## Data preconditioners

These preconditioners are design to act on the shot records (data). These preconditioners are indexable by source number so that working with a subset of shot is trivial to implement.
These preconditioners are design to act on the shot records (data). These preconditioners are indexable by source number so that working with a subset of shot is trivial to implement. Additionally, all [DataPreconditionner](@ref) are compatible with out-of-core JUDI objects such as `judiVector{SeisCon}` so that the preconditioner is only applied to single shot data at propagation time.


### Data topmute
Expand All @@ -100,4 +100,34 @@ A `TimeDifferential{K}` is a linear operator that implements a time derivative (

```@docs
TimeDifferential
```
```

## Inversion wrappers

For large scale and practical cases, the inversions wrappers [fwi_objective](@ref) and [lsrtm_objective](@ref) are used to minimize the number of PDE solves. Those wrapper support the use of preconditioner as well for better results.

**Usage:**

For fwi, you can use the `data_precon` keyword argument to be applied to the residual (the preconditioner is applied to both the field and synthetic data to ensure better misfit):

```julia
fwi_objective(model, q, dobs; data_precon=precon)
```

where `precon` can be:

- A single [DataPreconditionner](@ref)
- A list/tuple of [DataPreconditionner](@ref)
- A product of [DataPreconditionner](@ref)

Similarly, for LSRTM, you can use the `model_precon` keyword argument to be applied to the perturbation `dm` and the `data_precon` keyword argument to be applied to the residual:

```julia
lsrtm_objective(model, q, dobs, dm; model_precon=dPrec, data_precon=dmPrec)
```

where `dPrec` and `dmPrec` can be:

- A single preconditioner ([DataPreconditionner](@ref) for `data_precon` and [ModelPreconditionner](@ref) for `model_precon`)
- A list/tuple of preconditioners ([DataPreconditionner](@ref) for `data_precon` and ModelPreconditionner](@ref) for `model_precon`)
- A product of preconditioners ([DataPreconditionner](@ref) for `data_precon` and ModelPreconditionner](@ref) for `model_precon`)
5 changes: 4 additions & 1 deletion examples/scripts/lsrtm_2D.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ lsqr_sol = zeros(Float32, prod(model0.n))
dinv = d_lin[indsrc]
Jinv = J[indsrc]

lsqr!(lsqr_sol, Ml[indsrc]*Jinv*Mr, Ml[indsrc]*dinv; maxiter=niter)
Jp = Ml[indsrc]*Jinv*Mr
dinvp = Ml[indsrc]*dinv

lsqr!(lsqr_sol, Jp, dinvp; maxiter=niter)

# Save final velocity model, function value and history
h5open("lsrtm_marmousi_lsqr_result.h5", "w") do file
Expand Down
2 changes: 1 addition & 1 deletion src/TimeModeling/LinearOperators/basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,4 @@ time_space_src(nsrc::Integer, nt) = AbstractSize((:src, :time, :x, :y, :z), (nsr

space_src(nsrc::Integer) = AbstractSize((:src, :x, :y, :z), (nsrc, 0, 0, 0))

time_src(nsrc::Integer, nt) = AbstractSize((:src, :time), (nsrc, nt))
time_src(nsrc::Integer, nt) = AbstractSize((:src, :time), (nsrc, nt))
31 changes: 17 additions & 14 deletions src/TimeModeling/Modeling/misfit_fg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ MTypes = Union{<:AbstractModel, NTuple{N, <:AbstractModel} where N, Vector{<:Abs
dmTypes = Union{dmType, NTuple{N, dmType} where N, Vector{dmType}}


function multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, dm, options::JUDIOptions,
nlind::Bool, lin::Bool, misfit::Function, illum::Bool)
function multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, dm, options::JUDIOptions;
nlind::Bool=false, lin::Bool=false, misfit::Function=mse, illum::Bool=false,
data_precon=nothing, model_precon=LinearAlgebra.I)
GC.gc(true)
devito.clear_cache()
# assert this is for single source LSRTM
Expand All @@ -18,6 +19,9 @@ function multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, d
# Load full geometry for out-of-core geometry containers
d_geometry = Geometry(dObs.geometry)
s_geometry = Geometry(source.geometry)

# If model preconditioner is provided, apply it
dm = isnothing(dm) ? dm : model_precon * dm

# Limit model to area with sources/receivers
if options.limit_m == true
Expand Down Expand Up @@ -46,7 +50,17 @@ function multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, d
rec_coords = setup_grid(d_geometry, size(model)) # shifts rec coordinates by origin
end

mfunc = pyfunction(misfit, Matrix{Float32}, Matrix{Float32})
# Setup misfit function
if !isnothing(data_precon)
# resample
new_t = range(0, step=dtComp, length=size(dObserved, 1))
Pcomp = time_resample(data_precon, new_t)
runtime_misfit = (x, y) -> misfit(Pcomp*x, Pcomp*y)
else
runtime_misfit = misfit
end

mfunc = pyfunction(runtime_misfit, Matrix{Float32}, Matrix{Float32})

length(options.frequencies) == 0 ? freqs = nothing : freqs = options.frequencies
IT = illum ? (PyArray, PyArray) : (PyObject, PyObject)
Expand Down Expand Up @@ -78,17 +92,6 @@ function multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, d
return fval, grad
end


####### Defaults
multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, dm, options::JUDIOptions, nlind::Bool, lin::Bool) =
multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, dm, options::JUDIOptions, nlind::Bool, lin::Bool, mse, false)

multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, dm, options::JUDIOptions, nlind::Bool, lin::Bool, illum::Bool) =
multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, dm, options::JUDIOptions, nlind::Bool, lin::Bool, mse, illum)

multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, dm, options::JUDIOptions, nlind::Bool, lin::Bool, phi::Function) =
multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, dm, options::JUDIOptions, nlind::Bool, lin::Bool, phi, false)

# Find number of experiments
"""
get_nexp(x)
Expand Down
36 changes: 19 additions & 17 deletions src/TimeModeling/Modeling/propagation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,29 @@ Runs the function `func` for indices `1:nsrc` within arguments `func(arg_func(i)
the pool is empty, a standard loop and accumulation is ran. If the pool is a julia WorkerPool or
any custom Distributed pool, the loop is distributed via `remotecall` followed by are binary tree remote reduction.
"""
function run_and_reduce(func, pool, nsrc, arg_func::Function)
function run_and_reduce(func, pool, nsrc, arg_func::Function; kw=nothing)
# Allocate devices
_set_devices!()
# Run distributed loop
res = Vector{_TFuture}(undef, nsrc)
for i = 1:nsrc
args_loc = arg_func(i)
res[i] = remotecall(func, pool, args_loc...)
kw_loc = isnothing(kw) ? Dict() : kw(i)
res[i] = remotecall(func, pool, args_loc...; kw_loc...)
end
res = reduce!(res)
return res
end

function run_and_reduce(func, ::Nothing, nsrc, arg_func::Function)
function run_and_reduce(func, ::Nothing, nsrc, arg_func::Function; kw=nothing)
@juditime "Running $(func) for first src" begin
out = func(arg_func(1)...)
kw_loc = isnothing(kw) ? Dict() : kw(1)
out = func(arg_func(1)...; kw_loc...)
end
for i=2:nsrc
@juditime "Running $(func) for src $(i)" begin
next = func(arg_func(i)...)
kw_loc = isnothing(kw) ? Dict() : kw(i)
next = func(arg_func(i)...; kw_loc...)
end
single_reduce!(out, next)
end
Expand Down Expand Up @@ -100,27 +103,26 @@ Computes the misifit and gradient (LSRTM if `lin` else FWI) for the given `q` so
perturbation `dm`.
"""
function multi_src_fg!(G, model, q, dobs, dm; options=Options(), kw...)
check_non_indexable(kw)
# Number of sources and init result
nsrc = try q.nsrc catch; dobs.nsrc end
pool = _worker_pool()
illum = compute_illum(model, :adjoint_born)
# Distribute source
arg_func = i -> (model, q[i], dobs[i], dm, options[i], values(kw)..., illum)
arg_func = i -> (model, q[i], dobs[i], dm, options[i])
kw_func = i -> Dict(:illum=> illum, Dict(k => kw_i(v, i) for (k, v) in kw)...)
# Distribute source
res = run_and_reduce(multi_src_fg, pool, nsrc, arg_func)
res = run_and_reduce(multi_src_fg, pool, nsrc, arg_func; kw=kw_func)
f, g = update_illum(res, model, :adjoint_born)
f, g = as_vec(res, Val(options.return_array))
G .+= g
return f
end

check_non_indexable(d::Dict) = for (k, v) in d check_non_indexable(v) end
check_non_indexable(x) = false
check_non_indexable(::Bool) = true
check_non_indexable(::Number) = true
check_non_indexable(::judiMultiSourceVector) = throw(ArgumentError("Keyword arguments must not be source dependent"))
function check_non_indexable(::AbstractArray)
@warn "keyword arguement Array considered source independent and copied to all workers"
true
end
kw_i(b::Bool, ::Integer) = b
kw_i(msv::judiMultiSourceVector, i::Integer) = msv[i]
kw_i(P::DataPreconditioner, i::Integer) = P[i]
kw_i(P::ModelPreconditioner, ::Integer) = P
kw_i(P::MultiPreconditioner{TP, T}, i::Integer) where {TP, T} = MultiPreconditioner{TP, T}([kw_i(Pi, i) for Pi in P.precs])
kw_i(t::Tuple, i::Integer) = tuple(kw_i(ti, i) for ti in t)
kw_i(d::Vector{<:Preconditioner}, i::Integer) = foldr(*, [kw_i(di, i) for di in d])
kw_i(d::NTuple{N, <:Preconditioner}, i::Integer) where N = foldr(*, [kw_i(di, i) for di in d])
50 changes: 44 additions & 6 deletions src/TimeModeling/Preconditioners/DataPreconditioners.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
export DataMute, FrequencyFilter, judiTimeDerivative, judiTimeIntegration, TimeDifferential
export judiFilter, filter_data, judiDataMute, muteshot, judiTimeGain


time_resample(t::NTuple{N, <:DataPreconditioner}, newt) where N = prod(time_resample(ti, newt) for ti in t)
time_resample(t::Vector{<:DataPreconditioner}, newt) = prod(time_resample(ti, newt) for ti in t)

############################################ Data mute ###############################################
"""
struct DataMute{T, mode} <: DataPreconditioner{T, T}
Expand Down Expand Up @@ -60,8 +64,8 @@ end
judiDataMute(q::judiVector, d::judiVector; kw...) = judiDataMute(q.geometry, d.geometry; kw...)

# Implementation
matvec_T(D::DataMute{T, mode} , x::AbstractVector{T}) where {T, mode} = matvec(D, x) # Basically a mask so symmetric and self adjoint
matvec(D::DataMute{T, mode} , x::AbstractVector{T}) where {T, mode} = muteshot(x, D.srcGeom, D.recGeom; vp=D.vp, t0=D.t0, mode=mode, taperwidth=D.taperwidth)
matvec_T(D::DataMute{T, mode} , x::AbstractVecOrMat{T}) where {T, mode} = matvec(D, x) # Basically a mask so symmetric and self adjoint
matvec(D::DataMute{T, mode} , x::AbstractVecOrMat{T}) where {T, mode} = muteshot(x, D.srcGeom, D.recGeom; vp=D.vp, t0=D.t0, mode=mode, taperwidth=D.taperwidth)

# Real diagonal operator
conj(I::DataMute{T, mode}) where {T, mode} = I
Expand All @@ -76,6 +80,16 @@ function getindex(P::DataMute{T, mode}, i) where {T, mode}
DataMute{T, mode}(m, P.srcGeom[inds], geomi, P.vp[inds], P.t0[inds], P.taperwidth[inds])
end

function time_resample(d::DataMute{T, mode}, taxis::AbstractRange) where {T, mode}
@assert get_nsrc(d.recGeom) == 1
new_rgeom = deepcopy(d.recGeom)
new_sgeom = deepcopy(d.srcGeom)
new_rgeom.taxis[1] = taxis
new_sgeom.taxis[1] = taxis
taperwidth = trunc.(Int64, d.taperwidth .* (d.recGeom.dt ./ taxis.step))
return DataMute{T, mode}(d.m, new_sgeom, new_rgeom, d.vp, d.t0, taperwidth)
end

function _mutetrace!(t::AbstractVector{T}, taper::AbstractVector{T}, i::Integer, taperwidth::Integer, ::Val{:reflection}) where T
t[1:i-taperwidth] .= 0f0
t[i-taperwidth+1:i] .*= taper
Expand Down Expand Up @@ -104,14 +118,14 @@ function muteshot!(shot::AbstractMatrix{T}, rGeom::Geometry, srcGeom::Geometry;
end
end

function muteshot(shot::Vector{T}, srcGeom::Geometry, recGeom::Geometry;
function muteshot(shot::VecOrMat{T}, srcGeom::Geometry, recGeom::Geometry;
vp=1500, t0=.1, mode=:reflection, taperwidth=floor(Int, 2/t0)) where {T<:Number}
sr = reshape(shot, recGeom)
for s=1:get_nsrc(recGeom)
sri = view(sr, :, :, s)
muteshot!(sri, recGeom[s], srcGeom[s]; vp=vp[s], t0=t0[s], mode=mode, taperwidth=taperwidth[s])
end
return vec(vcat(sr...))
return reshape(sr, size(shot))
end

function muteshot(shot::judiVector, srcGeom::Geometry; vp=1500, t0=.1, mode=:reflection, taperwidth=20)
Expand Down Expand Up @@ -144,7 +158,7 @@ judiFilter(geometry::Geometry, fmin::T, fmax::T) where T = judiFilter(geometry,
judiFilter(geometry::Geometry, fmin::Float32, fmax::Float32) = FrequencyFilter{Float32, fmin, fmax}(n_samples(geometry), geometry)
judiFilter(v::judiVector, fmin, fmax) = judiFilter(v.geometry, fmin, fmax)

function matvec(D::FrequencyFilter{T, fm, FM}, x::Vector{T}) where {T, fm, FM}
function matvec(D::FrequencyFilter{T, fm, FM}, x::VecOrMat{T}) where {T, fm, FM}
dr = reshape(x, D.recGeom)
for j=1:get_nsrc(D.recGeom)
dri = view(dr, :, :, j)
Expand All @@ -167,6 +181,13 @@ function getindex(P::FrequencyFilter{T, fm, FM}, i) where {T, fm, FM}
return FrequencyFilter{T, fm, FM}(n_samples(geomi), geomi)
end

function time_resample(d::FrequencyFilter{T, fm, fM}, taxis::AbstractRange) where {T, fm, fM}
@assert get_nsrc(d.recGeom) == 1
new_geom = deepcopy(d.recGeom)
new_geom.taxis[1] = taxis
return FrequencyFilter{T, fm, fM}(d.m, new_geom)
end

# filtering is self-adjoint (diagonal in fourier domain)
matvec_T(D::FrequencyFilter{T, fm, FM}, x) where {T, fm, FM} = matvec(D, x)

Expand Down Expand Up @@ -263,6 +284,14 @@ function getindex(P::TimeDifferential{T, K}, i) where {T, K}
TimeDifferential{T, K}(m, geomi)
end

function time_resample(d::TimeDifferential{T, K}, taxis::AbstractRange) where {T, K}
@assert get_nsrc(d.recGeom) == 1
new_geom = deepcopy(d.recGeom)
new_geom.taxis[1] = taxis
return TimeDifferential{T, K}(d.m, new_geom)
end


function matvec(D::TimeDifferential{T, K}, x::judiVector{T, AT}) where {T, AT, K}
out = similar(x)
for s=1:out.nsrc
Expand All @@ -277,11 +306,12 @@ end

function matvec(D::TimeDifferential{T, K}, x::Array{T}) where {T, K}
xr = reshape(x, D.recGeom)
out = similar(xr)
# make omega^K
ω = 2 .* pi .* fftfreq(get_nt(D.recGeom, 1), 1/get_dt(D.recGeom, 1))
ω[ω.==0] .= 1f0
ω .= abs.(ω).^K
out = real.(ifft.* fft(xr, 1), 1))
out .= real.(ifft.* fft(xr, 1), 1))
return reshape(out, size(x))
end

Expand Down Expand Up @@ -323,6 +353,14 @@ matvec_T(D::TimeGain{T, K}, x) where {T, K} = matvec(D, x)
# getindex for source subsampling
getindex(D::TimeGain{T, K}, i) where {T, K} = TimeGain{T, K}(D.recGeom[i])


function time_resample(d::TimeGain, taxis::AbstractRange)
@assert get_nsrc(d.recGeom) == 1
new_geom = deepcopy(d.recGeom)
new_geom.taxis[1] = taxis
return TimeGain(new_geom, d.pow)
end

function matvec(D::TimeGain{T, K}, x::judiVector{T, AT}) where {T, AT, K}
out = similar(x)
for s=1:out.nsrc
Expand Down
46 changes: 44 additions & 2 deletions src/TimeModeling/Preconditioners/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,51 @@ getproperty(J::Preconditioner, s::Symbol) = _get_property(J, Val{s}())
# Base compat
*(J::Preconditioner, ms::judiMultiSourceVector) = matvec(J, ms)
*(J::Preconditioner, ms::PhysicalParameter) = matvec(J, ms)
*(J::Preconditioner, v::Vector{T}) where T = matvec(J, v)
*(J::Preconditioner, v::VecOrMat{T}) where T = matvec(J, v)

mul!(out::judiMultiSourceVector, J::Preconditioner, ms::judiMultiSourceVector) = copyto!(out, matvec(J, ms))
mul!(out::PhysicalParameter, J::Preconditioner, ms::PhysicalParameter) = copyto!(out, matvec(J, ms))

# OOC judiVector
*(J::DataPreconditioner, v::judiVector{T, SegyIO.SeisCon}) where T = LazyMul(v.nsrc, J, v)
*(J::DataPreconditioner, v::judiVector{T, SegyIO.SeisCon}) where T = LazyMul(v.nsrc, J, v)
*(J::Preconditioner, v::LazyMul) = LazyMul(v.nsrc, J*v.P, v)

"""
MultiPrcontinioner{TP, T}
Type for the combination of preconditioners. It is a linear operator that applies the preconditioners in sequence.
"""

struct MultiPreconditioner{TP, T} <: Preconditioner{T, T}
precs::Vector{TP}
end

function matvec(J::MultiPreconditioner, x)
y = J.precs[end] * x
for P in J.precs[1:end-1]
y = P * y
end
return y
end

function matvec_T(J::MultiPreconditioner, x)
y = J.precs[1]' * x
for P in J.precs[2:end]
y = P' * y
end
return y
end

conj(I::MultiPreconditioner{TP, T}) where {TP, T} = MultiPreconditioner{TP, T}(conj.(I.precs))
adjoint(I::MultiPreconditioner{TP, T}) where {TP, T} = MultiPreconditioner{TP, T}(adjoint.(reverse(I.precs)))
transpose(I::MultiPreconditioner{TP, T}) where {TP, T} = MultiPreconditioner{TP, T}(transpose.(reverse(I.precs)))

getindex(I::MultiPreconditioner{TP, T}, i) where {TP, T} = MultiPreconditioner{TP, T}([getindex(P, i) for P in I.precs])

for T in [DataPreconditioner, ModelPreconditioner]
@eval *(P1::$(T){DT}, P2::$(T){DT}) where DT = MultiPreconditioner{$(T), DT}([P1, P2])
@eval *(P::$(T){DT}, P2::MultiPreconditioner{$(T), DT}) where DT = MultiPreconditioner{$(T), DT}([P, P2.precs...])
@eval *(P2::MultiPreconditioner{$(T), DT}, P::$(T){DT}) where DT = MultiPreconditioner{$(T), DT}([P2.precs..., P])
end

time_resample(x::MultiPreconditioner{TP, T}, newt) where {TP, T} = MultiPreconditioner{TP, T}([time_resample(P, newt) for P in x.precs])
Loading

0 comments on commit 3a9eb54

Please sign in to comment.