From b1362a7fdad32969582254e7a6c2d2148397db7d Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 21 Mar 2024 22:46:29 -0400 Subject: [PATCH] add data/model_precon support to lsrtm/fwi objectives --- examples/scripts/lsrtm_2D.jl | 5 +- src/TimeModeling/LinearOperators/basics.jl | 2 +- src/TimeModeling/Modeling/misfit_fg.jl | 31 ++++++------ src/TimeModeling/Modeling/propagation.jl | 36 ++++++------- .../Preconditioners/DataPreconditioners.jl | 50 ++++++++++++++++--- src/TimeModeling/Preconditioners/base.jl | 46 ++++++++++++++++- src/TimeModeling/TimeModeling.jl | 16 +++--- src/TimeModeling/Types/judiVector.jl | 44 +++++++++------- src/TimeModeling/Types/lazy_msv.jl | 6 ++- src/rrules.jl | 1 + test/test_gradients.jl | 38 ++++++++++++++ test/test_issues.jl | 4 +- test/test_preconditioners.jl | 2 +- 13 files changed, 210 insertions(+), 71 deletions(-) diff --git a/examples/scripts/lsrtm_2D.jl b/examples/scripts/lsrtm_2D.jl index d4f41ce67..4f85c5fcc 100644 --- a/examples/scripts/lsrtm_2D.jl +++ b/examples/scripts/lsrtm_2D.jl @@ -65,7 +65,10 @@ lsqr_sol = zeros(Float32, prod(model0.n)) dinv = d_lin[indsrc] Jinv = J[indsrc] -lsqr!(lsqr_sol, Ml[indsrc]*Jinv*Mr, Ml[indsrc]*dinv; maxiter=niter) +Jp = Ml[indsrc]*Jinv*Mr +dinvp = Ml[indsrc]*dinv + +lsqr!(lsqr_sol, Jp, dinvp; maxiter=niter) # Save final velocity model, function value and history h5open("lsrtm_marmousi_lsqr_result.h5", "w") do file diff --git a/src/TimeModeling/LinearOperators/basics.jl b/src/TimeModeling/LinearOperators/basics.jl index 98cac5be4..defea439b 100644 --- a/src/TimeModeling/LinearOperators/basics.jl +++ b/src/TimeModeling/LinearOperators/basics.jl @@ -66,4 +66,4 @@ time_space_src(nsrc::Integer, nt) = AbstractSize((:src, :time, :x, :y, :z), (nsr space_src(nsrc::Integer) = AbstractSize((:src, :x, :y, :z), (nsrc, 0, 0, 0)) -time_src(nsrc::Integer, nt) = AbstractSize((:src, :time), (nsrc, nt)) +time_src(nsrc::Integer, nt) = AbstractSize((:src, :time), (nsrc, nt)) \ No newline at end of file diff --git a/src/TimeModeling/Modeling/misfit_fg.jl b/src/TimeModeling/Modeling/misfit_fg.jl index 46c87dc5f..60ae61558 100644 --- a/src/TimeModeling/Modeling/misfit_fg.jl +++ b/src/TimeModeling/Modeling/misfit_fg.jl @@ -7,8 +7,9 @@ MTypes = Union{<:AbstractModel, NTuple{N, <:AbstractModel} where N, Vector{<:Abs dmTypes = Union{dmType, NTuple{N, dmType} where N, Vector{dmType}} -function multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, dm, options::JUDIOptions, - nlind::Bool, lin::Bool, misfit::Function, illum::Bool) +function multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, dm, options::JUDIOptions; + nlind::Bool=false, lin::Bool=false, misfit::Function=mse, illum::Bool=false, + data_precon=nothing, model_precon=LinearAlgebra.I) GC.gc(true) devito.clear_cache() # assert this is for single source LSRTM @@ -18,6 +19,9 @@ function multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, d # Load full geometry for out-of-core geometry containers d_geometry = Geometry(dObs.geometry) s_geometry = Geometry(source.geometry) + + # If model preconditioner is provided, apply it + dm = isnothing(dm) ? dm : model_precon * dm # Limit model to area with sources/receivers if options.limit_m == true @@ -46,7 +50,17 @@ function multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, d rec_coords = setup_grid(d_geometry, size(model)) # shifts rec coordinates by origin end - mfunc = pyfunction(misfit, Matrix{Float32}, Matrix{Float32}) + # Setup misfit function + if !isnothing(data_precon) + # resample + new_t = range(start=0, step=dtComp, length=size(dObserved, 1)) + Pcomp = time_resample(data_precon, new_t) + runtime_misfit = (x, y) -> misfit(Pcomp*x, Pcomp*y) + else + runtime_misfit = misfit + end + + mfunc = pyfunction(runtime_misfit, Matrix{Float32}, Matrix{Float32}) length(options.frequencies) == 0 ? freqs = nothing : freqs = options.frequencies IT = illum ? (PyArray, PyArray) : (PyObject, PyObject) @@ -78,17 +92,6 @@ function multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, d return fval, grad end - -####### Defaults -multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, dm, options::JUDIOptions, nlind::Bool, lin::Bool) = - multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, dm, options::JUDIOptions, nlind::Bool, lin::Bool, mse, false) - -multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, dm, options::JUDIOptions, nlind::Bool, lin::Bool, illum::Bool) = - multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, dm, options::JUDIOptions, nlind::Bool, lin::Bool, mse, illum) - -multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, dm, options::JUDIOptions, nlind::Bool, lin::Bool, phi::Function) = - multi_src_fg(model_full::AbstractModel, source::Dtypes, dObs::Dtypes, dm, options::JUDIOptions, nlind::Bool, lin::Bool, phi, false) - # Find number of experiments """ get_nexp(x) diff --git a/src/TimeModeling/Modeling/propagation.jl b/src/TimeModeling/Modeling/propagation.jl index b68e04b53..eda0a948d 100644 --- a/src/TimeModeling/Modeling/propagation.jl +++ b/src/TimeModeling/Modeling/propagation.jl @@ -20,26 +20,29 @@ Runs the function `func` for indices `1:nsrc` within arguments `func(arg_func(i) the pool is empty, a standard loop and accumulation is ran. If the pool is a julia WorkerPool or any custom Distributed pool, the loop is distributed via `remotecall` followed by are binary tree remote reduction. """ -function run_and_reduce(func, pool, nsrc, arg_func::Function) +function run_and_reduce(func, pool, nsrc, arg_func::Function; kw=nothing) # Allocate devices _set_devices!() # Run distributed loop res = Vector{_TFuture}(undef, nsrc) for i = 1:nsrc args_loc = arg_func(i) - res[i] = remotecall(func, pool, args_loc...) + kw_loc = isnothing(kw) ? Dict() : kw(i) + res[i] = remotecall(func, pool, args_loc...; kw_loc...) end res = reduce!(res) return res end -function run_and_reduce(func, ::Nothing, nsrc, arg_func::Function) +function run_and_reduce(func, ::Nothing, nsrc, arg_func::Function; kw=nothing) @juditime "Running $(func) for first src" begin - out = func(arg_func(1)...) + kw_loc = isnothing(kw) ? Dict() : kw(1) + out = func(arg_func(1)...; kw_loc...) end for i=2:nsrc @juditime "Running $(func) for src $(i)" begin - next = func(arg_func(i)...) + kw_loc = isnothing(kw) ? Dict() : kw(i) + next = func(arg_func(i)...; kw_loc...) end single_reduce!(out, next) end @@ -100,27 +103,26 @@ Computes the misifit and gradient (LSRTM if `lin` else FWI) for the given `q` so perturbation `dm`. """ function multi_src_fg!(G, model, q, dobs, dm; options=Options(), kw...) - check_non_indexable(kw) # Number of sources and init result nsrc = try q.nsrc catch; dobs.nsrc end pool = _worker_pool() illum = compute_illum(model, :adjoint_born) # Distribute source - arg_func = i -> (model, q[i], dobs[i], dm, options[i], values(kw)..., illum) + arg_func = i -> (model, q[i], dobs[i], dm, options[i]) + kw_func = i -> Dict(:illum=> illum, Dict(k => kw_i(v, i) for (k, v) in kw)...) # Distribute source - res = run_and_reduce(multi_src_fg, pool, nsrc, arg_func) + res = run_and_reduce(multi_src_fg, pool, nsrc, arg_func; kw=kw_func) f, g = update_illum(res, model, :adjoint_born) f, g = as_vec(res, Val(options.return_array)) G .+= g return f end -check_non_indexable(d::Dict) = for (k, v) in d check_non_indexable(v) end -check_non_indexable(x) = false -check_non_indexable(::Bool) = true -check_non_indexable(::Number) = true -check_non_indexable(::judiMultiSourceVector) = throw(ArgumentError("Keyword arguments must not be source dependent")) -function check_non_indexable(::AbstractArray) - @warn "keyword arguement Array considered source independent and copied to all workers" - true -end +kw_i(b::Bool, ::Integer) = b +kw_i(msv::judiMultiSourceVector, i::Integer) = msv[i] +kw_i(P::DataPreconditioner, i::Integer) = P[i] +kw_i(P::ModelPreconditioner, ::Integer) = P +kw_i(P::MultiPreconditioner{TP, T}, i::Integer) where {TP, T} = MultiPreconditioner{TP, T}([kw_i(Pi, i) for Pi in P.precs]) +kw_i(t::Tuple, i::Integer) = tuple(kw_i(ti, i) for ti in t) +kw_i(d::Vector{<:Preconditioner}, i::Integer) = foldr(*, [kw_i(di, i) for di in d]) +kw_i(d::NTuple{N, <:Preconditioner}, i::Integer) where N = foldr(*, [kw_i(di, i) for di in d]) diff --git a/src/TimeModeling/Preconditioners/DataPreconditioners.jl b/src/TimeModeling/Preconditioners/DataPreconditioners.jl index 3f50600ec..ab0739743 100644 --- a/src/TimeModeling/Preconditioners/DataPreconditioners.jl +++ b/src/TimeModeling/Preconditioners/DataPreconditioners.jl @@ -1,6 +1,10 @@ export DataMute, FrequencyFilter, judiTimeDerivative, judiTimeIntegration, TimeDifferential export judiFilter, filter_data, judiDataMute, muteshot, judiTimeGain + +time_resample(t::NTuple{N, <:DataPreconditioner}, newt) where N = prod(time_resample(ti, newt) for ti in t) +time_resample(t::Vector{<:DataPreconditioner}, newt) = prod(time_resample(ti, newt) for ti in t) + ############################################ Data mute ############################################### """ struct DataMute{T, mode} <: DataPreconditioner{T, T} @@ -60,8 +64,8 @@ end judiDataMute(q::judiVector, d::judiVector; kw...) = judiDataMute(q.geometry, d.geometry; kw...) # Implementation -matvec_T(D::DataMute{T, mode} , x::AbstractVector{T}) where {T, mode} = matvec(D, x) # Basically a mask so symmetric and self adjoint -matvec(D::DataMute{T, mode} , x::AbstractVector{T}) where {T, mode} = muteshot(x, D.srcGeom, D.recGeom; vp=D.vp, t0=D.t0, mode=mode, taperwidth=D.taperwidth) +matvec_T(D::DataMute{T, mode} , x::AbstractVecOrMat{T}) where {T, mode} = matvec(D, x) # Basically a mask so symmetric and self adjoint +matvec(D::DataMute{T, mode} , x::AbstractVecOrMat{T}) where {T, mode} = muteshot(x, D.srcGeom, D.recGeom; vp=D.vp, t0=D.t0, mode=mode, taperwidth=D.taperwidth) # Real diagonal operator conj(I::DataMute{T, mode}) where {T, mode} = I @@ -76,6 +80,16 @@ function getindex(P::DataMute{T, mode}, i) where {T, mode} DataMute{T, mode}(m, P.srcGeom[inds], geomi, P.vp[inds], P.t0[inds], P.taperwidth[inds]) end +function time_resample(d::DataMute{T, mode}, taxis::AbstractRange) where {T, mode} + @assert get_nsrc(d.recGeom) == 1 + new_rgeom = deepcopy(d.recGeom) + new_sgeom = deepcopy(d.srcGeom) + new_rgeom.taxis[1] = taxis + new_sgeom.taxis[1] = taxis + taperwidth = trunc.(Int64, d.taperwidth .* (d.recGeom.dt ./ taxis.step)) + return DataMute{T, mode}(d.m, new_sgeom, new_rgeom, d.vp, d.t0, taperwidth) +end + function _mutetrace!(t::AbstractVector{T}, taper::AbstractVector{T}, i::Integer, taperwidth::Integer, ::Val{:reflection}) where T t[1:i-taperwidth] .= 0f0 t[i-taperwidth+1:i] .*= taper @@ -104,14 +118,14 @@ function muteshot!(shot::AbstractMatrix{T}, rGeom::Geometry, srcGeom::Geometry; end end -function muteshot(shot::Vector{T}, srcGeom::Geometry, recGeom::Geometry; +function muteshot(shot::VecOrMat{T}, srcGeom::Geometry, recGeom::Geometry; vp=1500, t0=.1, mode=:reflection, taperwidth=floor(Int, 2/t0)) where {T<:Number} sr = reshape(shot, recGeom) for s=1:get_nsrc(recGeom) sri = view(sr, :, :, s) muteshot!(sri, recGeom[s], srcGeom[s]; vp=vp[s], t0=t0[s], mode=mode, taperwidth=taperwidth[s]) end - return vec(vcat(sr...)) + return reshape(sr, size(shot)) end function muteshot(shot::judiVector, srcGeom::Geometry; vp=1500, t0=.1, mode=:reflection, taperwidth=20) @@ -144,7 +158,7 @@ judiFilter(geometry::Geometry, fmin::T, fmax::T) where T = judiFilter(geometry, judiFilter(geometry::Geometry, fmin::Float32, fmax::Float32) = FrequencyFilter{Float32, fmin, fmax}(n_samples(geometry), geometry) judiFilter(v::judiVector, fmin, fmax) = judiFilter(v.geometry, fmin, fmax) -function matvec(D::FrequencyFilter{T, fm, FM}, x::Vector{T}) where {T, fm, FM} +function matvec(D::FrequencyFilter{T, fm, FM}, x::VecOrMat{T}) where {T, fm, FM} dr = reshape(x, D.recGeom) for j=1:get_nsrc(D.recGeom) dri = view(dr, :, :, j) @@ -167,6 +181,13 @@ function getindex(P::FrequencyFilter{T, fm, FM}, i) where {T, fm, FM} return FrequencyFilter{T, fm, FM}(n_samples(geomi), geomi) end +function time_resample(d::FrequencyFilter{T, fm, fM}, taxis::AbstractRange) where {T, fm, fM} + @assert get_nsrc(d.recGeom) == 1 + new_geom = deepcopy(d.recGeom) + new_geom.taxis[1] = taxis + return FrequencyFilter{T, fm, fM}(d.m, new_geom) +end + # filtering is self-adjoint (diagonal in fourier domain) matvec_T(D::FrequencyFilter{T, fm, FM}, x) where {T, fm, FM} = matvec(D, x) @@ -263,6 +284,14 @@ function getindex(P::TimeDifferential{T, K}, i) where {T, K} TimeDifferential{T, K}(m, geomi) end +function time_resample(d::TimeDifferential{T, K}, taxis::AbstractRange) where {T, K} + @assert get_nsrc(d.recGeom) == 1 + new_geom = deepcopy(d.recGeom) + new_geom.taxis[1] = taxis + return TimeDifferential{T, K}(d.m, new_geom) +end + + function matvec(D::TimeDifferential{T, K}, x::judiVector{T, AT}) where {T, AT, K} out = similar(x) for s=1:out.nsrc @@ -277,11 +306,12 @@ end function matvec(D::TimeDifferential{T, K}, x::Array{T}) where {T, K} xr = reshape(x, D.recGeom) + out = similar(xr) # make omega^K ω = 2 .* pi .* fftfreq(get_nt(D.recGeom, 1), 1/get_dt(D.recGeom, 1)) ω[ω.==0] .= 1f0 ω .= abs.(ω).^K - out = real.(ifft(ω .* fft(xr, 1), 1)) + out .= real.(ifft(ω .* fft(xr, 1), 1)) return reshape(out, size(x)) end @@ -323,6 +353,14 @@ matvec_T(D::TimeGain{T, K}, x) where {T, K} = matvec(D, x) # getindex for source subsampling getindex(D::TimeGain{T, K}, i) where {T, K} = TimeGain{T, K}(D.recGeom[i]) + +function time_resample(d::TimeGain, taxis::AbstractRange) + @assert get_nsrc(d.recGeom) == 1 + new_geom = deepcopy(d.recGeom) + new_geom.taxis[1] = taxis + return TimeGain(new_geom, d.pow) +end + function matvec(D::TimeGain{T, K}, x::judiVector{T, AT}) where {T, AT, K} out = similar(x) for s=1:out.nsrc diff --git a/src/TimeModeling/Preconditioners/base.jl b/src/TimeModeling/Preconditioners/base.jl index e9dd8badf..8662f6f22 100644 --- a/src/TimeModeling/Preconditioners/base.jl +++ b/src/TimeModeling/Preconditioners/base.jl @@ -27,9 +27,51 @@ getproperty(J::Preconditioner, s::Symbol) = _get_property(J, Val{s}()) # Base compat *(J::Preconditioner, ms::judiMultiSourceVector) = matvec(J, ms) *(J::Preconditioner, ms::PhysicalParameter) = matvec(J, ms) -*(J::Preconditioner, v::Vector{T}) where T = matvec(J, v) +*(J::Preconditioner, v::VecOrMat{T}) where T = matvec(J, v) + mul!(out::judiMultiSourceVector, J::Preconditioner, ms::judiMultiSourceVector) = copyto!(out, matvec(J, ms)) mul!(out::PhysicalParameter, J::Preconditioner, ms::PhysicalParameter) = copyto!(out, matvec(J, ms)) # OOC judiVector -*(J::DataPreconditioner, v::judiVector{T, SegyIO.SeisCon}) where T = LazyMul(v.nsrc, J, v) \ No newline at end of file +*(J::DataPreconditioner, v::judiVector{T, SegyIO.SeisCon}) where T = LazyMul(v.nsrc, J, v) +*(J::Preconditioner, v::LazyMul) = LazyMul(v.nsrc, J*v.P, v) + +""" + MultiPrcontinioner{TP, T} + +Type for the combination of preconditioners. It is a linear operator that applies the preconditioners in sequence. +""" + +struct MultiPreconditioner{TP, T} <: Preconditioner{T, T} + precs::Vector{TP} +end + +function matvec(J::MultiPreconditioner, x) + y = J.precs[end] * x + for P in J.precs[1:end-1] + y = P * y + end + return y +end + +function matvec_T(J::MultiPreconditioner, x) + y = J.precs[1]' * x + for P in J.precs[2:end] + y = P' * y + end + return y +end + +conj(I::MultiPreconditioner{TP, T}) where {TP, T} = MultiPreconditioner{TP, T}(conj.(I.precs)) +adjoint(I::MultiPreconditioner{TP, T}) where {TP, T} = MultiPreconditioner{TP, T}(adjoint.(reverse(I.precs))) +transpose(I::MultiPreconditioner{TP, T}) where {TP, T} = MultiPreconditioner{TP, T}(transpose.(reverse(I.precs))) + +getindex(I::MultiPreconditioner{TP, T}, i) where {TP, T} = MultiPreconditioner{TP, T}([getindex(P, i) for P in I.precs]) + +for T in [DataPreconditioner, ModelPreconditioner] + @eval *(P1::$(T){DT}, P2::$(T){DT}) where DT = MultiPreconditioner{$(T), DT}([P1, P2]) + @eval *(P::$(T){DT}, P2::MultiPreconditioner{$(T), DT}) where DT = MultiPreconditioner{$(T), DT}([P, P2.precs...]) + @eval *(P2::MultiPreconditioner{$(T), DT}, P::$(T){DT}) where DT = MultiPreconditioner{$(T), DT}([P2.precs..., P]) +end + +time_resample(x::MultiPreconditioner{TP, T}, newt) where {TP, T} = MultiPreconditioner{TP, T}([time_resample(P, newt) for P in x.precs]) diff --git a/src/TimeModeling/TimeModeling.jl b/src/TimeModeling/TimeModeling.jl index bac10d533..38686c2be 100644 --- a/src/TimeModeling/TimeModeling.jl +++ b/src/TimeModeling/TimeModeling.jl @@ -39,14 +39,6 @@ include("LinearOperators/lazy.jl") include("LinearOperators/operators.jl") include("LinearOperators/callable.jl") -############################################################################# -# PDE solvers -include("Modeling/distributed.jl") # Modeling functions utilities -include("Modeling/python_interface.jl") # forward/adjoint linear/nonlinear modeling -include("Modeling/time_modeling_serial.jl") # forward/adjoint linear/nonlinear modeling -include("Modeling/misfit_fg.jl") # FWI/LSRTM objective function value and gradient -include("Modeling/twri_objective.jl") # TWRI objective function value and gradient -include("Modeling/propagation.jl") ############################################################################# # Preconditioners @@ -55,6 +47,14 @@ include("Preconditioners/utils.jl") include("Preconditioners/DataPreconditioners.jl") include("Preconditioners/ModelPreconditioners.jl") +############################################################################# +# PDE solvers +include("Modeling/distributed.jl") # Modeling functions utilities +include("Modeling/python_interface.jl") # forward/adjoint linear/nonlinear modeling +include("Modeling/time_modeling_serial.jl") # forward/adjoint linear/nonlinear modeling +include("Modeling/misfit_fg.jl") # FWI/LSRTM objective function value and gradient +include("Modeling/twri_objective.jl") # TWRI objective function value and gradient +include("Modeling/propagation.jl") ############################################################################# # Extra that need all imports diff --git a/src/TimeModeling/Types/judiVector.jl b/src/TimeModeling/Types/judiVector.jl index 749fe849a..215451c4f 100644 --- a/src/TimeModeling/Types/judiVector.jl +++ b/src/TimeModeling/Types/judiVector.jl @@ -39,14 +39,14 @@ can also be a single (non-cell) array, in which case the data is the same for al judiVector(geometry, data) -Construct vector for observed data from `SegyIO.SeisBlock`. `segy_depth_key` is the `SegyIO` keyword \\ +Construct vector for observed data from `SeisBlock`. `segy_depth_key` is the `SegyIO` keyword \\ that contains the receiver depth coordinate: - judiVector(SegyIO.SeisBlock; segy_depth_key="RecGroupElevation") + judiVector(SeisBlock; segy_depth_key="RecGroupElevation") -Construct vector for observed data from out-of-core data container of type `SegyIO.SeisCon`: +Construct vector for observed data from out-of-core data container of type `SeisCon`: - judiVector(SegyIO.SeisCon; segy_depth_key="RecGroupElevation") + judiVector(SeisCon; segy_depth_key="RecGroupElevation") Examples ======== @@ -60,13 +60,13 @@ wavelets or a single wavelet as an array): q = judiVector(src_geometry, wavelet) -(3) Construct data vector from `SegyIO.SeisBlock` object: +(3) Construct data vector from `SeisBlock` object: using SegyIO seis_block = segy_read("test_file.segy") dobs = judiVector(seis_block; segy_depth_key="RecGroupElevation") -(4) Construct out-of-core data vector from `SegyIO.SeisCon` object (for large SEG-Y files): +(4) Construct out-of-core data vector from `SeisCon` object (for large SEG-Y files): using SegyIO seis_container = segy_scan("/path/to/data/directory","filenames",["GroupX","GroupY","RecGroupElevation","SourceDepth","dt"]) @@ -96,7 +96,7 @@ function judiVector(geometry::Geometry, data::Vector{Array{T, N}}) where {T, N} end # contructor for in-core data container and given geometry -function judiVector(geometry::Geometry, data::SegyIO.SeisBlock) +function judiVector(geometry::Geometry, data::SeisBlock) check_geom(geometry, data) # length of data vector src = get_header(data,"FieldRecord") @@ -111,22 +111,22 @@ function judiVector(geometry::Geometry, data::SegyIO.SeisBlock) end # contructor for single out-of-core data container and given geometry -function judiVector(geometry::Geometry, data::SegyIO.SeisCon) +function judiVector(geometry::Geometry, data::SeisCon) check_geom(geometry, data) # length of data vector nsrc = length(data) # fill data vector with pointers to data location - dataCell = Vector{SegyIO.SeisCon}(undef, nsrc) + dataCell = Vector{SeisCon}(undef, nsrc) for j=1:nsrc dataCell[j] = split(data,j) end - return judiVector{Float32, SegyIO.SeisCon}(nsrc, geometry,dataCell) + return judiVector{Float32, SeisCon}(nsrc, geometry,dataCell) end -judiVector(data::SegyIO.SeisBlock; kw...) = judiVector(Geometry(data; key="receiver", kw...), data) -judiVector(data::SegyIO.SeisCon; kw...)= judiVector(Geometry(data; key="receiver", kw...), data) -judiVector(data::Vector{SegyIO.SeisCon}; kw...) = judiVector(Geometry(data; key="receiver", kw...), data) -judiVector(geometry::Geometry, data::Vector{SegyIO.SeisCon}) = judiVector{Float32, SegyIO.SeisCon}(length(data), geometry, data) +judiVector(data::SeisBlock; kw...) = judiVector(Geometry(data; key="receiver", kw...), data) +judiVector(data::SeisCon; kw...)= judiVector(Geometry(data; key="receiver", kw...), data) +judiVector(data::Vector{SeisCon}; kw...) = judiVector(Geometry(data; key="receiver", kw...), data) +judiVector(geometry::Geometry, data::Vector{SeisCon}) = judiVector{Float32, SeisCon}(length(data), geometry, data) ############################################################ ## overloaded multi_source functions @@ -136,7 +136,12 @@ time_sampling(jv::judiVector) = get_dt(jv.geometry) # JOLI conversion jo_convert(::Type{T}, jv::judiVector{T, Array{T, N}}, ::Bool) where {T<:AbstractFloat, N} = jv jo_convert(::Type{T}, jv::judiVector{vT, Array{vT, N}}, B::Bool) where {T<:AbstractFloat, vT, N} = judiVector{T, Array{T, N}}(jv.nsrc, jv.geometry, jo_convert.(T, jv.data, B)) -zero(::Type{T}, v::judiVector{vT, AT}; nsrc::Integer=v.nsrc) where {T, vT, AT} = judiVector{T, AT}(nsrc, deepcopy(v.geometry[1:nsrc]), T(0) .* v.data[1:nsrc]) + +function zero(::Type{T}, v::judiVector{vT, AT}; nsrc::Integer=v.nsrc) where {T, vT, AT} + zgeom = deepcopy(v.geometry[1:nsrc]) + zdata = [zeros(T, get_nt(v.geometry, i), v.geometry.nrec[i]) for i=1:nsrc] + return judiVector{T, Matrix{T}}(nsrc, zgeom, zdata) +end function copy!(jv::judiVector, jv2::judiVector) jv.geometry = deepcopy(jv2.geometry) @@ -219,12 +224,15 @@ end ########################################################## # Overload needed base function for SegyIO objects -vec(x::SegyIO.SeisCon) = vec(x[1].data) -dot(x::SegyIO.SeisCon, y::SegyIO.SeisCon) = dot(x[1].data, y[1].data) -norm(x::SegyIO.SeisCon, p::Real=2) = norm(x[1].data, p) +vec(x::SeisCon) = vec(x[1].data) +dot(x::SeisCon, y::SeisCon) = dot(x[1].data, y[1].data) +norm(x::SeisCon, p::Real=2) = norm(x[1].data, p) abs(x::SegyIO.IBMFloat32) = abs(Float32(x)) *(::Number, ::SeisCon) = throw(judiMultiSourceException("Cannot multiply out of core SeisCon byt scalar")) +length(jv::judiVector{T, SeisCon}) where T = n_samples(jv.geometry) + + # push! function push!(a::judiVector{T, mT}, b::judiVector{T, mT}) where {T, mT} typeof(a.geometry) == typeof(b.geometry) || throw(judiMultiSourceException("Geometry type mismatch")) diff --git a/src/TimeModeling/Types/lazy_msv.jl b/src/TimeModeling/Types/lazy_msv.jl index 1fd0e2866..15e6882da 100644 --- a/src/TimeModeling/Types/lazy_msv.jl +++ b/src/TimeModeling/Types/lazy_msv.jl @@ -73,6 +73,8 @@ struct LazyMul{D} <: judiMultiSourceVector{D} end getindex(la::LazyMul{D}, i::RangeOrVec) where D = LazyMul{D}(length(i), la.P[i], la.msv[i]) +length(lm::LazyMul) = length(lm.msv) +zero(::Type{T}, lm::LazyMul; nsrc::Integer=lm.nsrc) where T = zero(T, lm.msv; nsrc=nsrc) function make_input(lm::LazyMul{D}) where D @assert lm.nsrc == 1 @@ -80,6 +82,8 @@ function make_input(lm::LazyMul{D}) where D end get_data(lm::LazyMul{D}) where D = lm.P * get_data(lm.msv) +materialize(lm::LazyMul{D}) where D = get_data(lm) +deepcopy(lm::LazyMul{D}) where D = deepcopy(lm.msv) function getproperty(lm::LazyMul{D}, s::Symbol) where D if s == :data @@ -89,4 +93,4 @@ function getproperty(lm::LazyMul{D}, s::Symbol) where D else return getfield(lm, s) end -end \ No newline at end of file +end diff --git a/src/rrules.jl b/src/rrules.jl index 758ff949e..d123509ea 100644 --- a/src/rrules.jl +++ b/src/rrules.jl @@ -49,6 +49,7 @@ broadcasted(::typeof(^), y::LazyPropagation, p::Real) = eval_prop(y).^(p) *(F::judiPropagator, q::LazyPropagation) = F*eval_prop(q) *(M::Preconditioner, q::LazyPropagation) = M*eval_prop(q) matvec(M::Preconditioner, q::LazyPropagation) = matvec(M, eval_prop(q)) +matvec(M::MultiPreconditioner, q::LazyPropagation) = matvec(M, eval_prop(q)) reshape(F::LazyPropagation, dims...) = LazyPropagation(x->reshape(x, dims...), F.F, F.q) copyto!(x::AbstractArray, F::LazyPropagation) = copyto!(x, eval_prop(F)) diff --git a/test/test_gradients.jl b/test/test_gradients.jl index d5ca8141d..3e162baf7 100644 --- a/test/test_gradients.jl +++ b/test/test_gradients.jl @@ -37,6 +37,44 @@ dm1 = 2f0*circshift(dm, 10) end +################################################################################################### +@testset "FWI preconditionners test with $(nlayer) layers and tti $(tti) and viscoacoustic $(viscoacoustic) and freesurface $(fs)" begin + Ml = judiDataMute(q.geometry, dobs.geometry; t0=.2) + Ml2 = judiTimeDerivative(dobs.geometry, 1) + + + Jm0, grad = fwi_objective(model0, q, dobs; options=opt, data_precon=Ml) + ghand = J'*Ml*(F0*q - dobs) + @test isapprox(norm(grad - ghand)/norm(grad+ghand), 0f0; rtol=0, atol=1e-2) + + Jm0, grad = fwi_objective(model0, q, dobs; options=opt, data_precon=[Ml, Ml2]) + ghand = J'*Ml*Ml2*(F0*q - dobs) + @test isapprox(norm(grad - ghand)/norm(grad+ghand), 0f0; rtol=0, atol=1e-2) + + Jm0, grad = fwi_objective(model0, q, dobs; options=opt, data_precon=Ml*Ml2) + @test isapprox(norm(grad - ghand)/norm(grad+ghand), 0f0; rtol=0, atol=1e-2) +end + + +@testset "LSRTM preconditionners test with $(nlayer) layers and tti $(tti) and viscoacoustic $(viscoacoustic) and freesurface $(fs)" begin + Mr = judiTopmute(model0; taperwidth=10) + Ml = judiDataMute(q.geometry, dobs.geometry) + Ml2 = judiTimeDerivative(dobs.geometry, 1) + Mr2 = judiIllumination(J) + + Jm0, grad = lsrtm_objective(model0, q, dobs, dm; options=opt, data_precon=Ml, model_precon=Mr) + ghand = J'*Ml*(J*Mr*dm - dobs) + @test isapprox(norm(grad - ghand)/norm(grad+ghand), 0f0; rtol=0, atol=1e-2) + + Jm0, grad = lsrtm_objective(model0, q, dobs, dm; options=opt, data_precon=[Ml, Ml2], model_precon=[Mr, Mr2]) + ghand = J'*Ml*Ml2*(J*Mr2*Mr*dm - dobs) + @test isapprox(norm(grad - ghand)/norm(grad+ghand), 0f0; rtol=0, atol=1e-2) + + Jm0, grad = lsrtm_objective(model0, q, dobs, dm; options=opt, data_precon=Ml*Ml2, model_precon=Mr*Mr2) + @test isapprox(norm(grad - ghand)/norm(grad+ghand), 0f0; rtol=0, atol=1e-2) + +end + ################################################################################################### @testset "LSRTM gradient test with $(nlayer) layers, tti $(tti), viscoacoustic $(viscoacoustic). freesurface $(fs), nlind $(nlind)" for nlind=[true, false] diff --git a/test/test_issues.jl b/test/test_issues.jl index b2d78fabb..190fed5f2 100644 --- a/test/test_issues.jl +++ b/test/test_issues.jl @@ -86,7 +86,7 @@ end F0 = judiModeling(model0, srcGeometry, recGeometry; options=opt) # fwi wrapper - g_ap = JUDI.multi_src_fg(model0, q , dobs, nothing, opt, false, false, mse)[2] + g_ap = JUDI.multi_src_fg(model0, q , dobs, nothing, opt)[2] @test g_ap.n == (model.n .- (22, 0)) @test g_ap.o[1] == model.d[1]*11 @@ -96,7 +96,7 @@ end @test norm(g1.data[end-10:end, :]) == 0 # lsrtm wrapper - g_ap = JUDI.multi_src_fg(model0, q , dobs, dm, opt, false, true, mse)[2] + g_ap = JUDI.multi_src_fg(model0, q , dobs, dm, opt, lin=true)[2] @test g_ap.n == (model.n .- (22, 0)) @test g_ap.o[1] == model.d[1]*11 diff --git a/test/test_preconditioners.jl b/test/test_preconditioners.jl index 98a113708..77befb6a9 100644 --- a/test/test_preconditioners.jl +++ b/test/test_preconditioners.jl @@ -202,7 +202,7 @@ dm = model0.m - model.m @test "u" ∉ keys(Iv.illums) @test norm(Iv.illums["v"]) == norm(ones(Float32, model.n)) # Test Product - @test inv(I)*I*model0.m ≈ model0.m.data[:] rtol=ftol atol=0 + @test inv(I)*I*model0.m ≈ model0.m rtol=ftol atol=0 # Test in place ModelPrecon for Pc in [Ds, Mm, Mm2, I]