-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmnist_test.jl
39 lines (28 loc) · 1 KB
/
mnist_test.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
using Flux, Flux.Data.MNIST, Statistics
using Flux: onehotbatch, onecold, crossentropy, throttle
using Base.Iterators: repeated, partition
using NNlib
# using CuArrays
# Classify MNIST digits with a convolutional network
imgs = MNIST.images()
labels = onehotbatch(MNIST.labels(), 0:9)
# Partition into batches of size 1,000
train = [(cat(float.(imgs[i])..., dims = 4), labels[:,i])
for i in partition(1:60_000, 1000)]
train = gpu.(train)
# Prepare test set (first 1,000 images)
tX = cat(float.(MNIST.images(:test)[1:1000])..., dims = 4) |> gpu
tY = onehotbatch(MNIST.labels(:test)[1:1000], 0:9) |> gpu
m = Chain(
Conv((2,2), 1=>16, relu),
x -> maxpool(x, (2,2)),
Conv((2,2), 16=>8, relu),
x -> maxpool(x, (2,2)),
x -> reshape(x, :, size(x, 4)),
Dense(288, 10), softmax) |> gpu
m(train[1][1])
loss(x, y) = crossentropy(m(x), y)
accuracy(x, y) = mean(onecold(m(x)) .== onecold(y))
evalcb = throttle(() -> @show(accuracy(tX, tY)), 10)
opt = ADAM(params(m))
Flux.train!(loss, train, opt, cb = evalcb)