diff --git a/Project.toml b/Project.toml index dfe68d9..2a347d9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ParametricOperators" uuid = "db9e0614-c73c-4112-a40c-114e5b366d0d" -authors = ["Thomas Grady "] +authors = ["Thomas Grady ", "Richard Rex "] version = "0.1.0" [deps] @@ -9,11 +9,13 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637" LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" Match = "7eb4fadd-790c-5f42-8a69-bfa0b872bfbf" +OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" diff --git a/src/ParBroadcasted.jl b/src/ParBroadcasted.jl index 74c0997..9de0ec7 100644 --- a/src/ParBroadcasted.jl +++ b/src/ParBroadcasted.jl @@ -4,7 +4,7 @@ struct ParBroadcasted{D,R,L,P,F} <: ParOperator{D,R,L,P,Internal} op::F comm::MPI.Comm root::Int - ParBroadcasted(op, comm, root::Int = 0) = new{DDT(op),RDT(op),linearity(op),parametricity(op),typeof(op)}(op, comm, root) + ParBroadcasted(op, comm::Any=MPI.COMM_WORLD, root::Int = 0) = new{DDT(op),RDT(op),linearity(op),parametricity(op),typeof(op)}(op, comm, root) end bcasted(A::ParOperator{D,R,L,P,External}, comm = MPI.COMM_WORLD, root = 0) where {D,R,L,P} = @@ -32,4 +32,23 @@ end (A::ParBroadcasted{D,R,L,<:Applicable,F})(x::X) where {D,R,L,F,X<:AbstractVector{D}} = A.op(x) (A::ParBroadcasted{D,R,L,<:Applicable,F})(x::X) where {D,R,L,F,X<:AbstractMatrix{D}} = A.op(x) -*(x::X, A::ParBroadcasted{D,R,Linear,<:Applicable,F}) where {D,R,F,X<:AbstractMatrix{D}} = x*A.op \ No newline at end of file +*(x::X, A::ParBroadcasted{D,R,Linear,<:Applicable,F}) where {D,R,F,X<:AbstractMatrix{D}} = x*A.op ++(x::X, A::ParBroadcasted{D,R,Linear,<:Applicable,F}) where {D,R,F,X<:AbstractMatrix{D}} = x+A.op + +function ChainRulesCore.rrule(A::ParBroadcasted{D,R,L,Parametric,F}, params) where {D,R,L,F} + op_out = A(params) + function pullback(op) + device = get_device(op.op.params) + θ_global = MPI.Reduce(op.op.params |> cpu, MPI.SUM, A.root, A.comm) + + if MPI.Comm_rank(A.comm) == A.root + if device == "cpu" + return NoTangent(), Dict(A.op=>θ_global) + end + return NoTangent(), Dict(A.op=>(θ_global |> gpu)) + else + return NoTangent(), NoTangent() + end + end + return op_out, pullback +end diff --git a/src/ParCommon.jl b/src/ParCommon.jl index 869a7b7..fb0f155 100644 --- a/src/ParCommon.jl +++ b/src/ParCommon.jl @@ -66,6 +66,12 @@ end function rotate_dims_batched(x, rot) n = length(size(x)) perm = [circshift(collect(1:n-1), rot)..., n] + + device = get_device(x) + if device != "cpu" + 0 in size(x) && return permutedims(x |> cpu, perm) |> gpu + end + return permutedims(x, perm) end @@ -83,4 +89,15 @@ zeros_like(::AbstractArray{T}, dims...) where {T} = zeros(T, dims...) if CUDA.functional() zeros_like(::CuArray{T}, dims) where {T} = CUDA.zeros(T, dims) zeros_like(::CuArray{T}, dims...) where {T} = CUDA.zeros(T, dims...) -end \ No newline at end of file +end + +""" +Returns whether the input is on a NVIDIA GPU +""" +function get_device(x::AbstractArray) + if isa(x, CUDA.CuArray) + return "gpu" + else + return "cpu" + end +end diff --git a/src/ParDFT.jl b/src/ParDFT.jl index 6ca2ee0..0045c5d 100644 --- a/src/ParDFT.jl +++ b/src/ParDFT.jl @@ -23,12 +23,12 @@ Range(A::ParDFT) = A.m complexity(A::ParDFT{D,R}) where {D,R} = elementwise_multiplication_cost(R)*A.n*log2(A.n) -(A::ParDFT{D,R})(x::X) where {D<:Complex,R,X<:AbstractMatrix{D}} = convert(Matrix{R}, fft(x, 1) ./ sqrt(A.n)) -(A::ParDFT{D,R})(x::X) where {D<:Real,R,X<:AbstractMatrix{D}} = convert(Matrix{R}, rfft(x, 1) ./ sqrt(A.n)) +(A::ParDFT{D,R})(x::X) where {D<:Complex,R,X<:AbstractMatrix{D}} = 0 in size(x) ? x : fft(x, 1) +(A::ParDFT{D,R})(x::X) where {D<:Real,R,X<:AbstractMatrix{D}} = rfft(x, 1) (A::ParDFT{D,R})(x::X) where {D,R,X<:AbstractVector{D}} = vec(A(reshape(x, length(x), 1))) -(A::ParAdjoint{D,R,NonParametric,ParDFT{D,R}})(x::X) where {D<:Complex,R,X<:AbstractMatrix{R}} = ifft(x, 1).*convert(real(D), sqrt(A.op.n)) -(A::ParAdjoint{D,R,NonParametric,ParDFT{D,R}})(x::X) where {D<:Real,R,X<:AbstractMatrix{R}} = irfft(x, A.op.n, 1).*convert(D, sqrt(A.op.n)) +(A::ParAdjoint{D,R,NonParametric,ParDFT{D,R}})(x::X) where {D<:Complex,R,X<:AbstractMatrix{R}} = 0 in size(x) ? x : ifft(x, 1) +(A::ParAdjoint{D,R,NonParametric,ParDFT{D,R}})(x::X) where {D<:Real,R,X<:AbstractMatrix{R}} = irfft(x, A.op.n, 1) (A::ParAdjoint{D,R,NonParametric,ParDFT{D,R}})(x::X) where {D,R,X<:AbstractVector{R}} = vec(A(reshape(x, length(x), 1))) to_Dict(A::ParDFT{D,R}) where {D,R} = Dict{String, Any}("type" => "ParDFT", "T" => string(D), "n" => A.n, "m" => A.m) diff --git a/src/ParDiagonal.jl b/src/ParDiagonal.jl index 8797204..53aa5e8 100644 --- a/src/ParDiagonal.jl +++ b/src/ParDiagonal.jl @@ -5,8 +5,11 @@ Diagonal matrix (elementwise) operator. """ struct ParDiagonal{T} <: ParLinearOperator{T,T,Parametric,External} n::Int - ParDiagonal(T, n) = new{T}(n) - ParDiagonal(n) = new{Float64}(n) + id::Any + ParDiagonal(T::DataType, n::Int) = new{T}(n, uuid4(Random.GLOBAL_RNG)) + ParDiagonal(n::Int) = new{Float64}(n, uuid4(Random.GLOBAL_RNG)) + ParDiagonal(T::DataType, n::Int, id) = new{T}(n, id) + ParDiagonal(n::Int, id) = new{Float64}(n, id) end Domain(A::ParDiagonal) = A.n @@ -27,7 +30,21 @@ end *(x::X, A::ParParameterized{T,T,Linear,ParAdjoint{T,T,Parametric,ParDiagonal{T}},V}) where {T,V,X<:AbstractVector{T}} = x.*conj(A.params[A.op.op]) *(x::X, A::ParParameterized{T,T,Linear,ParAdjoint{T,T,Parametric,ParDiagonal{T}},V}) where {T,V,X<:AbstractMatrix{T}} = x.*conj(A.params[A.op.op]) -to_Dict(A::ParDiagonal{T}) where {T} = Dict{String, Any}("type" => "ParDiagonal", "T" => string(T), "n" => A.n) +function to_Dict(A::ParDiagonal{T}) where {T} + rv = Dict{String, Any}( + "type" => "ParDiagonal", + "T" => string(T), + "n" => A.n + ) + if typeof(A.id) == String + rv["id"] = A.id + elseif typeof(A.id) == UUID + rv["id"] = "UUID:$(string(A.id))" + else + throw(ParException("I don't know how to encode id of type $(typeof(A.id))")) + end + rv +end function from_Dict(::Type{ParDiagonal}, d) ts = d["T"] @@ -35,5 +52,14 @@ function from_Dict(::Type{ParDiagonal}, d) throw(ParException("unknown data type `$ts`")) end dtype = Data_TYPES[ts] - ParDiagonal(dtype, d["n"]) + mid = d["id"] + if startswith(mid, "UUID:") + mid = UUID(mid[6:end]) + end + ParDiagonal(dtype, d["n"], mid) +end + +function distribute(A::ParDiagonal{T}, comm::MPI.Comm = MPI.COMM_WORLD) where {T} + local_n = local_size(A.n, MPI.Comm_rank(comm), MPI.Comm_size(comm)) + return ParDiagonal(T, local_n) end diff --git a/src/ParIdentity.jl b/src/ParIdentity.jl index e0d3367..d93f99d 100644 --- a/src/ParIdentity.jl +++ b/src/ParIdentity.jl @@ -31,3 +31,5 @@ function from_Dict(::Type{ParIdentity}, d) dtype = Data_TYPES[ts] ParIdentity(dtype, d["n"]) end + +kron(A::ParIdentity{T}, B::ParIdentity{T}) where {T} = ParIdentity(T,B.n*A.n) diff --git a/src/ParKron.jl b/src/ParKron.jl index 7d59d24..592c86e 100644 --- a/src/ParKron.jl +++ b/src/ParKron.jl @@ -71,7 +71,7 @@ end kron(A::ParLinearOperator, B::ParLinearOperator) = ParKron(A, B) kron(A::ParKron, B::ParLinearOperator) = ParKron(A.ops..., B) kron(A::ParLinearOperator, B::ParKron) = ParKron(A, B.ops...) -kron(A::ParKron, B::ParKron) = ParKron(A.ops..., B.ops...) +⊗(A::ParKron, B::ParKron) = ParKron(A.ops..., B.ops...) ⊗(A::ParLinearOperator, B::ParLinearOperator) = kron(A, B) Domain(A::ParSeparableOperator) = prod(map(Domain, children(A))) @@ -236,15 +236,26 @@ function latex_string(A::ParKron{D,R,P,F,N}) where {D,R,P,F,N} return out end +rebuild(A::ParBroadcasted{D,R,L,Parametric,F}, cs) where {D,R,L,F<:ParKron} = rebuild(A.op, collect(map(c -> parametricity(c) == Parametric ? ParBroadcasted(c, A.comm, A.root) : c, children(cs[1])))) + """ -Distributes Kronecker product over the given communicator +Distributes Kronecker product over the given dimensions """ -function distribute(A::ParKron, dims_in, dims_out=dims_in, parent_comm=MPI.COMM_WORLD) - +function distribute(A::ParKron, dims_in::Vector{Int64}, dims_out::Vector{Int64}=dims_in, parent_comm=MPI.COMM_WORLD) comm_in = MPI.Cart_create(parent_comm, dims_in) comm_out = MPI.Cart_create(parent_comm, dims_out) + return distribute(A, comm_in, comm_out, parent_comm) +end + +""" +Distributes Kronecker product over the given communicator +""" +function distribute(A::ParKron, comm_in::MPI.Comm, comm_out::MPI.Comm, parent_comm=MPI.COMM_WORLD) + dims, _, _ = MPI.Cart_get(comm_in) + dims_out, _, _ = MPI.Cart_get(comm_out) + N = length(dims) @assert length(A.ops) == N @@ -271,6 +282,7 @@ function distribute(A::ParKron, dims_in, dims_out=dims_in, parent_comm=MPI.COMM_ coords_i = MPI.Cart_coords(comm_i) # Create repartition operator + !isequal(dims_prev, dims_i) && (MPI.Comm_rank(parent_comm) == 0) && println("Adding Repartition") !isequal(dims_prev, dims_i) && pushfirst!(ops, ParRepartition(DDT(Ai), comm_prev, comm_i, tuple(size_curr...))) # Create Kronecker w/ distributed identities @@ -284,7 +296,7 @@ function distribute(A::ParKron, dims_in, dims_out=dims_in, parent_comm=MPI.COMM_ pushfirst!(idents_dim_upper, ParDistributed(ParIdentity(DDT(Ai), size_curr[j]), coords_i[j], dims_i[j])) end - pushfirst!(ops, ParKron(idents_dim_lower..., ParBroadcasted(Ai, comm_i), idents_dim_upper...)) + pushfirst!(ops, ParKron(idents_dim_lower..., rebuild(ParBroadcasted(Ai, comm_i), [Ai]), idents_dim_upper...)) size_curr[d] = Range(Ai) comm_prev = comm_i diff --git a/src/ParMatrix.jl b/src/ParMatrix.jl index 3a19faf..07823d3 100644 --- a/src/ParMatrix.jl +++ b/src/ParMatrix.jl @@ -19,22 +19,41 @@ Range(A::ParMatrix) = A.m complexity(A::ParMatrix{T}) where {T} = elementwise_multiplication_cost(T)*A.n*A.m function init!(A::ParMatrix{T}, d::Parameters) where {T<:Real} - d[A] = rand(T, A.m, A.n)/convert(T, sqrt(A.m*A.n)) + if A.n == 1 + d[A] = zeros(T, A.m, A.n) + return + end + scale = sqrt(24.0f0 / sum((A.m, A.n))) + d[A] = (rand(T, (A.n, A.m)) .- 0.5f0) .* scale + d[A] = permutedims(d[A], [2, 1]) end function init!(A::ParMatrix{T}, d::Parameters) where {T<:Complex} - d[A] = rand(T, A.m, A.n)/convert(real(T), sqrt(A.m*A.n)) + if A.n == 1 + d[A] = zeros(T, A.m, A.n) + return + end + d[A] = rand(T, A.n, A.m)/convert(real(T), sqrt(A.m*A.n)) + d[A] = permutedims(d[A], [2, 1]) end (A::ParParameterized{T,T,Linear,ParMatrix{T},V})(x::X) where {T,V,X<:AbstractVector{T}} = A.params*x (A::ParParameterized{T,T,Linear,ParMatrix{T},V})(x::X) where {T,V,X<:AbstractMatrix{T}} = A.params*x (A::ParParameterized{T,T,Linear,ParAdjoint{T,T,Parametric,ParMatrix{T}},V})(x::X) where {T,V,X<:AbstractVector{T}} = A.params[A.op.op]'*x (A::ParParameterized{T,T,Linear,ParAdjoint{T,T,Parametric,ParMatrix{T}},V})(x::X) where {T,V,X<:AbstractMatrix{T}} = A.params[A.op.op]'*x + *(x::X, A::ParParameterized{T,T,Linear,ParMatrix{T},V}) where {T,V,X<:AbstractVector{T}} = x*A.params *(x::X, A::ParParameterized{T,T,Linear,ParMatrix{T},V}) where {T,V,X<:AbstractMatrix{T}} = x*A.params *(x::X, A::ParParameterized{T,T,Linear,ParAdjoint{T,T,Parametric,ParMatrix{T}},V}) where {T,V,X<:AbstractVector{T}} = x*A.params[A.op.op]' *(x::X, A::ParParameterized{T,T,Linear,ParAdjoint{T,T,Parametric,ParMatrix{T}},V}) where {T,V,X<:AbstractMatrix{T}} = x*A.params[A.op.op]' ++(x::X, A::ParParameterized{T,T,Linear,ParMatrix{T},V}) where {T,V,X<:AbstractVector{T}} = x.+A.params ++(x::X, A::ParParameterized{T,T,Linear,ParMatrix{T},V}) where {T,V,X<:AbstractArray{T}} = x.+A.params ++(x::X, A::ParParameterized{T,T,Linear,ParMatrix{T},V}) where {T,V,X<:AbstractMatrix{T}} = x.+A.params ++(x::X, A::ParParameterized{T,T,Linear,ParAdjoint{T,T,Parametric,ParMatrix{T}},V}) where {T,V,X<:AbstractVector{T}} = x+A.params[A.op.op]' ++(x::X, A::ParParameterized{T,T,Linear,ParAdjoint{T,T,Parametric,ParMatrix{T}},V}) where {T,V,X<:AbstractArray{T}} = x+A.params[A.op.op]' ++(x::X, A::ParParameterized{T,T,Linear,ParAdjoint{T,T,Parametric,ParMatrix{T}},V}) where {T,V,X<:AbstractMatrix{T}} = x+A.params[A.op.op]' + function to_Dict(A::ParMatrix{T}) where {T} rv = Dict{String, Any}( "type" => "ParMatrix", diff --git a/src/ParOperator.jl b/src/ParOperator.jl index 617b7fc..f69f599 100644 --- a/src/ParOperator.jl +++ b/src/ParOperator.jl @@ -132,7 +132,7 @@ const Parameters = Dict{<:ParOperator,Any} Move objects to cpu. """ cpu(x::CuArray{<:Number}) = Array(x) -cpu(x::Vector{CuArray}) = [cpu(y) fpr y in x] +cpu(x::Vector{CuArray}) = [cpu(y) for y in x] cpu(x::AbstractArray) = x cpu(x::Parameters) = Dict(k => cpu(v) for (k, v) in pairs(x)) @@ -141,7 +141,7 @@ if CUDA.functional() Move objects to gpu. """ gpu(x::AbstractArray{<:Number}) = CuArray(x) - gpu(x::Vector{<:AbstractArray}) = [gpu(y) fpr y in x] + gpu(x::Vector{<:AbstractArray}) = [gpu(y) for y in x] gpu(x::CuArray) = x gpu(x::Parameters) = Dict(k => gpu(v) for (k, v) in pairs(x)) end diff --git a/src/ParReduce.jl b/src/ParReduce.jl new file mode 100644 index 0000000..e2d32c4 --- /dev/null +++ b/src/ParReduce.jl @@ -0,0 +1,47 @@ +export ParReduce + +""" +Reduction Operator. Reduce across the given communicator +""" +struct ParReduce{T} <: ParOperator{T,T,Linear,NonParametric,External} + comm::MPI.Comm + + ParReduce() = new{Float64}(MPI.COMM_WORLD) + ParReduce(T::DataType) = new{T}(MPI.COMM_WORLD) + ParReduce(T::DataType, comm::MPI.Comm) = new{T}(comm) + ParReduce(comm::MPI.Comm) = new{Float64}(comm) +end + +function (A::ParReduce{T})(x::X) where {T,X<:AbstractVector{T}} + device = get_device(x) + if device == "cpu" + return MPI.Allreduce(x, MPI.SUM, A.comm) + elseif device == "gpu" + return MPI.Allreduce(x |> cpu, MPI.SUM, A.comm) |> gpu + end +end + +function (A::ParReduce{T})(x::X) where {T,X<:AbstractArray{T}} + device = get_device(x) + if device == "cpu" + return MPI.Allreduce(x, MPI.SUM, A.comm) + elseif device == "gpu" + return MPI.Allreduce(x |> cpu, MPI.SUM, A.comm) |> gpu + end +end + +function ChainRulesCore.rrule(A::ParReduce{T}, x::X) where {T,X<:AbstractVector{T}} + op_out = A(x) + function pullback(op) + return NoTangent(), op + end + return op_out, pullback +end + +function ChainRulesCore.rrule(A::ParReduce{T}, x::X) where {T,X<:AbstractArray{T}} + op_out = A(x) + function pullback(op) + return NoTangent(), op + end + return op_out, pullback +end diff --git a/src/ParRepartition.jl b/src/ParRepartition.jl index 5fe9477..f0f07b9 100644 --- a/src/ParRepartition.jl +++ b/src/ParRepartition.jl @@ -7,8 +7,8 @@ mutable struct ParRepartition{T,N} <: ParLinearOperator{T,T,NonParametric,Extern global_size::NTuple{N, Integer} local_size_in::NTuple{N, Integer} local_size_out::NTuple{N, Integer} - send_data::OrderedDict{Integer, Tuple{NTuple{N, UnitRange{Integer}}, Option{Vector{T}}}} - recv_data::OrderedDict{Integer, Tuple{NTuple{N, UnitRange{Integer}}, Option{Vector{T}}}} + send_data::OrderedDict{Int32, Tuple{NTuple{N, UnitRange{Int32}}, Option{Vector{T}}}} + recv_data::OrderedDict{Int32, Tuple{NTuple{N, UnitRange{Int32}}, Option{Vector{T}}}} batch_size::Option{Integer} function ParRepartition(T, comm_in, comm_out, global_size) @@ -194,4 +194,12 @@ end function (R::ParRepartition{T,N})(x::X) where {T,N,X<:AbstractVector{T}} y = R(reshape(x, length(x), 1)) return vec(y) -end \ No newline at end of file +end + +function ChainRulesCore.rrule(A::ParRepartition{T,N}, x::X) where {T,N,X<:AbstractMatrix{T}} + op_out = A(x) + function pullback(op) + return NoTangent(), A'(op) + end + return op_out, pullback +end diff --git a/src/ParRestriction.jl b/src/ParRestriction.jl index 488cb39..b814e34 100644 --- a/src/ParRestriction.jl +++ b/src/ParRestriction.jl @@ -53,3 +53,19 @@ function from_Dict(::Type{ParRestriction}, d) dtype = Data_TYPES[ts] ParRestriction(dtype, d["n"], ranges) end + +function ChainRulesCore.rrule(A::ParRestriction{T}, x::X) where {T,X<:AbstractMatrix{T}} + op_out = A(x) + function pullback(op) + return (NoTangent(), A'(op)) + end + return op_out, pullback +end + +function ChainRulesCore.rrule(A::ParAdjoint{T,T,NonParametric,ParRestriction{T}}, x::X) where {T,X<:AbstractMatrix{T}} + op_out = A(x) + function pullback(op) + return (NoTangent(), A.op(op)) + end + return op_out, pullback +end diff --git a/src/ParTensor.jl b/src/ParTensor.jl new file mode 100644 index 0000000..2a8cf78 --- /dev/null +++ b/src/ParTensor.jl @@ -0,0 +1,199 @@ +export ParTensor + +using OMEinsum +using Flux:batched_mul + +""" +Dense N dimensional tensor operator. +""" +struct ParTensor{N,M,O,T} <: ParLinearOperator{T,T,Parametric,External} + weight_order::Tuple{Vararg{Int,N}} + weight_shape::Tuple{Vararg{Int,N}} + + input_order::Tuple{Vararg{Int,M}} + input_shape::Tuple{Vararg{Int,M}} + + target_order::Tuple{Vararg{Int,O}} + target_shape::Tuple{Vararg{Int,O}} + id::Any + + ParTensor(T::DataType, wo::Tuple{Vararg{Int,N}}, ws::Tuple{Vararg{Int,N}}, io::Tuple{Vararg{Int,M}}, is::Tuple{Vararg{Int,M}}, to::Tuple{Vararg{Int,O}}, ts::Tuple{Vararg{Int,O}}, id) where {N, M, O} = new{N,M,O,T}(wo, ws, io, is, to, ts, id) + ParTensor(wo::Tuple{Vararg{Int,N}}, ws::Tuple{Vararg{Int,N}}, io::Tuple{Vararg{Int,M}}, is::Tuple{Vararg{Int,M}}, to::Tuple{Vararg{Int,O}}, ts::Tuple{Vararg{Int,O}}, id) where {N, M, O} = new{N,M,O,Float64}(wo, ws, io, is, to, ts, id) + ParTensor(T::DataType, wo::Tuple{Vararg{Int,N}}, ws::Tuple{Vararg{Int,N}}, io::Tuple{Vararg{Int,M}}, is::Tuple{Vararg{Int,M}}, to::Tuple{Vararg{Int,O}}, ts::Tuple{Vararg{Int,O}}) where {N, M, O} = new{N,M,O,T}(wo, ws, io, is, to, ts, uuid4(Random.GLOBAL_RNG)) + ParTensor(wo::Tuple{Vararg{Int,N}}, ws::Tuple{Vararg{Int,N}}, io::Tuple{Vararg{Int,M}}, is::Tuple{Vararg{Int,M}}, to::Tuple{Vararg{Int,O}}, ts::Tuple{Vararg{Int,O}}) where {N, M, O} = new{N,M,O,Float64}(wo, ws, io, is, to, ts, uuid4(Random.GLOBAL_RNG)) +end + +Domain(A::ParTensor) = prod(A.input_shape) +Range(A::ParTensor) = prod(A.target_shape) + +function init!(A::ParTensor{N,M,O,T}, d::Parameters) where {N,M,O,T<:Real} + d[A] = rand(T, A.weight_shape...) ./ convert(T, sqrt(prod(A.weight_shape))) +end + +function init!(A::ParTensor{N,M,O,T}, d::Parameters) where {N,M,O,T<:Complex} + d[A] = rand(T, A.weight_shape...) ./ convert(real(T), sqrt(prod(A.weight_shape))) +end + +# TODO: Abstract usage of OMEinsum to another controller +function (A::ParParameterized{T,T,Linear,ParTensor{4,M,O,T},V})(x::X) where {M,O,T,V,X<:AbstractMatrix{T}} + # Hacky batched mul for Just ML4Seismic + b = size(x)[2] + ic = A.op.weight_shape[1] + oc = A.op.weight_shape[2] + nt = A.op.weight_shape[3] + nxy = A.op.weight_shape[4] + + # input from it(xy)b -> bi(txy) + x = reshape(x, (A.op.input_shape..., b)) + x = permutedims(x, [4,1,2,3]) + x = reshape(x, b, ic, :) + + # params from iot(xy) -> io(txy) + params = reshape(A.params, ic, oc, :) + + # output from bo(txy) -> (otxy)b + output = batched_mul(x, params) + output = reshape(output, b, oc, nt, nxy) + output = permutedims(output, [2,3,4,1]) + output = reshape(output, :, b) + + return output +end + +function (A::ParParameterized{T,T,Linear,ParTensor{5,M,O,T},V})(x::X) where {M,O,T,V,X<:AbstractMatrix{T}} + # Hacky batched mul for Just ML4Seismic + b = size(x)[2] + oc = A.op.weight_shape[1] + ic = A.op.weight_shape[2] + nx = A.op.weight_shape[3] + ny = A.op.weight_shape[4] + nt = A.op.weight_shape[5] + + # input from ixytb -> bi(xyt) + input = reshape(x, (A.op.input_shape..., b)) + input = permutedims(input, [5,1,2,3,4]) + input = reshape(input, b, ic, :) + + # params from oixyt -> io(xyt) + params = permutedims(A.params, [2,1,3,4,5]) + params = reshape(params, ic, oc, :) + + # output from bo(xyt) -> (oxyt)b + output = batched_mul(input, params) + output = reshape(output, b, oc, nx, ny, nt) + output = permutedims(output, [2,3,4,5,1]) + output = reshape(output, :, b) + + return output +end + +function (A::ParParameterized{T,T,Linear,ParTensor{6,M,O,T},V})(x::X) where {M,O,T,V,X<:AbstractMatrix{T}} + # Hacky batched mul for Just ML4Seismic + b = size(x)[2] + oc = A.op.weight_shape[1] + ic = A.op.weight_shape[2] + nx = A.op.weight_shape[3] + ny = A.op.weight_shape[4] + nz = A.op.weight_shape[5] + nt = A.op.weight_shape[6] + + # input from ixyztb -> bi(xyzt) + input = reshape(x, (A.op.input_shape..., b)) + input = permutedims(input, [6,1,2,3,4,5]) + input = reshape(input, b, ic, :) + + # params from oixyzt -> io(xyzt) + params = permutedims(A.params, [2,1,3,4,5,6]) + params = reshape(params, ic, oc, :) + + # output from bo(xyzt) -> (oxyzt)b + output = batched_mul(input, params) + output = reshape(output, b, oc, nx, ny, nz, nt) + output = permutedims(output, [2,3,4,5,6,1]) + output = reshape(output, :, b) + + return output +end + +# TODO: Ideally we want the following because its an abstraction for any einsum. Currently, a bug with Julia +# (A::ParParameterized{T,T,Linear,ParTensor{N,M,O,T},V})(x::X) where {N,M,O,T,V,X<:AbstractVector{T}} = vec(einsum(EinCode((A.op.weight_order,A.op.input_order),A.op.target_order),(A.params,reshape(x, A.op.input_shape)))) +# (A::ParParameterized{T,T,Linear,ParTensor{N,M,O,T},V})(x::X) where {N,M,O,T,V,X<:AbstractVector{T}} = vec(einsum(EinCode((A.op.weight_order,A.op.input_order),A.op.target_order),(A.params |> cpu,reshape(x, A.op.input_shape) |> cpu))) |> gpu + +function to_Dict(A::ParTensor{N,M,O,T}) where {N,M,O,T} + rv = Dict{String, Any}( + "type" => "ParTensor", + "T" => string(T), + "ws" => A.weight_shape, + "wo" => A.weight_order, + "is" => A.input_shape, + "io" => A.input_order, + "to" => A.target_order, + "ts" => A.target_shape, + ) + if typeof(A.id) == String + rv["id"] = A.id + elseif typeof(A.id) == UUID + rv["id"] = "UUID:$(string(A.id))" + else + throw(ParException("I don't know how to encode id of type $(typeof(A.id))")) + end + rv +end + +function from_Dict(::Type{ParTensor}, d) + ts = d["T"] + if !haskey(Data_TYPES, ts) + throw(ParException("unknown data type `$ts`")) + end + dtype = Data_TYPES[ts] + mid = d["id"] + if startswith(mid, "UUID:") + mid = UUID(mid[6:end]) + end + ParTensor(dtype, d["ws"], d["wo"], d["is"], d["io"], d["to"], d["ts"], mid) +end + +function distribute(A::ParTensor{N,M,O,T}, dims_in, comm::MPI.Comm = MPI.COMM_WORLD) where {N,M,O,T} + + @assert length(dims_in) == length(A.input_shape) + + comm_cart = MPI.Cart_create(comm, dims_in) + coords = MPI.Cart_coords(comm_cart) + + # TODO: Also assert comm size and dims_in product + + combined_tuples = tuple(A.input_order..., A.weight_order..., A.target_order...) + count_occurrences = (element_to_count) -> sum(element == element_to_count for element in combined_tuples) + + new_input_shape = collect(A.input_shape) + new_target_shape = collect(A.target_shape) + new_weight_shape = collect(A.weight_shape) + + for (i, dim) in enumerate(A.input_order) + dist_across = dims_in[dim] + if count_occurrences(dim) == 2 + # Do not distribute across the dimenions on which the convolution is performed + @assert dist_across == 1 + end + + # TODO: For now, only supports perfect distribution + @assert A.input_shape[i] % dist_across == 0 + new_input_shape[i] = A.input_shape[i] ÷ dist_across + + for (j, dim_j) in enumerate(A.weight_order) + if dim_j == dim + new_weight_shape[j] = A.weight_shape[j] ÷ dist_across + break + end + end + + for (j, dim_j) in enumerate(A.target_order) + if dim_j == dim + new_target_shape[j] = A.target_shape[j] ÷ dist_across + break + end + end + end + + return ParTensor(T, A.weight_order, tuple(new_weight_shape...), A.input_order, tuple(new_input_shape...), A.target_order, tuple(new_target_shape...), "$(A.id):($(join(coords, ',')))") +end diff --git a/src/ParametricOperators.jl b/src/ParametricOperators.jl index d8a0e3d..8285897 100644 --- a/src/ParametricOperators.jl +++ b/src/ParametricOperators.jl @@ -32,6 +32,7 @@ include("ASTOptimization.jl") include("ParDistributed.jl") include("ParBroadcasted.jl") include("ParRepartition.jl") +include("ParReduce.jl") # Operator wrappers include("ParIdentity.jl") # Include above for use in transforms, etc. @@ -44,6 +45,7 @@ include("ParKron.jl") # Operator definitions include("ParMatrix.jl") +include("ParTensor.jl") include("ParDiagonal.jl") include("ParDFT.jl") include("ParRestriction.jl") @@ -51,4 +53,4 @@ include("ParRestriction.jl") # Operator serialization include("ASTSerialization.jl") -end \ No newline at end of file +end