Skip to content

Commit

Permalink
Merge pull request #1 from taraspiotr/tensor
Browse files Browse the repository at this point in the history
tensor
  • Loading branch information
taraspiotr authored Oct 25, 2020
2 parents a9930a7 + 2c8f70c commit f58a49e
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 15 deletions.
9 changes: 7 additions & 2 deletions src/DeepJulia.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@ module DeepJulia
export Device,
gpu,
cpu,
ArrayOrCuArray,

# tensor

Tensor,
to,
device,

# variable

Expand Down Expand Up @@ -44,7 +48,8 @@ batchify,
FashionMNIST,
shuffle!

include("cuda.jl")
include("device.jl")
include("tensor.jl")
include("variable.jl")
include("loss.jl")
include("modules.jl")
Expand Down
4 changes: 4 additions & 0 deletions src/device.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
@enum Device begin
cpu
gpu
end
4 changes: 2 additions & 2 deletions src/modules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ mutable struct LinearLayer <: NNModule
W::Variable
b::Variable

input::ArrayOrCuArray
input::Tensor

LinearLayer(W::ArrayOrCuArray, b::ArrayOrCuArray) = new(
LinearLayer(W::Tensor, b::Tensor) = new(
Variable(W),
Variable(b),
Matrix{Real}(undef, 0, 0),
Expand Down
2 changes: 1 addition & 1 deletion src/optim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ struct SGD <: Optimizer
params::Vector{Variable}
lr::Real
momentum::Real
velocities::Vector{ArrayOrCuArray}
velocities::Vector{Tensor}

SGD(params, lr) = new(params, lr, 0.0, Vector{Matrix{Real}}(undef, size(params, 1)))
SGD(params, lr, momentum) = new(params, lr, momentum, [to(zeros(size(p.values)), device(p.values)) for p params])
Expand Down
11 changes: 3 additions & 8 deletions src/cuda.jl → src/tensor.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
using CUDA: CuArray

@enum Device begin
cpu
gpu
end

ArrayOrCuArray = Union{Array,CuArray}
Tensor = Union{Array,CuArray}

function to(A::ArrayOrCuArray, device::Device)
function to(A::Tensor, device::Device)
if device == cpu
return Array(A)
elseif device == gpu
Expand All @@ -17,4 +12,4 @@ function to(A::ArrayOrCuArray, device::Device)
end
end

device(A::ArrayOrCuArray) = isa(A, CuArray) ? gpu : cpu
device(A::Tensor) = isa(A, CuArray) ? gpu : cpu
4 changes: 2 additions & 2 deletions src/variable.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
mutable struct Variable
values::ArrayOrCuArray
grad::ArrayOrCuArray
values::Tensor
grad::Tensor

Variable(values) = new(values, to(zeros(size(values)), device(values)))
Variable(values, grad) = new(values, grad)
Expand Down

0 comments on commit f58a49e

Please sign in to comment.