Skip to content

Commit

Permalink
Support complex in QuadOverLinAtom and improve sumsquares (#678)
Browse files Browse the repository at this point in the history
Co-authored-by: Eric Hanson <[email protected]>
  • Loading branch information
odow and ericphanson authored May 20, 2024
1 parent 627525f commit 2978c32
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 5 deletions.
1 change: 0 additions & 1 deletion docs/src/manual/complex-domain_optimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ operate on complex variables as well. Notable exceptions include:

> - `inverse`
> - `square`
> - `quadoverlin`
> - `sqrt`
> - `geomean`
> - `huber`
Expand Down
2 changes: 0 additions & 2 deletions src/atoms/QolElemAtom.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ function square(x::AbstractExpr)
return QolElemAtom(x, constant(ones(x.size)))
end

sumsquares(x::AbstractExpr) = square(norm2(x))

invpos(x::AbstractExpr) = QolElemAtom(constant(ones(x.size)), x)

function Base.Broadcast.broadcasted(::typeof(/), x::Value, y::AbstractExpr)
Expand Down
24 changes: 22 additions & 2 deletions src/atoms/QuadOverLinAtom.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ mutable struct QuadOverLinAtom <: AbstractExpr
"[QuadOverLinAtom] quadoverlin arguments must be a vector and a scalar",
)
end
if iscomplex(y)
error(
"[QuadOverLinAtom] the second argument to quadoverlin must be real, not complex.",
)
end
return new((x, y), (1, 1))
end
end
Expand All @@ -29,15 +34,30 @@ curvature(::QuadOverLinAtom) = ConvexVexity()

function evaluate(q::QuadOverLinAtom)
x = evaluate(q.children[1])
return x' * x / evaluate(q.children[2])
# `real` is only necessary to fix the type; `x'*x` will always be real-valued.
return real(output(x' * x)) / evaluate(q.children[2])
end

function new_conic_form!(context::Context{T}, q::QuadOverLinAtom) where {T}
t = Variable()
x, y = q.children
f = vcat(t, (1 / T(2)) * y, x)
if iscomplex(x)
# ||x||₂² = ∑ᵢ |xᵢ|^2 = ∑ᵢ [re(xᵢ)² + im(xᵢ)²]
# = ||re(x)||₂² + ||im(x)||₂²
# = || vcat(re(x), im(y)) ||₂²
f = vcat(t, (1 / T(2)) * y, vcat(real(x), imag(x)))
else
f = vcat(t, (1 / T(2)) * y, x)
end
add_constraint!(context, Constraint{MOI.RotatedSecondOrderCone}(f))
return conic_form!(context, t)
end

quadoverlin(x::AbstractExpr, y::AbstractExpr) = QuadOverLinAtom(x, y)

function sumsquares(x::AbstractExpr)
if size(x, 2) != 1
return QuadOverLinAtom(reshape(x, length(x), 1), constant(1))
end
return QuadOverLinAtom(x, constant(1))
end
26 changes: 26 additions & 0 deletions test/test_atoms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1958,16 +1958,42 @@ function test_QuadOverLinAtom()
_test_atom(target) do context
return quadoverlin(Variable(2), Variable())
end
target = """
variables: t, x1, x2
minobjective: 1.0 * t
[t, 0.5, x1, x2] in RotatedSecondOrderCone(4)
"""
_test_atom(target) do context
return sumsquares(Variable(1, 2))
end
target = """
variables: t, y, x1, x2, x3, x4
minobjective: 1.0 * t
[t, 0.5 * y, x1, x2, x3, x4] in RotatedSecondOrderCone(6)
"""
_test_atom(target) do context
return quadoverlin(ComplexVariable(2), Variable())
end
@test_throws(
ErrorException(
"[QuadOverLinAtom] quadoverlin arguments must be a vector and a scalar",
),
quadoverlin(Variable(2), Variable(2))
)
@test_throws(
ErrorException(
"[QuadOverLinAtom] the second argument to quadoverlin must be real, not complex.",
),
quadoverlin(Variable(2), ComplexVariable())
)
x = Variable(2)
x.value = [2.0, 3.0]
atom = quadoverlin(x, constant(2.0))
@test evaluate(atom) 13 / 2
x = ComplexVariable(2)
x.value = [2.0 + im, 3.0]
atom = quadoverlin(x, constant(2.0))
@test evaluate(atom) 14 / 2
return
end

Expand Down

0 comments on commit 2978c32

Please sign in to comment.