Skip to content

Commit

Permalink
Formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Feb 11, 2024
1 parent f066505 commit 3c734ae
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions ext/TensorOperationscuTENSORExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ const CuStridedView = StridedViewsCUDAExt.CuStridedView
const SUPPORTED_CUARRAYS = Union{AnyCuArray,CuStridedView}
const cuTENSORBackend = TO.Backend{:cuTENSOR}


function TO.tensorscalar(C::SUPPORTED_CUARRAYS)
return ndims(C) == 0 ? tensorscalar(collect(C)) : throw(DimensionMismatch())
end
Expand All @@ -60,7 +59,8 @@ end

# making sure that if no backend is specified, the cuTENSOR backend is used:

function TO.tensoradd!(C::SUPPORTED_CUARRAYS, pC::Index2Tuple, A::SUPPORTED_CUARRAYS, conjA::Symbol,
function TO.tensoradd!(C::SUPPORTED_CUARRAYS, pC::Index2Tuple, A::SUPPORTED_CUARRAYS,
conjA::Symbol,
α::Number, β::Number)
return tensoradd!(C, pC, A, conjA, α, β, cuTENSORBackend())
end
Expand Down Expand Up @@ -171,7 +171,7 @@ function TO.tensorcontract!(C::CuArray, pC::Index2Tuple,
Ainds, Binds, Cinds = collect.(TO.contract_labels(pA, pB, pC))
opA = tensorop(A, conjA)
opB = tensorop(B, conjB)

# dispatch to cuTENSOR
return cuTENSOR.contract!(α,
A, Ainds, opA,
Expand Down Expand Up @@ -206,38 +206,38 @@ function plan_trace(@nospecialize(A::AbstractArray), Ainds::cuTENSOR.ModeType,
!cuTENSOR.is_unary(opA) && throw(ArgumentError("opA must be a unary op!"))
!cuTENSOR.is_unary(opC) && throw(ArgumentError("opC must be a unary op!"))
!cuTENSOR.is_binary(opReduce) && throw(ArgumentError("opReduce must be a binary op!"))

# TODO: check if this can be avoided, available in caller
# TODO: cuTENSOR will allocate sizes and strides anyways, could use that here
_, cindA1, cindA2 = TO.trace_indices(tuple(Ainds...), tuple(Cinds...))

# add strides of cindA2 to strides of cindA1 -> selects diagonal
stA = strides(A)
for (i, j) in zip(cindA1, cindA2)
stA = Base.setindex(stA, stA[i] + stA[j], i)
end
szA = TT.deleteat(size(A), cindA2)
stA′ = TT.deleteat(stA, cindA2)

descA = cuTENSOR.CuTensorDescriptor(A; size=szA, strides=stA′)
descC = cuTENSOR.CuTensorDescriptor(C)

modeA = collect(Cint, deleteat!(Ainds, cindA2))
modeC = collect(Cint, Cinds)

actual_compute_type = if compute_type === nothing
cuTENSOR.reduction_compute_types[(eltype(A), eltype(C))]
else
compute_type
end

desc = Ref{cuTENSOR.cutensorOperationDescriptor_t}()
cuTENSOR.cutensorCreateReduction(cuTENSOR.handle(),
desc,
descA, modeA, opA,
descC, modeC, opC,
descC, modeC, opReduce,
actual_compute_type)
desc,
descA, modeA, opA,
descC, modeC, opC,
descC, modeC, opReduce,
actual_compute_type)

plan_pref = Ref{cuTENSOR.cutensorPlanPreference_t}()
cuTENSOR.cutensorCreatePlanPreference(cuTENSOR.handle(), plan_pref, algo, jit)
Expand Down

0 comments on commit 3c734ae

Please sign in to comment.