Skip to content

Commit

Permalink
need to wrap everything
Browse files Browse the repository at this point in the history
  • Loading branch information
epolack committed Feb 26, 2024
1 parent 2353ffc commit db0367b
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 28 deletions.
7 changes: 0 additions & 7 deletions src/DFTK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,13 +231,6 @@ include("workarounds/dummy_inplace_fft.jl")
include("workarounds/forwarddiff_rules.jl")
include("workarounds/gpu_arrays.jl")

function __init__()
# We need to wait to have access to stdout.
# But now local to the REPL…
default_logger = DFTKLogger(; io=Base.stdout)
global_logger(default_logger)
end

# Precompilation block with a basic workflow
@setup_workload begin
# very artificial silicon ground state example
Expand Down
14 changes: 6 additions & 8 deletions src/common/logging.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
using Logging

# Removing most of format for `@info` in default logger.
function meta_formatter(level::LogLevel, args...)
color = Logging.default_logcolor(level)
Info == level && return color, "", ""
Logging.default_metafmt(level, args...)
end
#using Preferences

# Bypasses everything to ConsoleLogger but Info which just shows message without any
# formatting.
Base.@kwdef struct DFTKLogger <: AbstractLogger
io::IO
min_level::LogLevel = Info
fallback = ConsoleLogger(io, min_level; meta_formatter)
fallback = ConsoleLogger(io, min_level)
end
function Logging.handle_message(logger::DFTKLogger, level, msg, args...; kwargs...)
level == Info && return level < logger.min_level ? nothing : println(logger.io, msg)
Logging.handle_message(logger.fallback, level, msg, args...; kwargs...)
end
Logging.min_enabled_level(logger::DFTKLogger) = logger.min_level
Logging.shouldlog(::DFTKLogger, args...) = true

# Minimum log level is read from LocalPreferences.toml; defaults to Info.
#min_level = @load_preference("min_log_level"; default="0")
default_logger() = DFTKLogger(; io=stdout)
4 changes: 3 additions & 1 deletion src/eigen/diag_lobpcg_hyper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ function lobpcg_hyper(A, X0; maxiter=100, prec=nothing,
prec === nothing && (prec = I)

@assert !largest "Only seeking the smallest eigenpairs is implemented."
result = LOBPCG(A, X0, I, prec, tol, maxiter; n_conv_check, kwargs...)
result = with_logger(default_logger()) do
LOBPCG(A, X0, I, prec, tol, maxiter; n_conv_check, kwargs...)
end

n_conv_check === nothing && (n_conv_check = size(X0, 2))
converged = maximum(result.residual_norms[1:n_conv_check]) < tol
Expand Down
8 changes: 6 additions & 2 deletions src/scf/self_consistent_field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,9 @@ Overview of parameters:

converged = is_converged(info)
converged = MPI.bcast(converged, 0, MPI.COMM_WORLD) # Ensure same converged
callback(merge(info, (; converged)))
with_logger(default_logger()) do
callback(merge(info, (; converged)))
end

ρin + T(damping) .* mix_density(mixing, basis, Δρ; info...)
end
Expand All @@ -218,6 +220,8 @@ Overview of parameters:
ρ=ρout, α=damping, eigenvalues, occupation, εF, info.n_bands_converge,
n_iter, ψ, info.diagonalization, stage=:finalize, history_Δρ, history_Etot,
runtime_ns=time_ns() - start_ns, algorithm="SCF")
callback(info)
with_logger(default_logger()) do
callback(info)
end
info
end
17 changes: 11 additions & 6 deletions test/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using DFTK
using ForwardDiff
using LinearAlgebra
using Logging
silicon = TestCases.silicon

function compute_force(ε1, ε2; metal=false, tol=1e-10)
Expand All @@ -17,7 +18,7 @@
end
basis = PlaneWaveBasis(model; Ecut=5, kgrid=[2, 2, 2], kshift=[0, 0, 0])

response = ResponseOptions(; verbose=true)
response = ResponseOptions(; verbose=DFTK.default_logger().min_level Info)
is_converged = DFTK.ScfConvergenceForce(tol)
scfres = self_consistent_field(basis; is_converged, response)
compute_forces_cart(scfres)
Expand Down Expand Up @@ -61,6 +62,7 @@ end
using ForwardDiff
using LinearAlgebra
using ComponentArrays
using Logging
aluminium = TestCases.aluminium

function compute_band_energies::T) where {T}
Expand All @@ -76,9 +78,10 @@ end
basis = PlaneWaveBasis(model; Ecut=5, kgrid=[2, 2, 2], kshift=[0, 0, 0])

is_converged = DFTK.ScfConvergenceDensity(1e-10)
response = ResponseOptions(; verbose=DFTK.default_logger().min_level Info)
scfres = self_consistent_field(basis; is_converged, mixing=KerkerMixing(),
nbandsalg=FixedBands(; n_bands_converge=10),
damping=0.6, response=ResponseOptions(; verbose=true))
damping=0.6, response)

ComponentArray(
eigenvalues=stack([ev[1:10] for ev in scfres.eigenvalues]),
Expand All @@ -103,6 +106,7 @@ end
using LinearAlgebra
using ComponentArrays
using DftFunctionals
using Logging
silicon = TestCases.silicon

function compute_force(ε1::T) where {T}
Expand All @@ -115,8 +119,8 @@ end
basis = PlaneWaveBasis(model; Ecut=5, kgrid=[2, 2, 2], kshift=[0, 0, 0])

is_converged = DFTK.ScfConvergenceDensity(1e-10)
scfres = self_consistent_field(basis; is_converged,
response=ResponseOptions(; verbose=true))
response = ResponseOptions(; verbose=DFTK.default_logger().min_level Info)
scfres = self_consistent_field(basis; is_converged, response)
compute_forces_cart(scfres)
end

Expand Down Expand Up @@ -146,6 +150,7 @@ end
using DFTK
using ForwardDiff
using LinearAlgebra
using Logging

function compute_force::T) where {T}
# solve the 1D Gross-Pitaevskii equation with ElementGaussian potential
Expand All @@ -160,8 +165,8 @@ end
basis = PlaneWaveBasis(model; Ecut=500, kgrid=(1, 1, 1))
ρ = zeros(Float64, basis.fft_size..., 1)
is_converged = DFTK.ScfConvergenceDensity(1e-10)
scfres = self_consistent_field(basis; ρ, is_converged,
response=ResponseOptions(; verbose=true))
response = ResponseOptions(; verbose=DFTK.default_logger().min_level Info)
scfres = self_consistent_field(basis; ρ, is_converged, response)
compute_forces_cart(scfres)
end
derivative_ε = let ε = 1e-5
Expand Down
4 changes: 2 additions & 2 deletions test/printing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
disable_electrostatics_check=true, modelargs...)
basis = PlaneWaveBasis(model; basisargs...)

io = current_logger().min_level > Info ? devnull : stdout
io = DFTK.default_logger().min_level > Info ? devnull : stdout

@info model
show(io, "text/plain", model)
Expand All @@ -32,7 +32,7 @@
scfres = self_consistent_field(basis; nbandsalg=FixedBands(; n_bands_converge=6),
tol=1e-3)

io = current_logger().min_level > Info ? devnull : stdout
io = DFTK.default_logger().min_level > Info ? devnull : stdout

@info scfres.energies
show(io, "text/plain", scfres.energies)
Expand Down
7 changes: 5 additions & 2 deletions test/runtests_runner.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,11 @@ function dftk_testfilter(ti)
end

using Logging
using DFTK

# Don't print anything below or equal to warning level.
with_logger(ConsoleLogger(stdout, LogLevel(1001))) do
# Don't print anything below warning level.
DFTK.default_logger() = DFTK.DFTKLogger(; io=stdout, min_level=Warn)
#@set_preferences!("min_log_level" => "1001"; export_prefs=false)
with_logger(DFTK.default_logger()) do
@run_package_tests filter=dftk_testfilter verbose=true
end

0 comments on commit db0367b

Please sign in to comment.