Skip to content

Commit

Permalink
(almost) working
Browse files Browse the repository at this point in the history
  • Loading branch information
xtalax committed Aug 4, 2023
1 parent 72b7855 commit dfc7d94
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 20 deletions.
1 change: 1 addition & 0 deletions src/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ end
# turn `f(x...)` into `term(f, x...)`
#
function call2term(expr, arrs=[])
(expr isa QuoteNode) && return expr
!(expr isa Expr) && return :($unwrap($expr))
if expr.head == :call
if expr.args[1] == :(:)
Expand Down
44 changes: 24 additions & 20 deletions src/diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -787,20 +787,20 @@ function SymbolicUtils.substitute(op::Differential, dict; kwargs...)
@set! op.x = substitute(op.x, dict; kwargs...)
end


#######################################################################################################################
# Vector Calculus
#######################################################################################################################
abstract type ArrayOperator <: AbstractOperator end
abstract type ArrayOperator end

struct ArrayDifferentialOperator <: ArrayOperator
"""The variables to differentiate with respect to."""
"""The variables to differentiate with resp≈ect to."""
vars
differentials
name
ArrayDifferentialOperator(vars, differentials, name) = new(ArrayOp(vars), ArrayOp(differentials), name)
ArrayDifferentialOperator(vars::ArrayOp, diffs, name) = new(vars, ArrayOp(diffs), name)
ArrayDifferentialOperator(vars, differentials, name) = new(vars, differentials, name)
end
Nabla(vars) = ArrayDifferentialOperator(ArrayOp(value.(vars)), map(Differential, value.(vars)), "")
Nabla(vars) = ArrayDifferentialOperator(value.(vars), map(Differential, scalarize(value.(vars))), "")
Div(vars) = (x) -> Nabla(vars) x
Curl(vars) = (x) -> Nabla(vars) × x
Laplacian(vars) = Nabla(vars) Nabla(vars)
Expand All @@ -809,8 +809,7 @@ Laplacian(vars) = Nabla(vars) ⋅ Nabla(vars)

function (D::ArrayDifferentialOperator)(x::SymVec)
@assert length(D.vars) == length(x) "Vector must be same length as vars in Operator $(D.name)."
_call(d, x) = d(x)
map(_call, zip(D.differentials, x))
@arrayop (i,) unwrap((D.differentials)[i](x[i])) term=D(y) reduce=+
end
(D::ArrayDifferentialOperator)(x::Arr) = Arr(D(value(x)))

Expand All @@ -822,20 +821,20 @@ end

function LinearAlgebra.dot(D::ArrayDifferentialOperator, x::SymVec)
@assert length(D.vars) == length(x) "Vector must be same length as vars in Operator $(D.name)."
_call(d, x) = d(x)
sum(_call, zip(D.differentials, x))
sum(D(x))
end
LinearAlgebra.dot(D::ArrayDifferentialOperator, x::Arr) = Arr(D value(x))
LinearAlgebra.dot(D::ArrayDifferentialOperator, x::Arr) = D value(x)

function LinearAlgebra.dot(x::SymVec, D::ArrayDifferentialOperator)
@assert length(D.vars) == length(x) "Vector must be same length as vars in Operator $(D.name)."
(y) -> sum((X, D) -> X*D(y), zip(x, D.differentials))
(y) -> sum(@arrayop (i,) x[i]*D.differentials[i](y) term = (xD)(y))
end
LinearAlgebra.dot(x::Arr, D::ArrayDifferentialOperator) = value(x) D

function LinearAlgebra.dot(D1::ArrayDifferentialOperator, D2::ArrayDifferentialOperator)
@assert all(scalarize(isequal.(D1.vars, D2.vars))) "Operators have different variables and cannot be composed."
(x) -> sum(i -> (D1.differentials[i] D2.differentials[i])(x), eachindex(D1.vars))
lap = x -> sum((D1.differentials[i] D2.differentials[i])(x) for i in 1:length(D1.vars))
(x) -> @arrayop (i,) lap(x[i]) term=(D1D2)(x) reduce=+
end

function εijk_cond(i, j, k)
Expand All @@ -855,14 +854,18 @@ function crosscompose(a, b)
return [v1, v2, v3]
end

function crosscall(a, b)
v1 = a[2](b[3]) - a[3](b[2])
v2 = a[3](b[1]) - a[1](b[3])
v3 = a[1](b[2]) - a[2](b[1])
return [v1, v2, v3]
end
function LinearAlgebra.cross(D::ArrayDifferentialOperator, x::SymVec)
@assert length(D.vars) == length(x) == 3 "Cross product is only defined in 3 dimensions."
ε = [εijk_cond(i, j, k) for i in 1:3, j in 1:3, k in 1:3]
curl(i) = sum(j -> sum(k -> expand_derivatives(ε[i, j, k]*D.differentials[j](x[k])), 1:3), 1:3)

return map(curl, ArrayOp(1:3))
curl = crosscall(D.differentials, x)
@arrayop (i,) curl[i] term=D×y
end
LinearAlgebra.cross(D::ArrayDifferentialOperator, x::Arr) = Arr(D × x)
LinearAlgebra.cross(D::ArrayDifferentialOperator, x::Arr) = Arr(D × value(x))

function LinearAlgebra.cross(D1::ArrayDifferentialOperator, D2::ArrayDifferentialOperator)
@assert length(D1.vars) == length(D2.vars) == 3 "Cross product is only defined in 3 dimensions."
Expand All @@ -871,13 +874,14 @@ function LinearAlgebra.cross(D1::ArrayDifferentialOperator, D2::ArrayDifferentia
ArrayDifferentialOperator(D1.vars, crosscompose(D1.differentials, D2.differentials), "("*D1.name*"×"*D2.name*")")
end

SymbolicUtils.promote_symtype(::Nabla, x) = x
SymbolicUtils.promote_symtype(::ArrayDifferentialOperator, x) = x

is_derivative(x) = istree(x) ? operation(x) isa ArrayDifferentialOperator : false

Base.show(io::IO, D::ArrayDifferentialOperator) = print(io, "(D.name)(", scalarize(D.vars), ")")
Base.show(io::IO, D::ArrayDifferentialOperator) = print(io, D.name)
Base.nameof(D::ArrayDifferentialOperator) = Symbol(D.name)

function Base.:(==)(D1::ArrayDifferentialOperator, D2::ArrayDifferentialOperator)
@variables x[1:length(D1.vars)]
all(scalarize(isequal.(D1.vars, D2.vars))) && all(scalarize(isequal.(D1(x), D2(x))))
end
end

0 comments on commit dfc7d94

Please sign in to comment.