From 87313c48d42224e4d57b9d3d8a18237208f62ff6 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 29 Nov 2024 13:45:17 +0100 Subject: [PATCH 01/12] Add precompile statements new precompilation approach --- src/TensorOperations.jl | 2 ++ src/precompile.jl | 64 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+) create mode 100644 src/precompile.jl diff --git a/src/TensorOperations.jl b/src/TensorOperations.jl index 9a58c87..fa5c0b5 100644 --- a/src/TensorOperations.jl +++ b/src/TensorOperations.jl @@ -77,4 +77,6 @@ function __init__() @require_extensions end +include("precompile.jl") + end # module diff --git a/src/precompile.jl b/src/precompile.jl new file mode 100644 index 0000000..7716f11 --- /dev/null +++ b/src/precompile.jl @@ -0,0 +1,64 @@ +const PRECOMPILE_ELTYPES = (Float64, ComplexF64) + +# tensoradd! +# ---------- +const PRECOMPILE_ADD_NDIMS = 5 + +for T in PRECOMPILE_ELTYPES + for N in 0:PRECOMPILE_ADD_NDIMS + C = Array{T,N} + A = Array{T,N} + pA = Index2Tuple{N,0} + + precompile(tensoradd!, (C, A, pA, Bool, One, Zero)) + precompile(tensoradd!, (C, A, pA, Bool, T, Zero)) + precompile(tensoradd!, (C, A, pA, Bool, T, T)) + + precompile(tensoralloc_add, (T, A, pA, Bool, Val{true})) + precompile(tensoralloc_add, (T, A, pA, Bool, Val{false})) + end +end + +# tensortrace! +# ------------ +const PRECOMPILE_TRACE_NDIMS = (4, 2) + +for T in PRECOMPILE_ELTYPES + for N1 in 0:PRECOMPILE_TRACE_NDIMS[1], N2 in 0:PRECOMPILE_TRACE_NDIMS[2] + C = Array{T,N1} + A = Array{T,N1 + 2N2} + p = Index2Tuple{N1,0} + q = Index2Tuple{N2,N2} + + precompile(tensortrace!, (C, A, p, q, Bool, One, Zero)) + precompile(tensortrace!, (C, A, p, q, Bool, T, Zero)) + precompile(tensortrace!, (C, A, p, q, Bool, T, T)) + + # allocation re-uses tensoralloc_add + end +end + +# tensorcontract! +# --------------- +const PRECOMPILE_CONTRACT_NDIMS = (3, 2, 3) + +for T in PRECOMPILE_ELTYPES + for N1 in 0:PRECOMPILE_CONTRACT_NDIMS[1], N2 in 0:PRECOMPILE_CONTRACT_NDIMS[2], + N3 in 0:PRECOMPILE_CONTRACT_NDIMS[3] + + NA = N1 + N2 + NB = N2 + N3 + NC = N1 + N3 + C, A, B = Array{T,NC}, Array{T,NA}, Array{T,NB} + pA = Index2Tuple{N1,N2} + pB = Index2Tuple{N2,N3} + pAB = Index2Tuple{NC,0} + + precompile(tensorcontract!, (C, A, pA, Bool, B, pB, Bool, pAB, One, Zero)) + precompile(tensorcontract!, (C, A, pA, Bool, B, pB, Bool, pAB, T, Zero)) + precompile(tensorcontract!, (C, A, pA, Bool, B, pB, Bool, pAB, T, T)) + + precompile(tensoralloc_contract, (T, A, pA, Bool, B, pB, Bool, pAB, Val{true})) + precompile(tensoralloc_contract, (T, A, pA, Bool, B, pB, Bool, pAB, Val{false})) + end +end From 5b7fc776b7598cc291e6b654319fb50e69f8a728 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 27 Feb 2025 10:56:05 -0500 Subject: [PATCH 02/12] Bump v5.1.5 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 9fce9b0..080456e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorOperations" uuid = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" authors = ["Lukas Devos ", "Maarten Van Damme ", "Jutho Haegeman "] -version = "5.1.4" +version = "5.1.5" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" From 2e3040461449b76e776616c94d249df9d199af35 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 27 Feb 2025 11:10:45 -0500 Subject: [PATCH 03/12] Use `Zero()` instead of `false` --- src/implementation/functions.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/implementation/functions.jl b/src/implementation/functions.jl index 7444b0b..57f567b 100644 --- a/src/implementation/functions.jl +++ b/src/implementation/functions.jl @@ -79,7 +79,7 @@ See also [`tensorcopy`](@ref) and [`tensoradd!`](@ref) """ function tensorcopy!(C, A, pA::Index2Tuple, conjA::Bool=false, α::Number=One(), backend=DefaultBackend(), allocator=DefaultAllocator()) - return tensoradd!(C, A, pA, conjA, α, false, backend, allocator) + return tensoradd!(C, A, pA, conjA, α, Zero(), backend, allocator) end # ------------------------------------------------------------------------------------------ From 6ade3c144011590e46d713dd279a44bc2cfbb34d Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 28 Feb 2025 09:51:21 -0500 Subject: [PATCH 04/12] Make precompilation configurable --- .gitignore | 3 +- Project.toml | 4 ++ src/precompile.jl | 141 ++++++++++++++++++++++++++++++---------------- 3 files changed, 100 insertions(+), 48 deletions(-) diff --git a/.gitignore b/.gitignore index 0ee3d17..5e4280e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ *.jl.cov *.jl.*.cov *.jl.mem -Manifest.toml \ No newline at end of file +Manifest.toml +LocalPreferences.toml diff --git a/Project.toml b/Project.toml index 080456e..427ea56 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,8 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" +PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +Preferences = "21216c6a-2e73-6563-6e65-726566657250" PtrArrays = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d" Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67" StridedViews = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143" @@ -38,6 +40,8 @@ LRUCache = "1" LinearAlgebra = "1.6" Logging = "1.6" PackageExtensionCompat = "1" +PrecompileTools = "1.1" +Preferences = "1.4" PtrArrays = "1.2" Random = "1" Strided = "2.2" diff --git a/src/precompile.jl b/src/precompile.jl index 7716f11..305df3a 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -1,64 +1,111 @@ -const PRECOMPILE_ELTYPES = (Float64, ComplexF64) +using PrecompileTools: PrecompileTools +using Preferences: @load_preference -# tensoradd! -# ---------- -const PRECOMPILE_ADD_NDIMS = 5 +# Validate preferences input +# -------------------------- +function validate_precompile_eltypes(eltypes) + eltypes isa Vector{String} || + throw(ArgumentError("`precompile_eltypes` should be a vector of strings, got $(typeof(eltypes)) instead")) + return map(eltypes) do Tstr + T = eval(Meta.parse(Tstr)) + (T isa DataType && T <: Number) || + error("Invalid precompile_eltypes entry: `$Tstr`") + return T + end +end -for T in PRECOMPILE_ELTYPES - for N in 0:PRECOMPILE_ADD_NDIMS - C = Array{T,N} - A = Array{T,N} - pA = Index2Tuple{N,0} +function validate_add_ndims(add_ndims) + add_ndims isa Int || + throw(ArgumentError("`precompile_add_ndims` should be an `Int`, got `$add_ndims`")) + add_ndims ≥ 0 || error("Invalid precompile_add_ndims: `$add_ndims`") + return add_ndims +end - precompile(tensoradd!, (C, A, pA, Bool, One, Zero)) - precompile(tensoradd!, (C, A, pA, Bool, T, Zero)) - precompile(tensoradd!, (C, A, pA, Bool, T, T)) +function validate_trace_ndims(trace_ndims) + trace_ndims isa Vector{Int} && length(trace_ndims) == 2 || + throw(ArgumentError("`precompile_trace_ndims` should be a `Vector{Int}` of length 2, got `$trace_ndims`")) + all(≥(0), trace_ndims) || error("Invalid precompile_trace_ndims: `$trace_ndims`") + return trace_ndims +end - precompile(tensoralloc_add, (T, A, pA, Bool, Val{true})) - precompile(tensoralloc_add, (T, A, pA, Bool, Val{false})) - end +function validate_contract_ndims(contract_ndims) + contract_ndims isa Vector{Int} && length(contract_ndims) == 3 || + throw(ArgumentError("`precompile_contract_ndims` should be a `Vector{Int}` of length 3, got `$contract_ndims`")) + all(≥(0), contract_ndims) || + error("Invalid precompile_contract_ndims: `$contract_ndims`") + return contract_ndims end -# tensortrace! -# ------------ -const PRECOMPILE_TRACE_NDIMS = (4, 2) +# Static preferences +# ------------------ +const PRECOMPILE_ELTYPES = validate_precompile_eltypes(@load_preference("precompile_eltypes", + ["Float64", + "ComplexF64"])) +const PRECOMPILE_ADD_NDIMS = validate_add_ndims(@load_preference("precompile_add_ndims", 5)) +const PRECOMPILE_TRACE_NDIMS = validate_trace_ndims(@load_preference("precompile_trace_ndims", + [4, 2])) +const PRECOMPILE_CONTRACT_NDIMS = validate_contract_ndims(@load_preference("precompile_contract_ndims", + [3, 2, 3])) -for T in PRECOMPILE_ELTYPES - for N1 in 0:PRECOMPILE_TRACE_NDIMS[1], N2 in 0:PRECOMPILE_TRACE_NDIMS[2] - C = Array{T,N1} - A = Array{T,N1 + 2N2} - p = Index2Tuple{N1,0} - q = Index2Tuple{N2,N2} +# Using explicit precompile statements here instead of @compile_workload: +# Actually running the precompilation through PrecompileTools leads to longer compile times +# Keeping the workload_enabled functionality to have the option of disabling precompilation +# in a compatible manner with the rest of the ecosystem +if PrecompileTools.workload_enabled(@__MODULE__) + # tensoradd! + # ---------- + for T in PRECOMPILE_ELTYPES + for N in 0:PRECOMPILE_ADD_NDIMS + C = Array{T,N} + A = Array{T,N} + pA = Index2Tuple{N,0} - precompile(tensortrace!, (C, A, p, q, Bool, One, Zero)) - precompile(tensortrace!, (C, A, p, q, Bool, T, Zero)) - precompile(tensortrace!, (C, A, p, q, Bool, T, T)) + precompile(tensoradd!, (C, A, pA, Bool, One, Zero)) + precompile(tensoradd!, (C, A, pA, Bool, T, Zero)) + precompile(tensoradd!, (C, A, pA, Bool, T, T)) - # allocation re-uses tensoralloc_add + precompile(tensoralloc_add, (T, A, pA, Bool, Val{true})) + precompile(tensoralloc_add, (T, A, pA, Bool, Val{false})) + end end -end -# tensorcontract! -# --------------- -const PRECOMPILE_CONTRACT_NDIMS = (3, 2, 3) + # tensortrace! + # ------------ + for T in PRECOMPILE_ELTYPES + for N1 in 0:PRECOMPILE_TRACE_NDIMS[1], N2 in 0:PRECOMPILE_TRACE_NDIMS[2] + C = Array{T,N1} + A = Array{T,N1 + 2N2} + p = Index2Tuple{N1,0} + q = Index2Tuple{N2,N2} + + precompile(tensortrace!, (C, A, p, q, Bool, One, Zero)) + precompile(tensortrace!, (C, A, p, q, Bool, T, Zero)) + precompile(tensortrace!, (C, A, p, q, Bool, T, T)) + + # allocation re-uses tensoralloc_add + end + end -for T in PRECOMPILE_ELTYPES - for N1 in 0:PRECOMPILE_CONTRACT_NDIMS[1], N2 in 0:PRECOMPILE_CONTRACT_NDIMS[2], - N3 in 0:PRECOMPILE_CONTRACT_NDIMS[3] + # tensorcontract! + # --------------- + for T in PRECOMPILE_ELTYPES + for N1 in 0:PRECOMPILE_CONTRACT_NDIMS[1], N2 in 0:PRECOMPILE_CONTRACT_NDIMS[2], + N3 in 0:PRECOMPILE_CONTRACT_NDIMS[3] - NA = N1 + N2 - NB = N2 + N3 - NC = N1 + N3 - C, A, B = Array{T,NC}, Array{T,NA}, Array{T,NB} - pA = Index2Tuple{N1,N2} - pB = Index2Tuple{N2,N3} - pAB = Index2Tuple{NC,0} + NA = N1 + N2 + NB = N2 + N3 + NC = N1 + N3 + C, A, B = Array{T,NC}, Array{T,NA}, Array{T,NB} + pA = Index2Tuple{N1,N2} + pB = Index2Tuple{N2,N3} + pAB = Index2Tuple{NC,0} - precompile(tensorcontract!, (C, A, pA, Bool, B, pB, Bool, pAB, One, Zero)) - precompile(tensorcontract!, (C, A, pA, Bool, B, pB, Bool, pAB, T, Zero)) - precompile(tensorcontract!, (C, A, pA, Bool, B, pB, Bool, pAB, T, T)) + precompile(tensorcontract!, (C, A, pA, Bool, B, pB, Bool, pAB, One, Zero)) + precompile(tensorcontract!, (C, A, pA, Bool, B, pB, Bool, pAB, T, Zero)) + precompile(tensorcontract!, (C, A, pA, Bool, B, pB, Bool, pAB, T, T)) - precompile(tensoralloc_contract, (T, A, pA, Bool, B, pB, Bool, pAB, Val{true})) - precompile(tensoralloc_contract, (T, A, pA, Bool, B, pB, Bool, pAB, Val{false})) + precompile(tensoralloc_contract, (T, A, pA, Bool, B, pB, Bool, pAB, Val{true})) + precompile(tensoralloc_contract, (T, A, pA, Bool, B, pB, Bool, pAB, Val{false})) + end end end From 84afe182a6af6a85c703a6971416fbe6c64f701f Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 28 Feb 2025 10:21:39 -0500 Subject: [PATCH 05/12] Add documentation --- docs/make.jl | 3 ++- docs/src/index.md | 2 +- docs/src/man/precompilation.md | 46 ++++++++++++++++++++++++++++++++++ src/precompile.jl | 8 +++--- 4 files changed, 53 insertions(+), 6 deletions(-) create mode 100644 docs/src/man/precompilation.md diff --git a/docs/make.jl b/docs/make.jl index f6670ba..4285b7b 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -11,7 +11,8 @@ makedocs(; modules=[TensorOperations], "man/interface.md", "man/backends.md", "man/autodiff.md", - "man/implementation.md"], + "man/implementation.md", + "man/precompilation.md"], "Index" => "index/index.md"]) # Documenter can also automatically deploy documentation to gh-pages. diff --git a/docs/src/index.md b/docs/src/index.md index 099a51d..ddb2a6d 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -5,7 +5,7 @@ ## Table of contents ```@contents -Pages = ["index.md", "man/indexnotation.md", "man/functions.md", "man/interface.md", "man/backends.md", "man/autodiff.md", "man/implementation.md"] +Pages = ["index.md", "man/indexnotation.md", "man/functions.md", "man/interface.md", "man/backends.md", "man/autodiff.md", "man/implementation.md", "man/precompilation.md"] Depth = 4 ``` diff --git a/docs/src/man/precompilation.md b/docs/src/man/precompilation.md new file mode 100644 index 0000000..ae8425b --- /dev/null +++ b/docs/src/man/precompilation.md @@ -0,0 +1,46 @@ +# Precompilation + +Since version `v5.1.5`, TensorOperations.jl has some support for precompiling commonly called functions. +The guiding philosophy is that often, tensor contractions are (part of) the bottlenecks of typical workflows, +and as such we want to maximize performance. As a result, we are choosing to specialize many functions which +may lead to a rather large time-to-first-execution (TTFX). In order to mitigate this, some of that work can +be moved to precompile-time, avoiding the need to re-compile these specializations for every fresh Julia session. + +Nevertheless, TensorOperations is designed to work with a large variety of input types, and simply enumerating +all of these tends to lead to prohibitively large precompilation times, as well as large system images. +Therefore, there is some customization possible to tweak the desired level of precompilation, trading in +faster precompile times for fast TTFX for a wider range of inputs. + +## Defaults + +By default, precompilation is enabled for "tensors" of type `Array{T,N}`, where `T` and `N` range over the following values: + +* `T` is either `Float64` or `ComplexF64` +* `tensoradd!` is precompiled up to `N = 5` +* `tensortrace!` is precompiled up to `4` free output indices and `2` pairs of traced indices +* `tensorcontract!` is precompiled up to `3` free output indices on both inputs, and `2` contracted indices + +## Custom settings + +The default precompilation settings can be tweaked to allow for more or less expansive coverage. This is achieved +through a combination of `PrecompileTools`- and `Preferences`-based functionality. + +To disable precompilation altogether, for example during development or when you prefer to have small binaries, +you can *locally* change the `"precompile_workload"` key in the preferences. + +```julia +using TensorOperations, Preferences +set_preferences!(TensorOperations, "precompile_workload" => false; force=true) +``` + +Alternatively, you can keep precompilation enabled, change the settings above through the same machinery, via: + +* `"precomple_eltypes"`: a `Vector{String}` that evaluate to the desired values of `T<:Number` +* `"precompile_add_ndims"`: an `Int` to specify the maximum `N` for `tensoradd!` +* `"precompile_trace_ndims"`: a `Vector{Int}` of length 2 to specify the maximal number of free and traced indices for `tensortrace!`. +* `"precompile_contract_ndims"`: a `Vector{Int}` of length 2 to specify the maximal number of free and contracted indices for `tensorcontract!`. + +!!! note "Backends" + + Currently, there is no support for precompiling methods that do not use the default backend. If this is a + feature you would find useful, feel free to contact us or open an issue. diff --git a/src/precompile.jl b/src/precompile.jl index 305df3a..1bfbf9a 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -29,8 +29,8 @@ function validate_trace_ndims(trace_ndims) end function validate_contract_ndims(contract_ndims) - contract_ndims isa Vector{Int} && length(contract_ndims) == 3 || - throw(ArgumentError("`precompile_contract_ndims` should be a `Vector{Int}` of length 3, got `$contract_ndims`")) + contract_ndims isa Vector{Int} && length(contract_ndims) == 2 || + throw(ArgumentError("`precompile_contract_ndims` should be a `Vector{Int}` of length 2, got `$contract_ndims`")) all(≥(0), contract_ndims) || error("Invalid precompile_contract_ndims: `$contract_ndims`") return contract_ndims @@ -45,7 +45,7 @@ const PRECOMPILE_ADD_NDIMS = validate_add_ndims(@load_preference("precompile_add const PRECOMPILE_TRACE_NDIMS = validate_trace_ndims(@load_preference("precompile_trace_ndims", [4, 2])) const PRECOMPILE_CONTRACT_NDIMS = validate_contract_ndims(@load_preference("precompile_contract_ndims", - [3, 2, 3])) + [4, 2])) # Using explicit precompile statements here instead of @compile_workload: # Actually running the precompilation through PrecompileTools leads to longer compile times @@ -90,7 +90,7 @@ if PrecompileTools.workload_enabled(@__MODULE__) # --------------- for T in PRECOMPILE_ELTYPES for N1 in 0:PRECOMPILE_CONTRACT_NDIMS[1], N2 in 0:PRECOMPILE_CONTRACT_NDIMS[2], - N3 in 0:PRECOMPILE_CONTRACT_NDIMS[3] + N3 in 0:PRECOMPILE_CONTRACT_NDIMS[1] NA = N1 + N2 NB = N2 + N3 From 28a38343f4f24019ffa9d000ead9c4af86936f05 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 28 Feb 2025 10:27:01 -0500 Subject: [PATCH 06/12] change to minor version --- Project.toml | 2 +- docs/src/man/precompilation.md | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 427ea56..31bfdf7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorOperations" uuid = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" authors = ["Lukas Devos ", "Maarten Van Damme ", "Jutho Haegeman "] -version = "5.1.5" +version = "5.2.0" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" diff --git a/docs/src/man/precompilation.md b/docs/src/man/precompilation.md index ae8425b..64557d7 100644 --- a/docs/src/man/precompilation.md +++ b/docs/src/man/precompilation.md @@ -1,6 +1,6 @@ # Precompilation -Since version `v5.1.5`, TensorOperations.jl has some support for precompiling commonly called functions. +TensorOperations.jl has some support for precompiling commonly called functions. The guiding philosophy is that often, tensor contractions are (part of) the bottlenecks of typical workflows, and as such we want to maximize performance. As a result, we are choosing to specialize many functions which may lead to a rather large time-to-first-execution (TTFX). In order to mitigate this, some of that work can @@ -11,6 +11,10 @@ all of these tends to lead to prohibitively large precompilation times, as well Therefore, there is some customization possible to tweak the desired level of precompilation, trading in faster precompile times for fast TTFX for a wider range of inputs. +!!! compat "TensorOperations v5.2.0" + + Precompilation support requires at least TensorOperations v5.2.0. + ## Defaults By default, precompilation is enabled for "tensors" of type `Array{T,N}`, where `T` and `N` range over the following values: From d71968f212db3a250c2072dd6f0ca49958bf5579 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 4 Mar 2025 10:39:45 -0500 Subject: [PATCH 07/12] Rework Strided wrapping --- ext/TensorOperationsBumperExt.jl | 8 +++++ src/implementation/blascontract.jl | 18 ++++++----- src/implementation/diagonal.jl | 41 +++++++++++++++++-------- src/implementation/strided.jl | 48 +++++++++++++++++++++++------- 4 files changed, 86 insertions(+), 29 deletions(-) diff --git a/ext/TensorOperationsBumperExt.jl b/ext/TensorOperationsBumperExt.jl index 2c6db4f..0cae1ad 100644 --- a/ext/TensorOperationsBumperExt.jl +++ b/ext/TensorOperationsBumperExt.jl @@ -3,6 +3,14 @@ module TensorOperationsBumperExt using TensorOperations using Bumper +# Hack to normalize StridedView type to avoid too many specializations +# This is allowed because bumper ensures that the pointer won't be GC'd +# and we never return `parent(SV)` anyways. +function TensorOperations.wrap_stridedview(A::Bumper.UnsafeArray) + mem_A = Base.unsafe_wrap(Memory{eltype(A)}, pointer(A), length(A)) + return TensorOperations.StridedView(mem_A, size(A), strides(A), 0, identity) +end + function TensorOperations.tensoralloc(::Type{A}, structure, ::Val{istemp}, buf::Union{SlabBuffer,AllocBuffer}) where {A<:AbstractArray, istemp} diff --git a/src/implementation/blascontract.jl b/src/implementation/blascontract.jl index dbbf742..bbf0ba1 100644 --- a/src/implementation/blascontract.jl +++ b/src/implementation/blascontract.jl @@ -47,13 +47,18 @@ function _blas_contract!(C, A, pA, B, pB, pAB, α, β, backend, allocator) flagC = isblasdestination(C, ipAB) if flagC C_ = C - _unsafe_blas_contract!(C_, A_, pA, B_, pB, ipAB, α, β) + _unsafe_blas_contract!(wrap_stridedview(C_), + wrap_stridedview(A_), pA, + wrap_stridedview(B_), pB, + ipAB, α, β) else - C_ = SV(tensoralloc_add(TC, C, ipAB, false, Val(true), allocator)) - _unsafe_blas_contract!(C_, A_, pA, B_, pB, trivialpermutation(ipAB), - one(TC), zero(TC)) + C_ = tensoralloc_add(TC, C, ipAB, false, Val(true), allocator) + _unsafe_blas_contract!(wrap_stridedview(C_), + wrap_stridedview(A_), pA, + wrap_stridedview(B_), pB, + trivialpermutation(ipAB), one(TC), zero(TC)) tensoradd!(C, C_, pAB, false, α, β, backend, allocator) - tensorfree!(C_.parent, allocator) + tensorfree!(C_, allocator) end flagA || tensorfree!(A_.parent, allocator) flagB || tensorfree!(B_.parent, allocator) @@ -85,8 +90,7 @@ end flagA = isblascontractable(A, pA) && eltype(A) == TC if !flagA A_ = tensoralloc_add(TC, A, pA, false, Val(true), allocator) - Anew = SV(A_, size(A_), strides(A_), 0, A.op) - Anew = tensoradd!(Anew, A, pA, false, One(), Zero(), backend, allocator) + Anew = tensoradd!(A_, A, pA, false, One(), Zero(), backend, allocator) pAnew = trivialpermutation(pA) else Anew = A diff --git a/src/implementation/diagonal.jl b/src/implementation/diagonal.jl index 11f3e3c..2562b06 100644 --- a/src/implementation/diagonal.jl +++ b/src/implementation/diagonal.jl @@ -11,13 +11,22 @@ function tensorcontract!(C::AbstractArray, dimcheck_tensorcontract(C, A, pA, B, pB, pAB) if conjA && conjB - _diagtensorcontract!(SV(C), conj(SV(A)), pA, conj(SV(B.diag)), pB, pAB, α, β) + _diagtensorcontract!(wrap_stridedview(C), conj(wrap_stridedview(A)), pA, + conj(wrap_stridedview(B.diag)), pB, + pAB, α, β) elseif conjA - _diagtensorcontract!(SV(C), conj(SV(A)), pA, SV(B.diag), pB, pAB, α, β) + _diagtensorcontract!(wrap_stridedview(C), conj(wrap_stridedview(A)), pA, + wrap_stridedview(B.diag), + pB, pAB, α, + β) elseif conjB - _diagtensorcontract!(SV(C), SV(A), pA, conj(SV(B.diag)), pB, pAB, α, β) + _diagtensorcontract!(wrap_stridedview(C), wrap_stridedview(A), pA, + conj(wrap_stridedview(B.diag)), + pB, pAB, α, + β) else - _diagtensorcontract!(SV(C), SV(A), pA, SV(B.diag), pB, pAB, α, β) + _diagtensorcontract!(wrap_stridedview(C), wrap_stridedview(A), pA, + wrap_stridedview(B.diag), pB, pAB, α, β) end return C end @@ -41,13 +50,17 @@ function tensorcontract!(C::AbstractArray, TupleTools.getindices(indCinoBA, tpAB[2])) if conjA && conjB - _diagtensorcontract!(SV(C), conj(SV(B)), rpB, conj(SV(A.diag)), rpA, rpAB, α, β) + _diagtensorcontract!(wrap_stridedview(C), conj(wrap_stridedview(B)), rpB, + conj(wrap_stridedview(A.diag)), rpA, rpAB, α, β) elseif conjA - _diagtensorcontract!(SV(C), SV(B), rpB, conj(SV(A.diag)), rpA, rpAB, α, β) + _diagtensorcontract!(wrap_stridedview(C), wrap_stridedview(B), rpB, + conj(wrap_stridedview(A.diag)), rpA, rpAB, α, β) elseif conjB - _diagtensorcontract!(SV(C), conj(SV(B)), rpB, SV(A.diag), rpA, rpAB, α, β) + _diagtensorcontract!(wrap_stridedview(C), conj(wrap_stridedview(B)), rpB, + wrap_stridedview(A.diag), rpA, rpAB, α, β) else - _diagtensorcontract!(SV(C), SV(B), rpB, SV(A.diag), rpA, rpAB, α, β) + _diagtensorcontract!(wrap_stridedview(C), wrap_stridedview(B), rpB, + wrap_stridedview(A.diag), rpA, rpAB, α, β) end return C end @@ -62,13 +75,17 @@ function tensorcontract!(C::AbstractArray, dimcheck_tensorcontract(C, A, pA, B, pB, pAB) if conjA && conjB - _diagdiagcontract!(SV(C), conj(SV(A.diag)), pA, conj(SV(B.diag)), pB, pAB, α, β) + _diagdiagcontract!(wrap_stridedview(C), conj(wrap_stridedview(A.diag)), pA, + conj(wrap_stridedview(B.diag)), pB, pAB, α, β) elseif conjA - _diagdiagcontract!(SV(C), conj(SV(A.diag)), pA, SV(B.diag), pB, pAB, α, β) + _diagdiagcontract!(wrap_stridedview(C), conj(wrap_stridedview(A.diag)), pA, + wrap_stridedview(B.diag), pB, pAB, α, β) elseif conjB - _diagdiagcontract!(SV(C), SV(A.diag), pA, conj(SV(B.diag)), pB, pAB, α, β) + _diagdiagcontract!(wrap_stridedview(C), wrap_stridedview(A.diag), pA, + conj(wrap_stridedview(B.diag)), pB, pAB, α, β) else - _diagdiagcontract!(SV(C), SV(A.diag), pA, SV(B.diag), pB, pAB, α, β) + _diagdiagcontract!(wrap_stridedview(C), wrap_stridedview(A.diag), pA, + wrap_stridedview(B.diag), pB, pAB, α, β) end return C end diff --git a/src/implementation/strided.jl b/src/implementation/strided.jl index 4ead15e..dec0de4 100644 --- a/src/implementation/strided.jl +++ b/src/implementation/strided.jl @@ -38,16 +38,38 @@ end #------------------------------------------------------------------------------------------- # Force strided implementation on AbstractArray instances with Strided backend #------------------------------------------------------------------------------------------- -const SV = StridedView + +# we normalize the parent types here to avoid too many specializations +# this is allowed because we never return `parent(SV)`, so we can safely wrap anything +# that represents the same data +""" + wrap_stridedview(A::AbstractArray) + +Wrap any compatible array into a `StridedView` for the implementation. +Additionally, we normalize the parent types to avoid having to have too many specializations. +This is allowed because we never return `parent(SV)`, so we can safely wrap anything +that represents the same data. +""" +wrap_stridedview(A::AbstractArray) = StridedView(reshape(A, length(A)), + size(A), strides(A), 0, identity) +wrap_stridedview(A::StridedView) = A +@static if isdefined(Core, :Memory) + # For Arrays: we simply use the memory directly + # TODO: can we also do this for views? + wrap_stridedview(A::Array) = StridedView(A.ref.mem, size(A), strides(A), 0, identity) +end + function tensoradd!(C::AbstractArray, A::AbstractArray, pA::Index2Tuple, conjA::Bool, α::Number, β::Number, backend::StridedBackend, allocator=DefaultAllocator()) # resolve conj flags and absorb into StridedView constructor to avoid type instabilities later on if conjA - stridedtensoradd!(SV(C), conj(SV(A)), pA, α, β, backend, allocator) + stridedtensoradd!(wrap_stridedview(C), conj(wrap_stridedview(A)), pA, α, β, backend, + allocator) else - stridedtensoradd!(SV(C), SV(A), pA, α, β, backend, allocator) + stridedtensoradd!(wrap_stridedview(C), wrap_stridedview(A), pA, α, β, backend, + allocator) end return C end @@ -58,9 +80,11 @@ function tensortrace!(C::AbstractArray, backend::StridedBackend, allocator=DefaultAllocator()) # resolve conj flags and absorb into StridedView constructor to avoid type instabilities later on if conjA - stridedtensortrace!(SV(C), conj(SV(A)), p, q, α, β, backend, allocator) + stridedtensortrace!(wrap_stridedview(C), conj(wrap_stridedview(A)), p, q, α, β, + backend, allocator) else - stridedtensortrace!(SV(C), SV(A), p, q, α, β, backend, allocator) + stridedtensortrace!(wrap_stridedview(C), wrap_stridedview(A), p, q, α, β, backend, + allocator) end return C end @@ -73,16 +97,20 @@ function tensorcontract!(C::AbstractArray, backend::StridedBackend, allocator=DefaultAllocator()) # resolve conj flags and absorb into StridedView constructor to avoid type instabilities later on if conjA && conjB - stridedtensorcontract!(SV(C), conj(SV(A)), pA, conj(SV(B)), pB, pAB, α, β, + stridedtensorcontract!(wrap_stridedview(C), conj(wrap_stridedview(A)), pA, + conj(wrap_stridedview(B)), pB, pAB, α, β, backend, allocator) elseif conjA - stridedtensorcontract!(SV(C), conj(SV(A)), pA, SV(B), pB, pAB, α, β, + stridedtensorcontract!(wrap_stridedview(C), conj(wrap_stridedview(A)), pA, + wrap_stridedview(B), pB, pAB, α, β, backend, allocator) elseif conjB - stridedtensorcontract!(SV(C), SV(A), pA, conj(SV(B)), pB, pAB, α, β, + stridedtensorcontract!(wrap_stridedview(C), wrap_stridedview(A), pA, + conj(wrap_stridedview(B)), pB, pAB, α, β, backend, allocator) else - stridedtensorcontract!(SV(C), SV(A), pA, SV(B), pB, pAB, α, β, + stridedtensorcontract!(wrap_stridedview(C), wrap_stridedview(A), pA, + wrap_stridedview(B), pB, pAB, α, β, backend, allocator) end return C @@ -130,7 +158,7 @@ function stridedtensortrace!(C::StridedView, newstrides = (strideA.(linearize(p))..., (strideA.(q[1]) .+ strideA.(q[2]))...) newsize = (size(C)..., tracesize...) - A′ = SV(A.parent, newsize, newstrides, A.offset, A.op) + A′ = StridedView(A.parent, newsize, newstrides, A.offset, A.op) Strided._mapreducedim!(Scaler(α), Adder(), Scaler(β), newsize, (C, A′)) return C end From 53a6a70b78cbebc4c04f1cd3de25be728201064a Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 4 Mar 2025 12:38:37 -0500 Subject: [PATCH 08/12] small fixes --- src/implementation/blascontract.jl | 4 ++-- src/implementation/strided.jl | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/implementation/blascontract.jl b/src/implementation/blascontract.jl index bbf0ba1..271413e 100644 --- a/src/implementation/blascontract.jl +++ b/src/implementation/blascontract.jl @@ -60,8 +60,8 @@ function _blas_contract!(C, A, pA, B, pB, pAB, α, β, backend, allocator) tensoradd!(C, C_, pAB, false, α, β, backend, allocator) tensorfree!(C_, allocator) end - flagA || tensorfree!(A_.parent, allocator) - flagB || tensorfree!(B_.parent, allocator) + flagA || tensorfree!(A_, allocator) + flagB || tensorfree!(B_, allocator) return C end diff --git a/src/implementation/strided.jl b/src/implementation/strided.jl index dec0de4..a1a9abd 100644 --- a/src/implementation/strided.jl +++ b/src/implementation/strided.jl @@ -50,9 +50,7 @@ Additionally, we normalize the parent types to avoid having to have too many spe This is allowed because we never return `parent(SV)`, so we can safely wrap anything that represents the same data. """ -wrap_stridedview(A::AbstractArray) = StridedView(reshape(A, length(A)), - size(A), strides(A), 0, identity) -wrap_stridedview(A::StridedView) = A +wrap_stridedview(A::AbstractArray) = StridedView(A) @static if isdefined(Core, :Memory) # For Arrays: we simply use the memory directly # TODO: can we also do this for views? From f389aa331912e8c43304c9975355fd2e0789da73 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 4 Mar 2025 12:45:35 -0500 Subject: [PATCH 09/12] Remove `@inline` for `makeblascontractable` --- src/implementation/blascontract.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/implementation/blascontract.jl b/src/implementation/blascontract.jl index 271413e..c1f549c 100644 --- a/src/implementation/blascontract.jl +++ b/src/implementation/blascontract.jl @@ -86,7 +86,7 @@ function _unsafe_blas_contract!(C::StridedView{T}, return C end -@inline function makeblascontractable(A, pA, TC, backend, allocator) +function makeblascontractable(A, pA, TC, backend, allocator) flagA = isblascontractable(A, pA) && eltype(A) == TC if !flagA A_ = tensoralloc_add(TC, A, pA, false, Val(true), allocator) From 72f9dd95775295f8789ae8a7e76da1826c1560b5 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 4 Mar 2025 12:49:53 -0500 Subject: [PATCH 10/12] Sprinkle `@constprop :none` --- src/implementation/strided.jl | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/src/implementation/strided.jl b/src/implementation/strided.jl index a1a9abd..eed2452 100644 --- a/src/implementation/strided.jl +++ b/src/implementation/strided.jl @@ -57,10 +57,11 @@ wrap_stridedview(A::AbstractArray) = StridedView(A) wrap_stridedview(A::Array) = StridedView(A.ref.mem, size(A), strides(A), 0, identity) end -function tensoradd!(C::AbstractArray, - A::AbstractArray, pA::Index2Tuple, conjA::Bool, - α::Number, β::Number, - backend::StridedBackend, allocator=DefaultAllocator()) +Base.@constprop :none function tensoradd!(C::AbstractArray, + A::AbstractArray, pA::Index2Tuple, conjA::Bool, + α::Number, β::Number, + backend::StridedBackend, + allocator=DefaultAllocator()) # resolve conj flags and absorb into StridedView constructor to avoid type instabilities later on if conjA stridedtensoradd!(wrap_stridedview(C), conj(wrap_stridedview(A)), pA, α, β, backend, @@ -72,10 +73,12 @@ function tensoradd!(C::AbstractArray, return C end -function tensortrace!(C::AbstractArray, - A::AbstractArray, p::Index2Tuple, q::Index2Tuple, conjA::Bool, - α::Number, β::Number, - backend::StridedBackend, allocator=DefaultAllocator()) +Base.@constprop :none function tensortrace!(C::AbstractArray, + A::AbstractArray, p::Index2Tuple, + q::Index2Tuple, conjA::Bool, + α::Number, β::Number, + backend::StridedBackend, + allocator=DefaultAllocator()) # resolve conj flags and absorb into StridedView constructor to avoid type instabilities later on if conjA stridedtensortrace!(wrap_stridedview(C), conj(wrap_stridedview(A)), p, q, α, β, @@ -87,12 +90,15 @@ function tensortrace!(C::AbstractArray, return C end -function tensorcontract!(C::AbstractArray, - A::AbstractArray, pA::Index2Tuple, conjA::Bool, - B::AbstractArray, pB::Index2Tuple, conjB::Bool, - pAB::Index2Tuple, - α::Number, β::Number, - backend::StridedBackend, allocator=DefaultAllocator()) +Base.@constprop :none function tensorcontract!(C::AbstractArray, + A::AbstractArray, pA::Index2Tuple, + conjA::Bool, + B::AbstractArray, pB::Index2Tuple, + conjB::Bool, + pAB::Index2Tuple, + α::Number, β::Number, + backend::StridedBackend, + allocator=DefaultAllocator()) # resolve conj flags and absorb into StridedView constructor to avoid type instabilities later on if conjA && conjB stridedtensorcontract!(wrap_stridedview(C), conj(wrap_stridedview(A)), pA, From e1ac592af0422c5780f8eac645c52643961c1e46 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 4 Mar 2025 14:43:02 -0500 Subject: [PATCH 11/12] Make bumper extension LTS compatible --- ext/TensorOperationsBumperExt.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ext/TensorOperationsBumperExt.jl b/ext/TensorOperationsBumperExt.jl index 0cae1ad..fbc8cef 100644 --- a/ext/TensorOperationsBumperExt.jl +++ b/ext/TensorOperationsBumperExt.jl @@ -6,9 +6,11 @@ using Bumper # Hack to normalize StridedView type to avoid too many specializations # This is allowed because bumper ensures that the pointer won't be GC'd # and we never return `parent(SV)` anyways. -function TensorOperations.wrap_stridedview(A::Bumper.UnsafeArray) - mem_A = Base.unsafe_wrap(Memory{eltype(A)}, pointer(A), length(A)) - return TensorOperations.StridedView(mem_A, size(A), strides(A), 0, identity) +@static if isdefined(Core, :Memory) + function TensorOperations.wrap_stridedview(A::Bumper.UnsafeArray) + mem_A = Base.unsafe_wrap(Memory{eltype(A)}, pointer(A), length(A)) + return TensorOperations.StridedView(mem_A, size(A), strides(A), 0, identity) + end end function TensorOperations.tensoralloc(::Type{A}, structure, ::Val{istemp}, From 7721be6e04c68bc5b2859429f49d13acce0a2967 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 4 Mar 2025 14:46:01 -0500 Subject: [PATCH 12/12] Remove docstring --- src/implementation/strided.jl | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/src/implementation/strided.jl b/src/implementation/strided.jl index eed2452..3aa347a 100644 --- a/src/implementation/strided.jl +++ b/src/implementation/strided.jl @@ -39,17 +39,10 @@ end # Force strided implementation on AbstractArray instances with Strided backend #------------------------------------------------------------------------------------------- -# we normalize the parent types here to avoid too many specializations -# this is allowed because we never return `parent(SV)`, so we can safely wrap anything -# that represents the same data -""" - wrap_stridedview(A::AbstractArray) - -Wrap any compatible array into a `StridedView` for the implementation. -Additionally, we normalize the parent types to avoid having to have too many specializations. -This is allowed because we never return `parent(SV)`, so we can safely wrap anything -that represents the same data. -""" +# Wrap any compatible array into a `StridedView` for the implementation. +# Additionally, we normalize the parent types to avoid having to have too many specializations. +# This is allowed because we never return `parent(SV)`, so we can safely wrap anything +# that represents the same data. wrap_stridedview(A::AbstractArray) = StridedView(A) @static if isdefined(Core, :Memory) # For Arrays: we simply use the memory directly