Skip to content

Commit bda70af

Browse files
Add possibility to pass vector of Ns to convergence_test (#113)
* add possibility to pass vector of Ns to convergence_test * format * fix test * fix eocs * sort Ns before * format * add possibility to pass io object * put io in docstring * format * remove file again * Update src/util.jl Co-authored-by: Hendrik Ranocha <[email protected]> --------- Co-authored-by: Hendrik Ranocha <[email protected]>
1 parent 3adee66 commit bda70af

File tree

3 files changed

+57
-39
lines changed

3 files changed

+57
-39
lines changed

src/callbacks_step/analysis.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ function AnalysisCallback(mesh, equations::AbstractEquations, solver;
7373
extra_analysis_integrals = (),
7474
analysis_integrals = union(default_analysis_integrals(equations),
7575
extra_analysis_integrals),
76-
io = stdout)
76+
io::IO = stdout)
7777
# Decide when the callback is activated.
7878
# With error-based step size control, some steps can be rejected. Thus,
7979
# `integrator.iter >= integrator.stats.naccept`

src/util.jl

+49-38
Original file line numberDiff line numberDiff line change
@@ -48,31 +48,46 @@ function default_example()
4848
"bbm_bbm_variable_bathymetry_1d_basic.jl")
4949
end
5050

51+
function convergence_test(example::AbstractString, iterations_or_Ns; kwargs...)
52+
convergence_test(Main, example::AbstractString, iterations_or_Ns; kwargs...)
53+
end
54+
5155
"""
52-
convergence_test([mod::Module=Main,] example::AbstractString, iterations; kwargs...)
56+
convergence_test([mod::Module=Main,] example::AbstractString, iterations; io::IO = stdout, kwargs...)
57+
convergence_test([mod::Module=Main,] example::AbstractString, Ns::AbstractVector; io::IO = stdout, kwargs...)
5358
54-
Run `iterations` simulations using the setup given in `example` and compute
59+
Run multiple simulations using the setup given in `example` and compute
5560
the experimental order of convergence (EOC) in the ``L^2`` and ``L^\\infty`` norm.
56-
In each iteration, the resolution of the respective mesh will be doubled.
61+
If `iterations` is passed as integer, in each iteration, the resolution of the respective mesh
62+
will be doubled. If `Ns` is passed as vector, the simulations will be run for each value of `Ns`.
5763
Additional keyword arguments `kwargs...` and the optional module `mod` are passed directly
5864
to [`trixi_include`](@ref).
5965
6066
Adjusted from [Trixi.jl](https://github.com/trixi-framework/Trixi.jl).
6167
"""
62-
function convergence_test(mod::Module, example::AbstractString, iterations; kwargs...)
68+
function convergence_test(mod::Module, example::AbstractString, iterations; io::IO = stdout,
69+
kwargs...)
6370
@assert(iterations>1,
6471
"Number of iterations must be bigger than 1 for a convergence analysis")
6572

73+
initial_N = extract_initial_N(example, kwargs)
74+
Ns = initial_N * 2 .^ (0:(iterations - 1))
75+
convergence_test(mod, example, Ns; io = io, kwargs...)
76+
end
77+
78+
function convergence_test(mod::Module, example::AbstractString, Ns::AbstractVector;
79+
io::IO = stdout, kwargs...)
6680
# Types of errors to be calculated
6781
errors = Dict(:l2 => Float64[], :linf => Float64[])
6882

69-
initial_N = extract_initial_N(example, kwargs)
70-
83+
Base.require_one_based_indexing(Ns)
84+
sort!(Ns)
85+
iterations = length(Ns)
7186
# run simulations and extract errors
7287
for iter in 1:iterations
7388
println("Running convtest iteration ", iter, "/", iterations)
7489

75-
trixi_include(mod, example; kwargs..., N = initial_N * 2^(iter - 1))
90+
trixi_include(mod, example; kwargs..., N = Ns[iter])
7691

7792
l2_error, linf_error = mod.analysis_callback(mod.sol)
7893

@@ -85,20 +100,20 @@ function convergence_test(mod::Module, example::AbstractString, iterations; kwar
85100
end
86101

87102
# Use raw error values to compute EOC
88-
analyze_convergence(errors, iterations, mod.semi, initial_N)
103+
analyze_convergence(io, errors, iterations, mod.semi, Ns)
89104
end
90105

91106
# Analyze convergence for any semidiscretization
92107
# Note: this intermediate method is to allow dispatching on the semidiscretization
93-
function analyze_convergence(errors, iterations, semi::Semidiscretization, initial_N)
108+
function analyze_convergence(io, errors, iterations, semi::Semidiscretization, Ns)
94109
_, equations, _, _ = mesh_equations_solver_cache(semi)
95110
variablenames = varnames(prim2prim, equations)
96-
analyze_convergence(errors, iterations, variablenames, initial_N)
111+
analyze_convergence(io, errors, iterations, variablenames, Ns)
97112
end
98113

99114
# This method is called with the collected error values to actually compute and print the EOC
100-
function analyze_convergence(errors, iterations,
101-
variablenames::Union{Tuple, AbstractArray}, initial_N)
115+
function analyze_convergence(io, errors, iterations,
116+
variablenames::Union{Tuple, AbstractArray}, Ns)
102117
nvariables = length(variablenames)
103118

104119
# Reshape errors to get a matrix where the i-th row represents the i-th iteration
@@ -107,66 +122,62 @@ function analyze_convergence(errors, iterations,
107122
for (kind, error) in errors)
108123

109124
# Calculate EOCs where the columns represent the variables
110-
# As dx halves in every iteration the denominator needs to be log(1/2)
111-
eocs = Dict(kind => log.(error[2:end, :] ./ error[1:(end - 1), :]) ./ log(1 / 2)
125+
eocs = Dict(kind => log.(error[2:end, :] ./ error[1:(end - 1), :]) ./
126+
log.(Ns[1:(end - 1)] ./ Ns[2:end])
112127
for (kind, error) in errorsmatrix)
113128

114129
eoc_mean_values = Dict{Symbol, Any}()
115130
eoc_mean_values[:variables] = variablenames
116131

117132
for (kind, error) in errorsmatrix
118-
println(kind)
133+
println(io, kind)
119134

120135
for v in variablenames
121-
@printf("%-25s", v)
136+
@printf(io, "%-25s", v)
122137
end
123-
println("")
138+
println(io, "")
124139

125140
for k in 1:nvariables
126-
@printf("%-5s", "N")
127-
@printf("%-10s", "error")
128-
@printf("%-10s", "EOC")
141+
@printf(io, "%-5s", "N")
142+
@printf(io, "%-10s", "error")
143+
@printf(io, "%-10s", "EOC")
129144
end
130-
println("")
145+
println(io, "")
131146

132147
# Print errors for the first iteration
133148
for k in 1:nvariables
134-
@printf("%-5d", initial_N)
135-
@printf("%-10.2e", error[1, k])
136-
@printf("%-10s", "-")
149+
@printf(io, "%-5d", Ns[1])
150+
@printf(io, "%-10.2e", error[1, k])
151+
@printf(io, "%-10s", "-")
137152
end
138-
println("")
153+
println(io, "")
139154

140155
# For the following iterations print errors and EOCs
141156
for j in 2:iterations
142157
for k in 1:nvariables
143-
@printf("%-5d", initial_N*2^(j - 1))
144-
@printf("%-10.2e", error[j, k])
145-
@printf("%-10.2f", eocs[kind][j - 1, k])
158+
@printf(io, "%-5d", Ns[j])
159+
@printf(io, "%-10.2e", error[j, k])
160+
@printf(io, "%-10.2f", eocs[kind][j - 1, k])
146161
end
147-
println("")
162+
println(io, "")
148163
end
149-
println("")
164+
println(io, "")
150165

151166
# Print mean EOCs
152167
mean_values = zeros(nvariables)
153168
for v in 1:nvariables
154169
mean_values[v] = sum(eocs[kind][:, v]) ./ length(eocs[kind][:, v])
155-
@printf("%-15s", "mean")
156-
@printf("%-10.2f", mean_values[v])
170+
@printf(io, "%-15s", "mean")
171+
@printf(io, "%-10.2f", mean_values[v])
157172
end
158173
eoc_mean_values[kind] = mean_values
159-
println("")
160-
println("-"^100)
174+
println(io, "")
175+
println(io, "-"^100)
161176
end
162177

163178
return eoc_mean_values, errorsmatrix
164179
end
165180

166-
function convergence_test(example::AbstractString, iterations; kwargs...)
167-
convergence_test(Main, example::AbstractString, iterations; kwargs...)
168-
end
169-
170181
function extract_initial_N(example, kwargs)
171182
code = read(example, String)
172183
expr = Meta.parse("begin \n$code \nend")

test/test_unit.jl

+7
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,13 @@ using SparseArrays: sparse, SparseMatrixCSC
235235
@test isapprox(eoc_mean_values[:linf][2], accuracy_order, atol = 0.5)
236236
@test isapprox(eoc_mean_values[:l2][1], accuracy_order, atol = 0.5)
237237
@test isapprox(eoc_mean_values[:linf][2], accuracy_order, atol = 0.5)
238+
239+
eoc_mean_values2, _ = convergence_test(default_example(), [512, 1024],
240+
tspan = (0.0, 1.0),
241+
accuracy_order = accuracy_order)
242+
for kind in (:l2, :linf), variable in (1, 2)
243+
eoc_mean_values[kind][variable] == eoc_mean_values2[kind][variable]
244+
end
238245
end
239246
end
240247
end

0 commit comments

Comments
 (0)