Skip to content

Commit

Permalink
add data/model_precon support to lsrtm/fwi objectives
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Mar 22, 2024
1 parent 940ae55 commit 634bbf4
Show file tree
Hide file tree
Showing 13 changed files with 199 additions and 68 deletions.
5 changes: 4 additions & 1 deletion examples/scripts/lsrtm_2D.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ lsqr_sol = zeros(Float32, prod(model0.n))
dinv = d_lin[indsrc]
Jinv = J[indsrc]

lsqr!(lsqr_sol, Ml[indsrc]*Jinv*Mr, Ml[indsrc]*dinv; maxiter=niter)
Jp = Ml[indsrc]*Jinv*Mr
dinvp = Ml[indsrc]*dinv

lsqr!(lsqr_sol, Jp, dinvp; maxiter=niter)

# Save final velocity model, function value and history
h5open("lsrtm_marmousi_lsqr_result.h5", "w") do file
Expand Down
2 changes: 1 addition & 1 deletion src/TimeModeling/LinearOperators/basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,4 @@ time_space_src(nsrc::Integer, nt) = AbstractSize((:src, :time, :x, :y, :z), (nsr

space_src(nsrc::Integer) = AbstractSize((:src, :x, :y, :z), (nsrc, 0, 0, 0))

time_src(nsrc::Integer, nt) = AbstractSize((:src, :time), (nsrc, nt))
time_src(nsrc::Integer, nt) = AbstractSize((:src, :time), (nsrc, nt))
31 changes: 17 additions & 14 deletions src/TimeModeling/Modeling/misfit_fg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ MTypes = Union{<:AbstractModel, NTuple{N, <:AbstractModel} where N, Vector{<:Abs
dmTypes = Union{dmType, NTuple{N, dmType} where N, Vector{dmType}}


function multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, dm, options::JUDIOptions,
nlind::Bool, lin::Bool, misfit::Function, illum::Bool)
function multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, dm, options::JUDIOptions;
nlind::Bool=false, lin::Bool=false, misfit::Function=mse, illum::Bool=false,
data_precon=nothing, model_precon=LinearAlgebra.I)
GC.gc(true)
devito.clear_cache()
# assert this is for single source LSRTM
Expand All @@ -18,6 +19,9 @@ function multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, d
# Load full geometry for out-of-core geometry containers
d_geometry = Geometry(dObs.geometry)
s_geometry = Geometry(source.geometry)

# If model preconditioner is provided, apply it
dm = isnothing(dm) ? dm : model_precon * dm

# Limit model to area with sources/receivers
if options.limit_m == true
Expand Down Expand Up @@ -46,7 +50,17 @@ function multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, d
rec_coords = setup_grid(d_geometry, size(model)) # shifts rec coordinates by origin
end

mfunc = pyfunction(misfit, Matrix{Float32}, Matrix{Float32})
# Setup misfit function
if !isnothing(data_precon)
# resample
new_t = range(start=0, step=dtComp, length=size(dObserved, 1))
Pcomp = time_resample(data_precon, new_t)
runtime_misfit = (x, y) -> misfit(Pcomp*x, Pcomp*y)
else
runtime_misfit = misfit
end

mfunc = pyfunction(runtime_misfit, Matrix{Float32}, Matrix{Float32})

length(options.frequencies) == 0 ? freqs = nothing : freqs = options.frequencies
IT = illum ? (PyArray, PyArray) : (PyObject, PyObject)
Expand Down Expand Up @@ -78,17 +92,6 @@ function multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, d
return fval, grad
end


####### Defaults
multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, dm, options::JUDIOptions, nlind::Bool, lin::Bool) =
multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, dm, options::JUDIOptions, nlind::Bool, lin::Bool, mse, false)

multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, dm, options::JUDIOptions, nlind::Bool, lin::Bool, illum::Bool) =
multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, dm, options::JUDIOptions, nlind::Bool, lin::Bool, mse, illum)

multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, dm, options::JUDIOptions, nlind::Bool, lin::Bool, phi::Function) =
multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, dm, options::JUDIOptions, nlind::Bool, lin::Bool, phi, false)

# Find number of experiments
"""
get_nexp(x)
Expand Down
33 changes: 16 additions & 17 deletions src/TimeModeling/Modeling/propagation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,29 @@ Runs the function `func` for indices `1:nsrc` within arguments `func(arg_func(i)
the pool is empty, a standard loop and accumulation is ran. If the pool is a julia WorkerPool or
any custom Distributed pool, the loop is distributed via `remotecall` followed by are binary tree remote reduction.
"""
function run_and_reduce(func, pool, nsrc, arg_func::Function)
function run_and_reduce(func, pool, nsrc, arg_func::Function; kw=nothing)
# Allocate devices
_set_devices!()
# Run distributed loop
res = Vector{_TFuture}(undef, nsrc)
for i = 1:nsrc
args_loc = arg_func(i)
res[i] = remotecall(func, pool, args_loc...)
kw_loc = isnothing(kw) ? Dict() : kw(i)
res[i] = remotecall(func, pool, args_loc...; kw_loc...)
end
res = reduce!(res)
return res
end

function run_and_reduce(func, ::Nothing, nsrc, arg_func::Function)
function run_and_reduce(func, ::Nothing, nsrc, arg_func::Function; kw=nothing)
@juditime "Running $(func) for first src" begin
out = func(arg_func(1)...)
kw_loc = isnothing(kw) ? Dict() : kw(1)
out = func(arg_func(1)...; kw_loc...)
end
for i=2:nsrc
@juditime "Running $(func) for src $(i)" begin
next = func(arg_func(i)...)
kw_loc = isnothing(kw) ? Dict() : kw(i)
next = func(arg_func(i)...; kw_loc...)
end
single_reduce!(out, next)
end
Expand Down Expand Up @@ -100,27 +103,23 @@ Computes the misifit and gradient (LSRTM if `lin` else FWI) for the given `q` so
perturbation `dm`.
"""
function multi_src_fg!(G, model, q, dobs, dm; options=Options(), kw...)
check_non_indexable(kw)
# Number of sources and init result
nsrc = try q.nsrc catch; dobs.nsrc end
pool = _worker_pool()
illum = compute_illum(model, :adjoint_born)
# Distribute source
arg_func = i -> (model, q[i], dobs[i], dm, options[i], values(kw)..., illum)
arg_func = i -> (model, q[i], dobs[i], dm, options[i])
kw_func = i -> Dict(:illum=> illum, Dict(k => kw_i(v, i) for (k, v) in kw)...)
# Distribute source
res = run_and_reduce(multi_src_fg, pool, nsrc, arg_func)
res = run_and_reduce(multi_src_fg, pool, nsrc, arg_func; kw=kw_func)
f, g = update_illum(res, model, :adjoint_born)
f, g = as_vec(res, Val(options.return_array))
G .+= g
return f
end

check_non_indexable(d::Dict) = for (k, v) in d check_non_indexable(v) end
check_non_indexable(x) = false
check_non_indexable(::Bool) = true
check_non_indexable(::Number) = true
check_non_indexable(::judiMultiSourceVector) = throw(ArgumentError("Keyword arguments must not be source dependent"))
function check_non_indexable(::AbstractArray)
@warn "keyword arguement Array considered source independent and copied to all workers"
true
end
kw_i(b::Bool, ::Integer) = b
kw_i(msv::judiMultiSourceVector, i::Integer) = msv[i]
kw_i(P::Preconditioner, i::Integer) = P[i]
kw_i(t::Tuple, i::Integer) = tuple(kw_i(ti, i) for ti in t)
kw_i(d::Vector, i::Integer) = [kw_i(di, i) for di in d]
44 changes: 40 additions & 4 deletions src/TimeModeling/Preconditioners/DataPreconditioners.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
export DataMute, FrequencyFilter, judiTimeDerivative, judiTimeIntegration, TimeDifferential
export judiFilter, filter_data, judiDataMute, muteshot, judiTimeGain


time_resample(t::NTuple{N, <:DataPreconditioner}) where N = prod(time_resample.(t))
time_resample(t::Vector{<:DataPreconditioner}) = prod(time_resample.(t))

############################################ Data mute ###############################################
"""
struct DataMute{T, mode} <: DataPreconditioner{T, T}
Expand Down Expand Up @@ -60,8 +64,8 @@ end
judiDataMute(q::judiVector, d::judiVector; kw...) = judiDataMute(q.geometry, d.geometry; kw...)

# Implementation
matvec_T(D::DataMute{T, mode} , x::AbstractVector{T}) where {T, mode} = matvec(D, x) # Basically a mask so symmetric and self adjoint
matvec(D::DataMute{T, mode} , x::AbstractVector{T}) where {T, mode} = muteshot(x, D.srcGeom, D.recGeom; vp=D.vp, t0=D.t0, mode=mode, taperwidth=D.taperwidth)
matvec_T(D::DataMute{T, mode} , x::AbstractVecOrMat{T}) where {T, mode} = matvec(D, x) # Basically a mask so symmetric and self adjoint
matvec(D::DataMute{T, mode} , x::AbstractVecOrMat{T}) where {T, mode} = muteshot(x, D.srcGeom, D.recGeom; vp=D.vp, t0=D.t0, mode=mode, taperwidth=D.taperwidth)

# Real diagonal operator
conj(I::DataMute{T, mode}) where {T, mode} = I
Expand All @@ -76,6 +80,15 @@ function getindex(P::DataMute{T, mode}, i) where {T, mode}
DataMute{T, mode}(m, P.srcGeom[inds], geomi, P.vp[inds], P.t0[inds], P.taperwidth[inds])
end

function time_resample(d::DataMute{T, mode}, taxis::AbstractRange) where {T, mode}
@assert get_nsrc(d.recGeom) == 1
new_rgeom = deepcopy(d.recGeom)
new_sgeom = deepcopy(d.srcGeom)
new_rgeom.taxis[1] = taxis
new_sgeom.taxis[1] = taxis
return DataMute{T, mode}(d.m, new_sgeom, new_rgeom, d.vp, d.t0, d.taperwidth)
end

function _mutetrace!(t::AbstractVector{T}, taper::AbstractVector{T}, i::Integer, taperwidth::Integer, ::Val{:reflection}) where T
t[1:i-taperwidth] .= 0f0
t[i-taperwidth+1:i] .*= taper
Expand Down Expand Up @@ -104,7 +117,7 @@ function muteshot!(shot::AbstractMatrix{T}, rGeom::Geometry, srcGeom::Geometry;
end
end

function muteshot(shot::Vector{T}, srcGeom::Geometry, recGeom::Geometry;
function muteshot(shot::VecOrMat{T}, srcGeom::Geometry, recGeom::Geometry;
vp=1500, t0=.1, mode=:reflection, taperwidth=floor(Int, 2/t0)) where {T<:Number}
sr = reshape(shot, recGeom)
for s=1:get_nsrc(recGeom)
Expand Down Expand Up @@ -144,7 +157,7 @@ judiFilter(geometry::Geometry, fmin::T, fmax::T) where T = judiFilter(geometry,
judiFilter(geometry::Geometry, fmin::Float32, fmax::Float32) = FrequencyFilter{Float32, fmin, fmax}(n_samples(geometry), geometry)
judiFilter(v::judiVector, fmin, fmax) = judiFilter(v.geometry, fmin, fmax)

function matvec(D::FrequencyFilter{T, fm, FM}, x::Vector{T}) where {T, fm, FM}
function matvec(D::FrequencyFilter{T, fm, FM}, x::VecOrMat{T}) where {T, fm, FM}
dr = reshape(x, D.recGeom)
for j=1:get_nsrc(D.recGeom)
dri = view(dr, :, :, j)
Expand All @@ -167,6 +180,13 @@ function getindex(P::FrequencyFilter{T, fm, FM}, i) where {T, fm, FM}
return FrequencyFilter{T, fm, FM}(n_samples(geomi), geomi)
end

function time_resample(d::FrequencyFilter{T, fm, fM}, taxis::AbstractRange) where {T, fm, fM}
@assert get_nsrc(d.recGeom) == 1
new_geom = deepcopy(d.recGeom)
new_geom.taxis[1] = taxis
return FrequencyFilter{T, fm, fM}(d.m, new_geom)
end

# filtering is self-adjoint (diagonal in fourier domain)
matvec_T(D::FrequencyFilter{T, fm, FM}, x) where {T, fm, FM} = matvec(D, x)

Expand Down Expand Up @@ -263,6 +283,14 @@ function getindex(P::TimeDifferential{T, K}, i) where {T, K}
TimeDifferential{T, K}(m, geomi)
end

function time_resample(d::TimeDifferential{T, K}, taxis::AbstractRange) where {T, K}
@assert get_nsrc(d.recGeom) == 1
new_geom = deepcopy(d.recGeom)
new_geom.taxis[1] = taxis
return TimeDifferential{T, K}(d.m, new_geom)
end


function matvec(D::TimeDifferential{T, K}, x::judiVector{T, AT}) where {T, AT, K}
out = similar(x)
for s=1:out.nsrc
Expand Down Expand Up @@ -323,6 +351,14 @@ matvec_T(D::TimeGain{T, K}, x) where {T, K} = matvec(D, x)
# getindex for source subsampling
getindex(D::TimeGain{T, K}, i) where {T, K} = TimeGain{T, K}(D.recGeom[i])


function time_resample(d::TimeGain, taxis::AbstractRange)
@assert get_nsrc(d.recGeom) == 1
new_geom = deepcopy(d.recGeom)
new_geom.taxis[1] = taxis
return TimeGain(new_geom, d.pow)
end

function matvec(D::TimeGain{T, K}, x::judiVector{T, AT}) where {T, AT, K}
out = similar(x)
for s=1:out.nsrc
Expand Down
41 changes: 40 additions & 1 deletion src/TimeModeling/Preconditioners/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,43 @@ mul!(out::judiMultiSourceVector, J::Preconditioner, ms::judiMultiSourceVector) =
mul!(out::PhysicalParameter, J::Preconditioner, ms::PhysicalParameter) = copyto!(out, matvec(J, ms))

# OOC judiVector
*(J::DataPreconditioner, v::judiVector{T, SegyIO.SeisCon}) where T = LazyMul(v.nsrc, J, v)
*(J::DataPreconditioner, v::judiVector{T, SegyIO.SeisCon}) where T = LazyMul(v.nsrc, J, v)
*(J::Preconditioner, v::LazyMul) = LazyMul(v.nsrc, J*v.P, v)

"""
MultiPrcontinioner{TP, T}
Type for the combination of preconditioners. It is a linear operator that applies the preconditioners in sequence.
"""

struct MultiPreconditioner{TP, T} <: Preconditioner{T, T}
precs::Vector{TP}
end

function matvec(J::MultiPreconditioner, x)
y = J.precs[end] * x
for P in J.precs[1:end-1]
y = P * y
end
return y
end

function matvec_T(J::MultiPreconditioner, x)
y = J.precs[1]' * x
for P in J.precs[2:end]
y = P' * y
end
return y
end

conj(I::MultiPreconditioner{TP, T}) where {TP, T} = MultiPreconditioner{TP, T}(conj.(I.precs))
adjoint(I::MultiPreconditioner{TP, T}) where {TP, T} = MultiPreconditioner{TP, T}(adjoint.(reverse(I.precs)))
transpose(I::MultiPreconditioner{TP, T}) where {TP, T} = MultiPreconditioner{TP, T}(transpose.(reverse(I.precs)))

getindex(I::MultiPreconditioner{TP, T}, i) where {TP, T} = MultiPreconditioner{TP, T}([getindex(P, i) for P in I.precs])

for T in [DataPreconditioner, ModelPreconditioner]
@eval *(P1::$(T){DT}, P2::$(T){DT}) where DT = MultiPreconditioner{$(T), DT}([P1, P2])
@eval *(P::$(T){DT}, P2::MultiPreconditioner{$(T), DT}) where DT = MultiPreconditioner{$(T), DT}([P, P2.precs...])
@eval *(P2::MultiPreconditioner{$(T), DT}, P::$(T){DT}) where DT = MultiPreconditioner{$(T), DT}([P2.precs..., P])
end
16 changes: 8 additions & 8 deletions src/TimeModeling/TimeModeling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,6 @@ include("LinearOperators/lazy.jl")
include("LinearOperators/operators.jl")
include("LinearOperators/callable.jl")

#############################################################################
# PDE solvers
include("Modeling/distributed.jl") # Modeling functions utilities
include("Modeling/python_interface.jl") # forward/adjoint linear/nonlinear modeling
include("Modeling/time_modeling_serial.jl") # forward/adjoint linear/nonlinear modeling
include("Modeling/misfit_fg.jl") # FWI/LSRTM objective function value and gradient
include("Modeling/twri_objective.jl") # TWRI objective function value and gradient
include("Modeling/propagation.jl")

#############################################################################
# Preconditioners
Expand All @@ -55,6 +47,14 @@ include("Preconditioners/utils.jl")
include("Preconditioners/DataPreconditioners.jl")
include("Preconditioners/ModelPreconditioners.jl")

#############################################################################
# PDE solvers
include("Modeling/distributed.jl") # Modeling functions utilities
include("Modeling/python_interface.jl") # forward/adjoint linear/nonlinear modeling
include("Modeling/time_modeling_serial.jl") # forward/adjoint linear/nonlinear modeling
include("Modeling/misfit_fg.jl") # FWI/LSRTM objective function value and gradient
include("Modeling/twri_objective.jl") # TWRI objective function value and gradient
include("Modeling/propagation.jl")

#############################################################################
# Extra that need all imports
Expand Down
Loading

0 comments on commit 634bbf4

Please sign in to comment.