diff --git a/src/nnmf.jl b/src/nnmf.jl index d17b953..ca85762 100644 --- a/src/nnmf.jl +++ b/src/nnmf.jl @@ -32,8 +32,8 @@ function nnmf_model(m::Int = 100, n::Int = 50, k::Int = 10, T::DataType = Float6 selected = (m * k + 1):((m + n) * k) function resid!(r, x) - W = reshape_array(x[1:(m * k)], (m, k)) - H = reshape_array(x[(m * k + 1):end], (k, n)) + W = reshape_array(view(x, 1:(m * k)), (m, k)) + H = reshape_array(view(x, (m * k + 1):((m + n) * k)), (k, n)) mul!(WH, W, H) for i ∈ eachindex(r) r[i] = A[i] - WH[i] @@ -50,8 +50,8 @@ function nnmf_model(m::Int = 100, n::Int = 50, k::Int = 10, T::DataType = Float6 resid!(r, x) minusR = reshape_array(r, (m, n)) minusR .*= -1 - W_T = reshape_array(x[1:(m * k)], (m, k))' - H_T = reshape_array(x[(m * k + 1):end], (k, n))' + W_T = reshape_array(view(x, 1:(m * k)), (m, k))' + H_T = reshape_array(view(x, (m * k + 1):((m + n) * k)), (k, n))' mul!(gw, minusR, H_T) mul!(gh, W_T, minusR) for i ∈ eachindex(gw) @@ -63,14 +63,56 @@ function nnmf_model(m::Int = 100, n::Int = 50, k::Int = 10, T::DataType = Float6 return g end + function jacv!(Jv, x, v) + W = reshape_array(view(x, 1:(m * k)), (m, k)) + H = reshape_array(view(x, (m * k + 1):((m + n) * k)), (k, n)) + W_v = reshape_array(view(v, 1:(m * k)), (m, k)) + H_v = reshape_array(view(v, (m * k + 1):((m + n) * k)), (k, n)) + mul!(WH, W_v, H) + for i ∈ eachindex(WH) + Jv[i] = -WH[i] + end + mul!(WH, W, H_v) + for i ∈ eachindex(WH) + Jv[i] -= WH[i] + end + return Jv + end + + function jactv!(Jtv, x, w) + W_T = reshape_array(view(x, 1:(m * k)), (m, k))' + H_T = reshape_array(view(x, (m * k + 1):((m + n) * k)), (k, n))' + X_v = reshape_array(w, (m, n)) + mul!(gw, X_v, H_T) + mul!(gh, W_T, X_v) + for i ∈ eachindex(gw) + Jtv[i] = -gw[i] + end + for i ∈ eachindex(gh) + Jtv[i + m * k] = -gh[i] + end + return Jtv + end + x0 = 3 * rand(eltype(A), k * (m + n)) + FirstOrderModel( obj, grad!, - 3 * rand(eltype(A), k * (m + n)), + x0, name = "NNMF", lvar = zeros(eltype(A), k * (m + n)), uvar = fill!(zeros(eltype(A), k * (m + n)), Inf), ), + FirstOrderNLSModel( + resid!, + jacv!, + jactv!, + m * n, + x0, + name = "NNMF-LS", + lvar = zeros(eltype(A), k * (m + n)), + uvar = fill!(zeros(eltype(A), k * (m + n)), Inf), + ), A[:], selected end diff --git a/test/runtests.jl b/test/runtests.jl index 9ae7bc0..bde6589 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,6 +18,9 @@ function test_objectives(model, nls_model, x = model.meta.x0) g = grad(model, x) JtF = jtprod_residual(nls_model, x, F) @test all(g .≈ JtF) + + JF = jprod_residual(nls_model, x, x) + @test JF' * F ≈ JtF' * x end @testset "BPDN" begin @@ -137,12 +140,14 @@ end end @testset "NNMF" begin - # TODO: complete tests after NLS model has been implemented m, n, k = 100, 50, 10 - model, sol, selected = nnmf_model(m, n, k) + model, nls_model, sol, selected = nnmf_model(m, n, k) @test selected == (m * k + 1):((m + n) * k) - @test typeof(model) <: FirstOrderModel - @test typeof(sol) == typeof(model.meta.x0) + test_well_defined(model, nls_model, sol) + @test nls_model.nls_meta.nequ == m * n @test all(model.meta.lvar .== 0) @test all(model.meta.uvar .== Inf) + @test all(nls_model.meta.lvar .== 0) + @test all(nls_model.meta.uvar .== Inf) + test_objectives(model, nls_model) end