Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add naive Base julia AbstractArray implementation #171

Merged
merged 13 commits into from
Jun 23, 2024
Merged

Conversation

lkdvos
Copy link
Collaborator

@lkdvos lkdvos commented Apr 20, 2024

This is meant to avoid using the Strided based backends for types that do not support this. For now, it only dispatched to Strided for StridedArrays, and uses a naive version based on permutedims and vectorinterface for the other AbstractArray subtypes.
My preliminary tests do seem to indicate that this works for things like FillArrays, so at least this should solve the long-outstanding annoyance where Zygote generates these array types when calling sum on an array, which would then break in TensorOperations. (See also #169 )

TODO

  • tensortrace implementation
  • add some decent testing
  • add some documentation
  • think about centralizing backend selection procedure in separate functions

@Jutho
Copy link
Owner

Jutho commented Apr 24, 2024

I guess that this will have the effect that certain Julia arrays which are strided but are not part of the StridedArray union will no longer go via the StridedBLAS/StridedNative implementation?

Strided(Views).jl should perhaps have an isstrided check that can be applied to AbstractArray and works in a recursive manner, in the same way that the StridedView constructor works.

@Jutho
Copy link
Owner

Jutho commented Jun 13, 2024

Ok I've changed and finished the base implementations. Had to do some force pushing to fix a divergence :-).

Probably the base implementations need a test. I would also prefer not relying on the StridedArray union to decide on which implementation to use, as it excludes a number of cases. I remember discussing adding an isstrided check in StridedViews.jl that follows the same recursive logic as the StridedView constructor. Then we can use that to decide on which implementation to use.

@lkdvos
Copy link
Collaborator Author

lkdvos commented Jun 13, 2024

Maybe we can just add a trait for this in Strided.jl, define some defaults for reasonable types and allow people to opt in?

function tensoradd!(C::AbstractArray,
A::AbstractArray, pA::Index2Tuple, conjA::Bool,
α::Number, β::Number)
α::Number, β::Number, ::BaseCopy)
argcheck_tensoradd(C, A, pA)
dimcheck_tensoradd(C, A, pA)

# can we assume that C is mutable?
# is there more functionality in base that we can use?
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is an in-place permutedims!, such that we could use the allocation interface to also hijack into allocating these temporary arrays

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That might be possible indeed. The question is whether it is worth it. This will probably only be used for types which are very different from strided arrays, e.g. sparse arrays.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's probably a very valid point. It's probably fair to assume that we cannot guarantee optimal performance without extra information about the specific type anyways, and the system does allow to easily implement custom backends if necessary. The base backend should serve mostly as a catch-all implementation that ensures that it works for most types.

Comment on lines +202 to +209
if iszero(β)
C .= α .* conj.(reshape(view(Ã, :, 1, 1), so))
else
C .= β .* C .+ α .* conj.(reshape(view(Ã, :, 1, 1), so))
end
for i in 2:st
C .+= α .* conj.(reshape(view(Ã, :, i, i), so))
end
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we rewrite this with sum?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know? In a way that does not cause additional allocations? Is there an issue with the current approach?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote this in reply to the comment about -- # is there more base functionality we can use. I don't think there is any issue with the current approach.

@Jutho
Copy link
Owner

Jutho commented Jun 21, 2024

Ok, with these tests included, I think this can be merged, unless there are further comments or suggestions. I will include testing the different backends also in the macro-based tests, but this will first require the other backend specifying syntax.

@Jutho Jutho marked this pull request as ready for review June 22, 2024 12:18
Copy link
Collaborator Author

@lkdvos lkdvos left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to go for me!
(can't approve my own PR)

@Jutho Jutho merged commit fabfb08 into master Jun 23, 2024
15 checks passed
@lkdvos lkdvos deleted the base-backend branch June 23, 2024 22:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants