Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix excessive synchronization in contract on Dagger.DArray #124

Merged
merged 26 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
000792c
Fix `contract` on `Dagger.DArray` tensors
mofeing Nov 27, 2023
d12f1b7
Fix parent function call in `contract`
mofeing Nov 27, 2023
6d0e36c
Format code
mofeing May 29, 2024
95f5494
Replace `OMEinsum` code for `Tenet.contract` call
mofeing May 29, 2024
aef82c1
Refactor code
mofeing May 29, 2024
40264e2
Refactor `Dagger.@spawn` call
mofeing May 29, 2024
afb42be
Throw error if binary contraction between `Dagger.DArray` and non-`DA…
mofeing May 29, 2024
8ed6eb6
Test `Dagger` extension
mofeing May 29, 2024
8c29392
Include `Dagger` tests to integration tests
mofeing May 29, 2024
d75bd3f
Format code
mofeing May 29, 2024
a5af37b
Fix statement of broken test
mofeing May 29, 2024
532cb20
Fix `Base.size` on `Contract`
mofeing May 29, 2024
65424a0
Fix `Dagger.Blocks` on `Contract`
mofeing May 29, 2024
79ff56c
Fix indexing and iteration of `CartesianIndex`
mofeing May 29, 2024
565887d
Fix broadcasting
mofeing May 29, 2024
933d331
Fix chunk iteration logic
mofeing May 29, 2024
a5e607a
Import `EagerThunk`, `DArray`
mofeing May 29, 2024
8d2ed36
Format code
mofeing May 29, 2024
a3c7fe6
Fix library loading
mofeing May 29, 2024
cfb7ed5
Fix minor issue
mofeing May 29, 2024
15a3d85
Clean workers at the end
mofeing May 29, 2024
6661477
Fix `eltype` of subdomains
mofeing May 30, 2024
abc85ca
Fix typo in call
mofeing May 30, 2024
4d7752c
Comment "block-unblock" test
mofeing May 30, 2024
0113afd
Wrap tests over a `try-finally` block to always kill workers cleanly
mofeing May 30, 2024
7231b46
Remove unused test
mofeing May 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
Expand All @@ -30,6 +31,7 @@ TenetAdaptExt = "Adapt"
TenetChainRulesCoreExt = "ChainRulesCore"
TenetChainRulesExt = "ChainRules"
TenetChainRulesTestUtilsExt = ["ChainRulesCore", "ChainRulesTestUtils"]
TenetDaggerExt = "Dagger"
TenetFiniteDifferencesExt = "FiniteDifferences"
TenetGraphMakieExt = ["GraphMakie", "Makie"]

Expand All @@ -39,6 +41,7 @@ Adapt = "4"
ChainRules = "1.0"
ChainRulesCore = "1.0"
Combinatorics = "1.0"
Dagger = "0.18"
DeltaArrays = "0.1.1"
EinExprs = "0.5, 0.6"
GraphMakie = "0.4,0.5"
Expand Down
135 changes: 135 additions & 0 deletions ext/TenetDaggerExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
module TenetDaggerExt

using Tenet
using Dagger: Dagger, ArrayOp, Context, ArrayDomain, EagerThunk, DArray

struct Contract{T,N} <: ArrayOp{T,N}
ic::Vector{Symbol}
a::ArrayOp
ia::Vector{Symbol}
b::ArrayOp
ib::Vector{Symbol}

function Contract(ic, a, ia, b, ib)
allunique(ia) || throw(ErrorException("ia must have unique indices"))
allunique(ib) || throw(ErrorException("ib must have unique indices"))
allunique(ic) || throw(ErrorException("ic must have unique indices"))
ic ⊆ ia ∪ ib || throw(ErrorException("ic must be a subset of ia ∪ ib"))
return new{promote_type(eltype(a), eltype(b)),length(ic)}(ic, a, ia, b, ib)
end
end

function Base.size(x::Contract)
return Tuple(
Iterators.map(x.ic) do i
if i ∈ x.ia
size(x.a, findfirst(==(i), x.ia))
elseif i ∈ x.ib
size(x.b, findfirst(==(i), x.ib))
else
throw(ErrorException("index $i not found in a nor b"))
end
end,
)
end

function Dagger.Blocks(x::Contract)
return Dagger.Blocks(map(x.ic) do i
j = findfirst(==(i), x.ia)
isnothing(j) || return x.a.partitioning.blocksize[j]

j = findfirst(==(i), x.ib)
isnothing(j) || return x.b.partitioning.blocksize[j]

throw(ErrorException("index :$i not found in a nor b"))
end...)
end

function selectdims(a, proj::Pair...)
return reduce(proj; init=a) do acc, (d, i)
selectdim(acc, d, i)
end
end

contractfn(ic, chunk_a, ia, chunk_b, ib) = parent(contract(Tensor(chunk_a, ia), Tensor(chunk_b, ib); out=ic))

function Dagger.stage(ctx::Context, op::Contract{T,N}) where {T,N}
domain = Dagger.ArrayDomain([1:l for l in size(op)])
partitioning = Dagger.Blocks(op)

# NOTE careful with ÷ for dividing into partitions
subdomains = Array{ArrayDomain{N,NTuple{2,UnitRange{Int}}}}(undef, map(÷, size(op), partitioning.blocksize))
for indices in eachindex(IndexCartesian(), subdomains)
subdomains[indices] = ArrayDomain(
map(Tuple(indices), partitioning.blocksize) do i, step
(i - 1) * step .+ (1:step)
end,
)
end

suminds = setdiff(op.ia ∪ op.ib, op.ic)
inner_perm_a = sortperm(map(i -> findfirst(==(i), op.ia), suminds))
inner_perm_b = sortperm(map(i -> findfirst(==(i), op.ib), suminds))

mask_a = op.ic .∈ (op.ia,)
mask_b = op.ic .∈ (op.ib,)
outer_perm_a = map(i -> findfirst(==(i), op.ia), op.ic[mask_a])
outer_perm_b = map(i -> findfirst(==(i), op.ib), op.ic[mask_b])

chunks = similar(subdomains, EagerThunk)
for indices in eachindex(IndexCartesian(), chunks)
outer_indices_a = Tuple(indices)[mask_a]
chunks_a = dropdims(
reduce(zip(outer_perm_a, outer_indices_a); init=Dagger.chunks(op.a)) do acc, (d, i)
selectdim(acc, d, i:i)
end;
dims=Tuple(outer_perm_a),
)
chunks_a = permutedims(chunks_a, inner_perm_a)

outer_indices_b = Tuple(indices)[mask_b]
chunks_b = dropdims(
reduce(zip(outer_perm_b, outer_indices_b); init=Dagger.chunks(op.b)) do acc, (d, i)
selectdim(acc, d, i:i)
end;
dims=Tuple(outer_perm_b),
)
chunks_b = permutedims(chunks_b, inner_perm_b)

chunks[indices] = Dagger.treereduce(
Dagger.AddComputeOp,
map(chunks_a, chunks_b) do chunk_a, chunk_b
# TODO add ThunkOptions: alloc_util, occupancy, ...
Dagger.@spawn contractfn(op.ic, chunk_a, op.ia, chunk_b, op.ib)
end,
)
end

return DArray(T, domain, subdomains, chunks, partitioning)
end

function Tenet.contract(
a::Tensor{Ta,Na,Aa}, b::Tensor{Tb,Nb,Ab}; dims=(∩(inds(a), inds(b))), out=nothing
) where {Ta,Tb,Na,Nb,Aa<:Dagger.DArray{Ta,Na},Ab<:Dagger.DArray{Tb,Nb}}
ia = collect(inds(a))
ib = collect(inds(b))
i = ∩(dims, ia, ib)

ic::Vector{Symbol} = if isnothing(out)
setdiff(ia ∪ ib, i isa Base.AbstractVecOrTuple ? i : (i,))::Vector{Symbol}
else
out
end

data = Dagger._to_darray(Contract(ic, parent(a), ia, parent(b), ib))
return Tensor(data, ic)
end

Tenet.contract(a::Tensor, b::Tensor{T,N,A}; kwargs...) where {T,N,A<:Dagger.DArray} = contract(b, a; kwargs...)
function Tenet.contract(a::Tensor{T,N,A}, b::Tensor; kwargs...) where {T,N,A<:Dagger.DArray}
throw(
ArgumentError("contract on a Dagger.DArray-backed Tensor with a non-DArray-backed Tensor is not yet supported")
)
end

end
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
DeltaArrays = "10b0fc19-5ccc-4427-889b-d75dd6306188"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
Expand Down
45 changes: 45 additions & 0 deletions test/integration/Dagger_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
using Tenet
using Dagger
using Distributed

@testset "Dagger" begin
addprocs(1)
@everywhere using Dagger, Tenet

try
@testset "Tensor" begin
data = rand(4, 4)
block_array = DArray(data, Dagger.Blocks(2, 2))
indices = (:i, :j)

tensor = Tensor(data, indices)
block_tensor = Tensor(block_array, indices)

@test inds(block_tensor) == inds(tensor)
@test Array(parent(block_tensor)) ≈ parent(tensor)
end

@testset "contract" begin
@testset "block-block" begin
data1, data2 = rand(4, 4), rand(4, 4)
block_array1 = distribute(data1, Dagger.Blocks(2, 2))
block_array2 = distribute(data2, Dagger.Blocks(2, 2))

tensor1 = Tensor(data1, [:i, :j])
tensor2 = Tensor(data2, [:j, :k])
block_tensor1 = Tensor(block_array1, [:i, :j])
block_tensor2 = Tensor(block_array2, [:j, :k])

contracted_tensor = contract(tensor1, tensor2)
contracted_block_tensor = contract(block_tensor1, block_tensor2)

@test parent(contracted_block_tensor) isa DArray
@test inds(contracted_block_tensor) == [:i, :k]
@test all(==((2, 2)) ∘ size, Dagger.domainchunks(parent(contracted_block_tensor)))
@test collect(parent(contracted_block_tensor)) ≈ parent(contracted_tensor)
end
end
finally
rmprocs(workers())
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ if VERSION >= v"1.10"
@testset "Integration tests" verbose = true begin
include("integration/ChainRules_test.jl")
# include("integration/BlockArray_test.jl")
include("integration/Dagger_test.jl")
include("integration/Makie_test.jl")
end
end
Expand Down
Loading