From ef8c837e14a5a1e3856be84ee1b32f5e202b9883 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Wed, 17 Apr 2024 10:25:15 +0000 Subject: [PATCH] Clean-up cutensor implementation --- ext/TensorOperationscuTENSORExt.jl | 494 +++++++++++++++-------------- 1 file changed, 255 insertions(+), 239 deletions(-) diff --git a/ext/TensorOperationscuTENSORExt.jl b/ext/TensorOperationscuTENSORExt.jl index 2fd06b21..e2dee466 100644 --- a/ext/TensorOperationscuTENSORExt.jl +++ b/ext/TensorOperationscuTENSORExt.jl @@ -1,25 +1,37 @@ module TensorOperationscuTENSORExt -#= -in general, the cuTENSOR operations work as follows: -1. create a plan for the operation - - make tensor descriptors for the input and output arrays - - describe the operation to be performed: - labels for permutations - scalar factors - unary operations (e.g. conjugation) - binary reduction operations (e.g. addition) - scalar compute type -2. execute the plan on given tensors - - forward pointers to the input and output arrays -=# - using TensorOperations using TensorOperations: TensorOperations as TO using cuTENSOR -using cuTENSOR: CUDA +using cuTENSOR: OP_IDENTITY, OP_CONJ, OP_ADD +using cuTENSOR: is_unary, is_binary +using cuTENSOR: handle, stream +using cuTENSOR: cutensorWorksizePreference_t, cutensorAlgo_t, cutensorOperationDescriptor_t, + cutensorOperator_t, cutensorJitMode_t, cutensorPlanPreference_t, + cutensorComputeDescriptorEnum +using cuTENSOR: WORKSPACE_DEFAULT, ALGO_DEFAULT, JIT_MODE_NONE +using cuTENSOR: cutensorCreatePlanPreference, cutensorPlan, CuTensorPlan, + CuTensorDescriptor, ModeType + +# elementwise binary +using cuTENSOR: elementwise_binary_compute_types, cutensorCreateElementwiseBinary, + cutensorElementwiseBinaryExecute +import cuTENSOR: plan_elementwise_binary, elementwise_binary_execute! + +# permute +using cuTENSOR: permutation_compute_types, cutensorCreatePermutation, cutensorPermute +import cuTENSOR: plan_permutation, permute! + +# contract +using cuTENSOR: contraction_compute_types, cutensorCreateContraction, cutensorContract +import cuTENSOR: plan_contraction, contract! + +# reduce +using cuTENSOR: reduction_compute_types, cutensorCreateReduction, cutensorReduce +import cuTENSOR: plan_reduction, reduce! +using cuTENSOR: CUDA using CUDA: CuArray, AnyCuArray # this might be dependency-piracy, but removes a dependency from the main package using CUDA.Adapt: adapt @@ -46,13 +58,13 @@ function TO.tensorscalar(C::CuStridedView) end function tensorop(A::AnyCuArray, conjA::Symbol=:N) - return (eltype(A) <: Real || conjA === :N) ? cuTENSOR.OP_IDENTITY : cuTENSOR.OP_CONJ + return (eltype(A) <: Real || conjA === :N) ? OP_IDENTITY : OP_CONJ end function tensorop(A::CuStridedView, conjA::Symbol=:N) return if (eltype(A) <: Real || !xor(conjA === :C, A.op === conj)) - cuTENSOR.OP_IDENTITY + OP_IDENTITY else - cuTENSOR.OP_CONJ + OP_CONJ end end @@ -79,17 +91,14 @@ for ArrayType in SUPPORTED_CUARRAYS α::Number, β::Number) return tensortrace!(C, pC, A, pA, conjA, α, β, cuTENSORBackend()) end - @eval function TO.tensoradd_type(TC, pC::Index2Tuple, ::$ArrayType, conjA::Symbol) return CUDA.CuArray{TC,TO.numind(pC)} end - @eval function TO.tensorcontract_type(TC, pC::Index2Tuple, ::$ArrayType, pA::Index2Tuple, conjA::Symbol, ::$ArrayType, pB::Index2Tuple, conjB::Symbol) return CUDA.CuArray{TC,TO.numind(pC)} end - @eval TO.tensorfree!(C::$ArrayType) = TO.tensorfree!(C::$ArrayType, cuTENSORBackend()) end @@ -141,7 +150,7 @@ function TO.tensoralloc_contract(TC, pC, return tensoralloc(ttype, structure, istemp)::ttype end -function TO.tensorfree!(C::CuStridedView, backend::cuTENSORBackend) +function TO.tensorfree!(C::CuStridedView, ::cuTENSORBackend) CUDA.unsafe_free!(parent(C)) return nothing end @@ -184,14 +193,14 @@ function TO.tensoradd!(C::CuStridedView, pC::Index2Tuple, # dispatch to cuTENSOR return if iszero(β) - cuTENSOR.permute!(α, A, Ainds, opA, C, Cinds) + permute!(α, A, Ainds, opA, C, Cinds) else - cuTENSOR.elementwise_binary_execute!(α, - A, Ainds, opA, - β, - C, Cinds, cuTENSOR.OP_IDENTITY, - C, Cinds, - cuTENSOR.OP_ADD) + elementwise_binary_execute!(α, + A, Ainds, opA, + β, + C, Cinds, OP_IDENTITY, + C, Cinds, + OP_ADD) end end @@ -205,12 +214,12 @@ function TO.tensorcontract!(C::CuStridedView, pC::Index2Tuple, opB = tensorop(B, conjB) # dispatch to cuTENSOR - return cuTENSOR.contract!(α, - A, Ainds, opA, - B, Binds, opB, - β, - C, Cinds, cuTENSOR.OP_IDENTITY, - cuTENSOR.OP_IDENTITY) + return contract!(α, + A, Ainds, opA, + B, Binds, opB, + β, + C, Cinds, OP_IDENTITY, + OP_IDENTITY) end function TO.tensortrace!(C::CuStridedView, pC::Index2Tuple, @@ -221,60 +230,8 @@ function TO.tensortrace!(C::CuStridedView, pC::Index2Tuple, opA = tensorop(A, conjA) # map to reduction operation - plan = plan_trace(A, Ainds, opA, C, Cinds, cuTENSOR.OP_IDENTITY, cuTENSOR.OP_ADD) - return cuTENSOR.reduce!(plan, α, A, β, C) -end - -function plan_trace(@nospecialize(A::AbstractArray), Ainds::cuTENSOR.ModeType, - opA::cuTENSOR.cutensorOperator_t, - @nospecialize(C::AbstractArray), Cinds::cuTENSOR.ModeType, - opC::cuTENSOR.cutensorOperator_t, - opReduce::cuTENSOR.cutensorOperator_t; - jit::cuTENSOR.cutensorJitMode_t=cuTENSOR.JIT_MODE_NONE, - workspace::cuTENSOR.cutensorWorksizePreference_t=cuTENSOR.WORKSPACE_DEFAULT, - algo::cuTENSOR.cutensorAlgo_t=cuTENSOR.ALGO_DEFAULT, - compute_type::Union{DataType,cuTENSOR.cutensorComputeDescriptorEnum, - Nothing}=nothing) - !cuTENSOR.is_unary(opA) && throw(ArgumentError("opA must be a unary op!")) - !cuTENSOR.is_unary(opC) && throw(ArgumentError("opC must be a unary op!")) - !cuTENSOR.is_binary(opReduce) && throw(ArgumentError("opReduce must be a binary op!")) - - # TODO: check if this can be avoided, available in caller - # TODO: cuTENSOR will allocate sizes and strides anyways, could use that here - _, cindA1, cindA2 = TO.trace_indices(tuple(Ainds...), tuple(Cinds...)) - - # add strides of cindA2 to strides of cindA1 -> selects diagonal - stA = strides(A) - for (i, j) in zip(cindA1, cindA2) - stA = Base.setindex(stA, stA[i] + stA[j], i) - end - szA = TT.deleteat(size(A), cindA2) - stA′ = TT.deleteat(stA, cindA2) - - descA = CuTensorDescriptor(A; size=szA, strides=stA′) - descC = CuTensorDescriptor(C) - - modeA = collect(Cint, deleteat!(Ainds, cindA2)) - modeC = collect(Cint, Cinds) - - actual_compute_type = if compute_type === nothing - cuTENSOR.reduction_compute_types[(eltype(A), eltype(C))] - else - compute_type - end - - desc = Ref{cuTENSOR.cutensorOperationDescriptor_t}() - cuTENSOR.cutensorCreateReduction(cuTENSOR.handle(), - desc, - descA, modeA, opA, - descC, modeC, opC, - descC, modeC, opReduce, - actual_compute_type) - - plan_pref = Ref{cuTENSOR.cutensorPlanPreference_t}() - cuTENSOR.cutensorCreatePlanPreference(cuTENSOR.handle(), plan_pref, algo, jit) - - return cuTENSOR.CuTensorPlan(desc[], plan_pref[]; workspacePref=workspace) + plan = plan_trace(A, Ainds, opA, C, Cinds, OP_IDENTITY, OP_ADD) + return reduce!(plan, α, A, β, C) end #------------------------------------------------------------------------------------------- @@ -285,43 +242,35 @@ end # StridedViews should always work. The following is a lot of code duplication from # cuTENSOR.jl, but for now this will have to do. -using cuTENSOR: cutensorWorksizePreference_t, cutensorAlgo_t, cutensorComputeDescriptorEnum, - CuTensorPlan, ModeType, cutensorOperator_t, cutensorJitMode_t, - WORKSPACE_DEFAULT, ALGO_DEFAULT, JIT_MODE_NONE, CuTensorDescriptor, - is_unary, is_binary, cutensorOperationDescriptor_t, - cutensorCreateContraction, - cutensorCreatePermutation, cutensorReduce, cutensorPlanPreference_t, - plan_contraction, cutensorCreatePlanPreference, cutensorPermute, - cutensorElementwiseBinaryExecute, cutensorContract, - cutensorCreateElementwiseBinary, cutensorElementwiseBinaryExecute - -function cuTENSOR.elementwise_binary_execute!(@nospecialize(alpha::Number), - @nospecialize(A::CuStridedView), - Ainds::ModeType, - opA::cutensorOperator_t, - @nospecialize(gamma::Number), - @nospecialize(C::CuStridedView), - Cinds::ModeType, - opC::cutensorOperator_t, - @nospecialize(D::CuStridedView), - Dinds::ModeType, - opAC::cutensorOperator_t; - workspace::cutensorWorksizePreference_t=WORKSPACE_DEFAULT, - algo::cutensorAlgo_t=ALGO_DEFAULT, - compute_type::Union{DataType, - cutensorComputeDescriptorEnum, - Nothing}=nothing, - plan::Union{CuTensorPlan,Nothing}=nothing) +# elementwise_binary_execute +# -------------------------- +function elementwise_binary_execute!(@nospecialize(alpha::Number), + @nospecialize(A::CuStridedView), + Ainds::ModeType, + opA::cutensorOperator_t, + @nospecialize(gamma::Number), + @nospecialize(C::CuStridedView), + Cinds::ModeType, + opC::cutensorOperator_t, + @nospecialize(D::CuStridedView), + Dinds::ModeType, + opAC::cutensorOperator_t; + workspace::cutensorWorksizePreference_t=WORKSPACE_DEFAULT, + algo::cutensorAlgo_t=ALGO_DEFAULT, + compute_type::Union{DataType, + cutensorComputeDescriptorEnum, + Nothing}=nothing, + plan::Union{CuTensorPlan,Nothing}=nothing) actual_plan = if plan === nothing - cuTENSOR.plan_elementwise_binary(A, Ainds, opA, - C, Cinds, opC, - D, Dinds, opAC; - workspace, algo, compute_type) + plan_elementwise_binary(A, Ainds, opA, + C, Cinds, opC, + D, Dinds, opAC; + workspace, algo, compute_type) else plan end - cuTENSOR.elementwise_binary_execute!(actual_plan, alpha, A, gamma, C, D) + elementwise_binary_execute!(actual_plan, alpha, A, gamma, C, D) if plan === nothing CUDA.unsafe_free!(actual_plan) @@ -330,32 +279,32 @@ function cuTENSOR.elementwise_binary_execute!(@nospecialize(alpha::Number), return D end -function cuTENSOR.elementwise_binary_execute!(plan::CuTensorPlan, - @nospecialize(alpha::Number), - @nospecialize(A::CuStridedView), - @nospecialize(gamma::Number), - @nospecialize(C::CuStridedView), - @nospecialize(D::CuStridedView)) +function elementwise_binary_execute!(plan::CuTensorPlan, + @nospecialize(alpha::Number), + @nospecialize(A::CuStridedView), + @nospecialize(gamma::Number), + @nospecialize(C::CuStridedView), + @nospecialize(D::CuStridedView)) scalar_type = plan.scalar_type - cutensorElementwiseBinaryExecute(cuTENSOR.handle(), plan, + cutensorElementwiseBinaryExecute(handle(), plan, Ref{scalar_type}(alpha), A, Ref{scalar_type}(gamma), C, D, - cuTENSOR.stream()) + stream()) return D end -function cuTENSOR.plan_elementwise_binary(@nospecialize(A::CuStridedView), Ainds::ModeType, - opA::cutensorOperator_t, - @nospecialize(C::CuStridedView), Cinds::ModeType, - opC::cutensorOperator_t, - @nospecialize(D::CuStridedView), Dinds::ModeType, - opAC::cutensorOperator_t; - jit::cutensorJitMode_t=JIT_MODE_NONE, - workspace::cutensorWorksizePreference_t=WORKSPACE_DEFAULT, - algo::cutensorAlgo_t=ALGO_DEFAULT, - compute_type::Union{DataType, - cutensorComputeDescriptorEnum, - Nothing}=eltype(C)) +function plan_elementwise_binary(@nospecialize(A::CuStridedView), Ainds::ModeType, + opA::cutensorOperator_t, + @nospecialize(C::CuStridedView), Cinds::ModeType, + opC::cutensorOperator_t, + @nospecialize(D::CuStridedView), Dinds::ModeType, + opAC::cutensorOperator_t; + jit::cutensorJitMode_t=JIT_MODE_NONE, + workspace::cutensorWorksizePreference_t=WORKSPACE_DEFAULT, + algo::cutensorAlgo_t=ALGO_DEFAULT, + compute_type::Union{DataType, + cutensorComputeDescriptorEnum, + Nothing}=eltype(C)) !is_unary(opA) && throw(ArgumentError("opA must be a unary op!")) !is_unary(opC) && throw(ArgumentError("opC must be a unary op!")) !is_binary(opAC) && throw(ArgumentError("opAC must be a binary op!")) @@ -368,13 +317,13 @@ function cuTENSOR.plan_elementwise_binary(@nospecialize(A::CuStridedView), Ainds modeD = modeC actual_compute_type = if compute_type === nothing - cuTENSOR.elementwise_binary_compute_types[(eltype(A), eltype(C))] + elementwise_binary_compute_types[(eltype(A), eltype(C))] else compute_type end desc = Ref{cutensorOperationDescriptor_t}() - cutensorCreateElementwiseBinary(cuTENSOR.handle(), + cutensorCreateElementwiseBinary(handle(), desc, descA, modeA, opA, descC, modeC, opC, @@ -383,20 +332,22 @@ function cuTENSOR.plan_elementwise_binary(@nospecialize(A::CuStridedView), Ainds actual_compute_type) plan_pref = Ref{cutensorPlanPreference_t}() - cutensorCreatePlanPreference(cuTENSOR.handle(), plan_pref, algo, jit) + cutensorCreatePlanPreference(handle(), plan_pref, algo, jit) return CuTensorPlan(desc[], plan_pref[]; workspacePref=workspace) end -function cuTENSOR.permute!(@nospecialize(alpha::Number), - @nospecialize(A::CuStridedView), Ainds::ModeType, - opA::cutensorOperator_t, - @nospecialize(B::CuStridedView), Binds::ModeType; - workspace::cutensorWorksizePreference_t=WORKSPACE_DEFAULT, - algo::cutensorAlgo_t=ALGO_DEFAULT, - compute_type::Union{DataType,cutensorComputeDescriptorEnum, - Nothing}=nothing, - plan::Union{CuTensorPlan,Nothing}=nothing) +# permute! +# -------- +function permute!(@nospecialize(alpha::Number), + @nospecialize(A::CuStridedView), Ainds::ModeType, + opA::cutensorOperator_t, + @nospecialize(B::CuStridedView), Binds::ModeType; + workspace::cutensorWorksizePreference_t=WORKSPACE_DEFAULT, + algo::cutensorAlgo_t=ALGO_DEFAULT, + compute_type::Union{DataType,cutensorComputeDescriptorEnum, + Nothing}=nothing, + plan::Union{CuTensorPlan,Nothing}=nothing) actual_plan = if plan === nothing plan_permutation(A, Ainds, opA, B, Binds; @@ -405,7 +356,7 @@ function cuTENSOR.permute!(@nospecialize(alpha::Number), plan end - cuTENSOR.permute!(actual_plan, alpha, A, B) + permute!(actual_plan, alpha, A, B) if plan === nothing CUDA.unsafe_free!(actual_plan) @@ -414,14 +365,12 @@ function cuTENSOR.permute!(@nospecialize(alpha::Number), return B end -function cuTENSOR.permute!(plan::CuTensorPlan, - @nospecialize(alpha::Number), - @nospecialize(A::CuStridedView), - @nospecialize(B::CuStridedView)) +function permute!(plan::CuTensorPlan, + @nospecialize(alpha::Number), + @nospecialize(A::CuStridedView), + @nospecialize(B::CuStridedView)) scalar_type = plan.scalar_type - cutensorPermute(cuTENSOR.handle(), plan, - Ref{scalar_type}(alpha), A, B, - cuTENSOR.stream()) + cutensorPermute(handle(), plan, Ref{scalar_type}(alpha), A, B, stream()) return B end @@ -440,38 +389,40 @@ function plan_permutation(@nospecialize(A::CuStridedView), Ainds::ModeType, modeB = collect(Cint, Binds) actual_compute_type = if compute_type === nothing - cuTENSOR.permutation_compute_types[(eltype(A), eltype(B))] + permutation_compute_types[(eltype(A), eltype(B))] else compute_type end desc = Ref{cutensorOperationDescriptor_t}() - cutensorCreatePermutation(cuTENSOR.handle(), desc, + cutensorCreatePermutation(handle(), desc, descA, modeA, opA, descB, modeB, actual_compute_type) plan_pref = Ref{cutensorPlanPreference_t}() - cutensorCreatePlanPreference(cuTENSOR.handle(), plan_pref, algo, jit) + cutensorCreatePlanPreference(handle(), plan_pref, algo, jit) return CuTensorPlan(desc[], plan_pref[]; workspacePref=workspace) end -function cuTENSOR.contract!(@nospecialize(alpha::Number), - @nospecialize(A::CuStridedView), Ainds::ModeType, - opA::cutensorOperator_t, - @nospecialize(B::CuStridedView), Binds::ModeType, - opB::cutensorOperator_t, - @nospecialize(beta::Number), - @nospecialize(C::CuStridedView), Cinds::ModeType, - opC::cutensorOperator_t, - opOut::cutensorOperator_t; - jit::cutensorJitMode_t=JIT_MODE_NONE, - workspace::cutensorWorksizePreference_t=WORKSPACE_DEFAULT, - algo::cutensorAlgo_t=ALGO_DEFAULT, - compute_type::Union{DataType,cutensorComputeDescriptorEnum, - Nothing}=nothing, - plan::Union{CuTensorPlan,Nothing}=nothing) +# contract! +# --------- +function contract!(@nospecialize(alpha::Number), + @nospecialize(A::CuStridedView), Ainds::ModeType, + opA::cutensorOperator_t, + @nospecialize(B::CuStridedView), Binds::ModeType, + opB::cutensorOperator_t, + @nospecialize(beta::Number), + @nospecialize(C::CuStridedView), Cinds::ModeType, + opC::cutensorOperator_t, + opOut::cutensorOperator_t; + jit::cutensorJitMode_t=JIT_MODE_NONE, + workspace::cutensorWorksizePreference_t=WORKSPACE_DEFAULT, + algo::cutensorAlgo_t=ALGO_DEFAULT, + compute_type::Union{DataType,cutensorComputeDescriptorEnum, + Nothing}=nothing, + plan::Union{CuTensorPlan,Nothing}=nothing) actual_plan = if plan === nothing plan_contraction(A, Ainds, opA, B, Binds, opB, C, Cinds, opC, opOut; jit, workspace, algo, compute_type) @@ -479,7 +430,7 @@ function cuTENSOR.contract!(@nospecialize(alpha::Number), plan end - cuTENSOR.contract!(actual_plan, alpha, A, B, beta, C) + contract!(actual_plan, alpha, A, B, beta, C) if plan === nothing CUDA.unsafe_free!(actual_plan) @@ -488,33 +439,33 @@ function cuTENSOR.contract!(@nospecialize(alpha::Number), return C end -function cuTENSOR.contract!(plan::CuTensorPlan, - @nospecialize(alpha::Number), - @nospecialize(A::CuStridedView), - @nospecialize(B::CuStridedView), - @nospecialize(beta::Number), - @nospecialize(C::CuStridedView)) +function contract!(plan::CuTensorPlan, + @nospecialize(alpha::Number), + @nospecialize(A::CuStridedView), + @nospecialize(B::CuStridedView), + @nospecialize(beta::Number), + @nospecialize(C::CuStridedView)) scalar_type = plan.scalar_type - cutensorContract(cuTENSOR.handle(), plan, + cutensorContract(handle(), plan, Ref{scalar_type}(alpha), A, B, Ref{scalar_type}(beta), C, C, - plan.workspace, sizeof(plan.workspace), cuTENSOR.stream()) + plan.workspace, sizeof(plan.workspace), stream()) return C end -function cuTENSOR.plan_contraction(@nospecialize(A::CuStridedView), Ainds::ModeType, - opA::cutensorOperator_t, - @nospecialize(B::CuStridedView), Binds::ModeType, - opB::cutensorOperator_t, - @nospecialize(C::CuStridedView), Cinds::ModeType, - opC::cutensorOperator_t, - opOut::cutensorOperator_t; - jit::cutensorJitMode_t=JIT_MODE_NONE, - workspace::cutensorWorksizePreference_t=WORKSPACE_DEFAULT, - algo::cutensorAlgo_t=ALGO_DEFAULT, - compute_type::Union{DataType, - cutensorComputeDescriptorEnum, - Nothing}=nothing) +function plan_contraction(@nospecialize(A::CuStridedView), Ainds::ModeType, + opA::cutensorOperator_t, + @nospecialize(B::CuStridedView), Binds::ModeType, + opB::cutensorOperator_t, + @nospecialize(C::CuStridedView), Cinds::ModeType, + opC::cutensorOperator_t, + opOut::cutensorOperator_t; + jit::cutensorJitMode_t=JIT_MODE_NONE, + workspace::cutensorWorksizePreference_t=WORKSPACE_DEFAULT, + algo::cutensorAlgo_t=ALGO_DEFAULT, + compute_type::Union{DataType, + cutensorComputeDescriptorEnum, + Nothing}=nothing) !is_unary(opA) && throw(ArgumentError("opA must be a unary op!")) !is_unary(opB) && throw(ArgumentError("opB must be a unary op!")) !is_unary(opC) && throw(ArgumentError("opC must be a unary op!")) @@ -528,13 +479,13 @@ function cuTENSOR.plan_contraction(@nospecialize(A::CuStridedView), Ainds::ModeT modeC = collect(Cint, Cinds) actual_compute_type = if compute_type === nothing - cuTENSOR.contraction_compute_types[(eltype(A), eltype(B), eltype(C))] + contraction_compute_types[(eltype(A), eltype(B), eltype(C))] else compute_type end desc = Ref{cutensorOperationDescriptor_t}() - cutensorCreateContraction(cuTENSOR.handle(), + cutensorCreateContraction(handle(), desc, descA, modeA, opA, descB, modeB, opB, @@ -543,23 +494,25 @@ function cuTENSOR.plan_contraction(@nospecialize(A::CuStridedView), Ainds::ModeT actual_compute_type) plan_pref = Ref{cutensorPlanPreference_t}() - cutensorCreatePlanPreference(cuTENSOR.handle(), plan_pref, algo, jit) + cutensorCreatePlanPreference(handle(), plan_pref, algo, jit) return CuTensorPlan(desc[], plan_pref[]; workspacePref=workspace) end -function cuTENSOR.reduce!(@nospecialize(alpha::Number), - @nospecialize(A::CuStridedView), Ainds::ModeType, - opA::cutensorOperator_t, - @nospecialize(beta::Number), - @nospecialize(C::CuStridedView), Cinds::ModeType, - opC::cutensorOperator_t, - opReduce::cutensorOperator_t; - workspace::cutensorWorksizePreference_t=WORKSPACE_DEFAULT, - algo::cutensorAlgo_t=ALGO_DEFAULT, - compute_type::Union{DataType,cutensorComputeDescriptorEnum, - Nothing}=nothing, - plan::Union{CuTensorPlan,Nothing}=nothing) +# reduce! +# ------- +function reduce!(@nospecialize(alpha::Number), + @nospecialize(A::CuStridedView), Ainds::ModeType, + opA::cutensorOperator_t, + @nospecialize(beta::Number), + @nospecialize(C::CuStridedView), Cinds::ModeType, + opC::cutensorOperator_t, + opReduce::cutensorOperator_t; + workspace::cutensorWorksizePreference_t=WORKSPACE_DEFAULT, + algo::cutensorAlgo_t=ALGO_DEFAULT, + compute_type::Union{DataType,cutensorComputeDescriptorEnum, + Nothing}=nothing, + plan::Union{CuTensorPlan,Nothing}=nothing) actual_plan = if plan === nothing plan_reduction(A, Ainds, opA, C, Cinds, opC, opReduce; workspace, algo, compute_type) @@ -567,7 +520,7 @@ function cuTENSOR.reduce!(@nospecialize(alpha::Number), plan end - cuTENSOR.reduce!(actual_plan, alpha, A, beta, C) + reduce!(actual_plan, alpha, A, beta, C) if plan === nothing CUDA.unsafe_free!(actual_plan) @@ -576,29 +529,29 @@ function cuTENSOR.reduce!(@nospecialize(alpha::Number), return C end -function cuTENSOR.reduce!(plan::CuTensorPlan, - @nospecialize(alpha::Number), - @nospecialize(A::CuStridedView), - @nospecialize(beta::Number), - @nospecialize(C::CuStridedView)) +function reduce!(plan::CuTensorPlan, + @nospecialize(alpha::Number), + @nospecialize(A::CuStridedView), + @nospecialize(beta::Number), + @nospecialize(C::CuStridedView)) scalar_type = plan.scalar_type - cuTENSOR.cutensorReduce(cuTENSOR.handle(), plan, - Ref{scalar_type}(alpha), A, - Ref{scalar_type}(beta), C, C, - plan.workspace, sizeof(plan.workspace), cuTENSOR.stream()) + cutensorReduce(handle(), plan, + Ref{scalar_type}(alpha), A, + Ref{scalar_type}(beta), C, C, + plan.workspace, sizeof(plan.workspace), stream()) return C end -function cuTENSOR.plan_reduction(@nospecialize(A::CuStridedView), Ainds::ModeType, - opA::cutensorOperator_t, - @nospecialize(C::CuStridedView), Cinds::ModeType, - opC::cutensorOperator_t, - opReduce::cutensorOperator_t; - jit::cutensorJitMode_t=JIT_MODE_NONE, - workspace::cutensorWorksizePreference_t=WORKSPACE_DEFAULT, - algo::cutensorAlgo_t=ALGO_DEFAULT, - compute_type::Union{DataType,cutensorComputeDescriptorEnum, - Nothing}=nothing) +function plan_reduction(@nospecialize(A::CuStridedView), Ainds::ModeType, + opA::cutensorOperator_t, + @nospecialize(C::CuStridedView), Cinds::ModeType, + opC::cutensorOperator_t, + opReduce::cutensorOperator_t; + jit::cutensorJitMode_t=JIT_MODE_NONE, + workspace::cutensorWorksizePreference_t=WORKSPACE_DEFAULT, + algo::cutensorAlgo_t=ALGO_DEFAULT, + compute_type::Union{DataType,cutensorComputeDescriptorEnum, + Nothing}=nothing) !is_unary(opA) && throw(ArgumentError("opA must be a unary op!")) !is_unary(opC) && throw(ArgumentError("opC must be a unary op!")) !is_binary(opReduce) && throw(ArgumentError("opReduce must be a binary op!")) @@ -609,13 +562,76 @@ function cuTENSOR.plan_reduction(@nospecialize(A::CuStridedView), Ainds::ModeTyp modeC = collect(Cint, Cinds) actual_compute_type = if compute_type === nothing - cuTENSOR.reduction_compute_types[(eltype(A), eltype(C))] + reduction_compute_types[(eltype(A), eltype(C))] + else + compute_type + end + + desc = Ref{cutensorOperationDescriptor_t}() + cutensorCreateReduction(handle(), + desc, + descA, modeA, opA, + descC, modeC, opC, + descC, modeC, opReduce, + actual_compute_type) + + plan_pref = Ref{cutensorPlanPreference_t}() + cutensorCreatePlanPreference(handle(), plan_pref, algo, jit) + + return CuTensorPlan(desc[], plan_pref[]; workspacePref=workspace) +end + +# trace! +# ------ +# not actually part of cuTENSOR, just a special case of reduce +function trace!(plan::CuTensorPlan, + @nospecialize(alpha::Number), + @nospecialize(A::CuStridedView), + @nospecialize(beta::Number), + @nospecialize(C::CuStridedView)) + return reduce!(plan, alpha, A, beta, C) +end + +function plan_trace(@nospecialize(A::AbstractArray), Ainds::ModeType, + opA::cutensorOperator_t, + @nospecialize(C::AbstractArray), Cinds::ModeType, + opC::cutensorOperator_t, + opReduce::cutensorOperator_t; + jit::cutensorJitMode_t=JIT_MODE_NONE, + workspace::cutensorWorksizePreference_t=WORKSPACE_DEFAULT, + algo::cutensorAlgo_t=ALGO_DEFAULT, + compute_type::Union{DataType,cutensorComputeDescriptorEnum, + Nothing}=nothing) + !is_unary(opA) && throw(ArgumentError("opA must be a unary op!")) + !is_unary(opC) && throw(ArgumentError("opC must be a unary op!")) + !is_binary(opReduce) && throw(ArgumentError("opReduce must be a binary op!")) + + # TODO: check if this can be avoided, available in caller + # TODO: cuTENSOR will allocate sizes and strides anyways, could use that here + _, cindA1, cindA2 = TO.trace_indices(tuple(Ainds...), tuple(Cinds...)) + + # add strides of cindA2 to strides of cindA1 -> selects diagonal + stA = strides(A) + for (i, j) in zip(cindA1, cindA2) + stA = Base.setindex(stA, stA[i] + stA[j], i) + end + szA = TT.deleteat(size(A), cindA2) + stA′ = TT.deleteat(stA, cindA2) + + descA = CuTensorDescriptor(A; size=szA, strides=stA′) + descC = CuTensorDescriptor(C) + + modeA = collect(Cint, deleteat!(Ainds, cindA2)) + modeC = collect(Cint, Cinds) + + actual_compute_type = if compute_type === nothing + reduction_compute_types[(eltype(A), eltype(C))] else compute_type end desc = Ref{cutensorOperationDescriptor_t}() - cutensorCreateReduction(cuTENSOR.handle(), + cutensorCreateReduction(handle(), desc, descA, modeA, opA, descC, modeC, opC, @@ -623,7 +639,7 @@ function cuTENSOR.plan_reduction(@nospecialize(A::CuStridedView), Ainds::ModeTyp actual_compute_type) plan_pref = Ref{cutensorPlanPreference_t}() - cutensorCreatePlanPreference(cuTENSOR.handle(), plan_pref, algo, jit) + cutensorCreatePlanPreference(handle(), plan_pref, algo, jit) return CuTensorPlan(desc[], plan_pref[]; workspacePref=workspace) end