From 3c734ae7b923064326e3f9d5e67f24cadb41fbb7 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Sun, 11 Feb 2024 10:10:38 +0100 Subject: [PATCH] Formatter --- ext/TensorOperationscuTENSORExt.jl | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/ext/TensorOperationscuTENSORExt.jl b/ext/TensorOperationscuTENSORExt.jl index 3eaf3a0a..78e20ab6 100644 --- a/ext/TensorOperationscuTENSORExt.jl +++ b/ext/TensorOperationscuTENSORExt.jl @@ -38,7 +38,6 @@ const CuStridedView = StridedViewsCUDAExt.CuStridedView const SUPPORTED_CUARRAYS = Union{AnyCuArray,CuStridedView} const cuTENSORBackend = TO.Backend{:cuTENSOR} - function TO.tensorscalar(C::SUPPORTED_CUARRAYS) return ndims(C) == 0 ? tensorscalar(collect(C)) : throw(DimensionMismatch()) end @@ -60,7 +59,8 @@ end # making sure that if no backend is specified, the cuTENSOR backend is used: -function TO.tensoradd!(C::SUPPORTED_CUARRAYS, pC::Index2Tuple, A::SUPPORTED_CUARRAYS, conjA::Symbol, +function TO.tensoradd!(C::SUPPORTED_CUARRAYS, pC::Index2Tuple, A::SUPPORTED_CUARRAYS, + conjA::Symbol, α::Number, β::Number) return tensoradd!(C, pC, A, conjA, α, β, cuTENSORBackend()) end @@ -171,7 +171,7 @@ function TO.tensorcontract!(C::CuArray, pC::Index2Tuple, Ainds, Binds, Cinds = collect.(TO.contract_labels(pA, pB, pC)) opA = tensorop(A, conjA) opB = tensorop(B, conjB) - + # dispatch to cuTENSOR return cuTENSOR.contract!(α, A, Ainds, opA, @@ -206,11 +206,11 @@ function plan_trace(@nospecialize(A::AbstractArray), Ainds::cuTENSOR.ModeType, !cuTENSOR.is_unary(opA) && throw(ArgumentError("opA must be a unary op!")) !cuTENSOR.is_unary(opC) && throw(ArgumentError("opC must be a unary op!")) !cuTENSOR.is_binary(opReduce) && throw(ArgumentError("opReduce must be a binary op!")) - + # TODO: check if this can be avoided, available in caller # TODO: cuTENSOR will allocate sizes and strides anyways, could use that here _, cindA1, cindA2 = TO.trace_indices(tuple(Ainds...), tuple(Cinds...)) - + # add strides of cindA2 to strides of cindA1 -> selects diagonal stA = strides(A) for (i, j) in zip(cindA1, cindA2) @@ -218,26 +218,26 @@ function plan_trace(@nospecialize(A::AbstractArray), Ainds::cuTENSOR.ModeType, end szA = TT.deleteat(size(A), cindA2) stA′ = TT.deleteat(stA, cindA2) - + descA = cuTENSOR.CuTensorDescriptor(A; size=szA, strides=stA′) descC = cuTENSOR.CuTensorDescriptor(C) - + modeA = collect(Cint, deleteat!(Ainds, cindA2)) modeC = collect(Cint, Cinds) - + actual_compute_type = if compute_type === nothing cuTENSOR.reduction_compute_types[(eltype(A), eltype(C))] else compute_type end - + desc = Ref{cuTENSOR.cutensorOperationDescriptor_t}() cuTENSOR.cutensorCreateReduction(cuTENSOR.handle(), - desc, - descA, modeA, opA, - descC, modeC, opC, - descC, modeC, opReduce, - actual_compute_type) + desc, + descA, modeA, opA, + descC, modeC, opC, + descC, modeC, opReduce, + actual_compute_type) plan_pref = Ref{cuTENSOR.cutensorPlanPreference_t}() cuTENSOR.cutensorCreatePlanPreference(cuTENSOR.handle(), plan_pref, algo, jit)