Skip to content

Commit

Permalink
rotate(): multi-thread random starts
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexey Stukalov authored and alyst committed Apr 25, 2024
1 parent 5e03f0a commit f2a1235
Showing 1 changed file with 23 additions and 14 deletions.
37 changes: 23 additions & 14 deletions src/rotate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ Perform a rotation of the factor loading matrix `Λ` using a rotation `method`.
starting matrices.
- `reflect`: Switch signs of the columns of the rotated loading matrix such that the sum of
loadings is non-negative for all columns (default: true)
- `use_threads`: Parallelize random starts using threads (default: false)
- `verbose`: Print logging statements (default: true)
- `logperiod`: How frequently to report the optimization state (default: 100).
Expand Down Expand Up @@ -109,6 +110,7 @@ function rotate(
reflect = true,
f_atol = 1e-6,
g_atol = 1e-6,
use_threads::Bool = false,
kwargs...,
)
loglevel = verbose ? Logging.Info : Logging.Debug
Expand All @@ -130,44 +132,51 @@ function rotate(
@warn "Requested random starts but keyword argument `init` was provided. Ignoring initial starting values in `init`."
end

Q_lock = ReentrantLock()
Q_min = Inf
rotation = initialize(rotation_type(method), nothing, L; loglevel = Logging.Debug)
n_diverged = 0
n_diverged = Threads.Atomic{Int}(0)
n_at_Q_min = 0

for _ in 1:starts
start_chan = Channel{Int}(1) do ch
foreach(Base.Fix1(put!, ch), 1:starts)
end
(use_threads ? Threads.foreach : foreach)(start_chan) do _
init = random_orthogonal_matrix(size(L, 2))
random_rotation = try
_rotate(L, method; g_atol, loglevel, kwargs..., init)
catch err
if err isa ConvergenceError
@logmsg loglevel err.msg
n_diverged += 1
continue
Threads.atomic_add!(n_diverged, 1)
return nothing
else
rethrow()
end
end

Q_current = minimumQ(random_rotation)

if isapprox(Q_current, Q_min, atol = f_atol)
n_at_Q_min += 1
elseif Q_current < Q_min
@logmsg loglevel "Found new minimum at Q = $(Q_current)"
n_at_Q_min = 1
Q_min = Q_current
rotation = random_rotation
lock(Q_lock) do
if isapprox(Q_current, Q_min, atol = f_atol)
n_at_Q_min += 1
elseif Q_current < Q_min
@logmsg loglevel "Found new minimum at Q = $(Q_current)"
n_at_Q_min = 1
Q_min = Q_current
rotation = random_rotation
end
end
return nothing
end

@logmsg loglevel "Finished $(starts) rotations with random starts."

if n_diverged == starts
if n_diverged[] == starts
msg = "All $(starts) rotations did not converge. Please check the provided rotation method and/or loading matrix."
throw(ConvergenceError(msg))
elseif n_diverged > 0
@warn "There were $(n_diverged) rotations that did not converge. Please check the provided rotation method and/or loading matrix."
elseif n_diverged[] > 0
@warn "There were $(n_diverged[]) rotations that did not converge. Please check the provided rotation method and/or loading matrix."
else
@logmsg loglevel "There were 0 rotations that did not converge."
end
Expand Down

0 comments on commit f2a1235

Please sign in to comment.