diff --git a/src/rotate.jl b/src/rotate.jl index 6824c37..6137ee3 100644 --- a/src/rotate.jl +++ b/src/rotate.jl @@ -75,6 +75,7 @@ Perform a rotation of the factor loading matrix `Λ` using a rotation `method`. starting matrices. - `reflect`: Switch signs of the columns of the rotated loading matrix such that the sum of loadings is non-negative for all columns (default: true) +- `use_threads`: Parallelize random starts using threads (default: false) - `verbose`: Print logging statements (default: true) - `logperiod`: How frequently to report the optimization state (default: 100). @@ -109,6 +110,7 @@ function rotate( reflect = true, f_atol = 1e-6, g_atol = 1e-6, + use_threads::Bool = false, kwargs..., ) loglevel = verbose ? Logging.Info : Logging.Debug @@ -130,20 +132,24 @@ function rotate( @warn "Requested random starts but keyword argument `init` was provided. Ignoring initial starting values in `init`." end + Q_lock = ReentrantLock() Q_min = Inf rotation = initialize(rotation_type(method), nothing, L; loglevel = Logging.Debug) - n_diverged = 0 + n_diverged = Threads.Atomic{Int}(0) n_at_Q_min = 0 - for _ in 1:starts + start_chan = Channel{Int}(1) do ch + foreach(Base.Fix1(put!, ch), 1:starts) + end + (use_threads ? Threads.foreach : foreach)(start_chan) do _ init = random_orthogonal_matrix(size(L, 2)) random_rotation = try _rotate(L, method; g_atol, loglevel, kwargs..., init) catch err if err isa ConvergenceError @logmsg loglevel err.msg - n_diverged += 1 - continue + Threads.atomic_add!(n_diverged, 1) + return nothing else rethrow() end @@ -151,23 +157,26 @@ function rotate( Q_current = minimumQ(random_rotation) - if isapprox(Q_current, Q_min, atol = f_atol) - n_at_Q_min += 1 - elseif Q_current < Q_min - @logmsg loglevel "Found new minimum at Q = $(Q_current)" - n_at_Q_min = 1 - Q_min = Q_current - rotation = random_rotation + lock(Q_lock) do + if isapprox(Q_current, Q_min, atol = f_atol) + n_at_Q_min += 1 + elseif Q_current < Q_min + @logmsg loglevel "Found new minimum at Q = $(Q_current)" + n_at_Q_min = 1 + Q_min = Q_current + rotation = random_rotation + end end + return nothing end @logmsg loglevel "Finished $(starts) rotations with random starts." - if n_diverged == starts + if n_diverged[] == starts msg = "All $(starts) rotations did not converge. Please check the provided rotation method and/or loading matrix." throw(ConvergenceError(msg)) - elseif n_diverged > 0 - @warn "There were $(n_diverged) rotations that did not converge. Please check the provided rotation method and/or loading matrix." + elseif n_diverged[] > 0 + @warn "There were $(n_diverged[]) rotations that did not converge. Please check the provided rotation method and/or loading matrix." else @logmsg loglevel "There were 0 rotations that did not converge." end