-
Notifications
You must be signed in to change notification settings - Fork 217
/
Copy pathArchitectures.jl
33 lines (22 loc) · 1.11 KB
/
Architectures.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
module Architectures
using Reactant
using Oceananigans
import Oceananigans.Architectures: device, architecture, array_type, on_architecture
import Oceananigans.Architectures: unified_array, ReactantState, device_copy_to!
const ReactantKernelAbstractionsExt = Base.get_extension(
Reactant, :ReactantKernelAbstractionsExt
)
const ReactantBackend = ReactantKernelAbstractionsExt.ReactantBackend
device(::ReactantState) = ReactantBackend()
architecture(::Reactant.AnyConcreteRArray) = ReactantState
architecture(::Reactant.AnyTracedRArray) = ReactantState
array_type(::ReactantState) = ConcreteRArray
on_architecture(::ReactantState, a::Array) = ConcreteRArray(a)
on_architecture(::ReactantState, a::Reactant.AnyConcreteRArray) = a
on_architecture(::ReactantState, a::Reactant.AnyTracedRArray) = a
on_architecture(::ReactantState, a::BitArray) = ConcreteRArray(a)
on_architecture(::ReactantState, a::SubArray{<:Any, <:Any, <:Array}) = ConcreteRArray(a)
unified_array(::ReactantState, a) = a
@inline device_copy_to!(dst::Reactant.AnyConcreteRArray, src::Reactant.AnyConcreteRArray; kw...) =
Base.copyto!(dst, src)
end # module