Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Network modelling prototype 1 #9

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
Manifest.toml
test_*.png
test/data/*
9 changes: 7 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,18 @@ authors = ["dhairyagandhi <[email protected]>"]
version = "0.1.0"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
ImageIO = "82e4d734-157c-48bb-816b-45c225c6df19"
ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

Expand All @@ -19,7 +25,6 @@ FileIO = "1"
Flux = "0.10, 0.11"
ImageCore = "0.8"
ImageTransformations = "0.8"
Reexport = "0"
StatsBase = "0"
julia = "1.3"

Expand Down
20 changes: 13 additions & 7 deletions src/UNet.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
module UNet

export Unet, bce, load_img, load_batch

using Reexport

using StatsBase
using Flux
using Flux: @functor
using Flux.Data: DataLoader
using Flux: logitcrossentropy, dice_coeff_loss

using Images
using ImageCore
using ImageTransformations: imresize
using FileIO
using Distributions: Normal

@reexport using Statistics
@reexport using Flux, Flux.Zygote, Flux.Optimise
using Serialization
using ForwardDiff
using Parameters: @with_kw
using CUDAapi
using CUDA

include("defaults.jl")
include("utils.jl")
include("dataloader.jl")
include("model.jl")
include("train.jl")

export Unet, train

end # module
232 changes: 162 additions & 70 deletions src/dataloader.jl
Original file line number Diff line number Diff line change
@@ -1,75 +1,167 @@
using StatsBase: sample, shuffle

"""
stub -> <one of the dirs>/<img path> - _im/cr.tif
"""
function load_img(base::String, stub::String; rsize = (256,256))
im = joinpath(base, stub * "_im.tif")
cr = joinpath(base, stub * "_cr.tif")
x, y = load(im), load(cr)
x = imresize(x, rsize...)
y = imresize(y, rsize...)
x, y = channelview(x), channelview(y)
x = reshape(x, rsize..., 1, 1)
y = reshape(y, rsize..., 1, 1)
x, y

@with_kw mutable struct Dataset

name::String = ""
input_directory::String = "/input"
input_prefix::String = "input_"
input_files::Dict{Int,String} = Dict{Int,String}()
input_files_keys::Array = []
input_num_files::Int = 0

target_directory::String = "/target"
target_prefix::String = "target_"
target_files::Dict{Int,String} = Dict{Int,String}()
target_files_keys::Array = []
target_num_files::Int = 0

ini :: Function = initialize_dataset

end

"""
`load_batch(base::String, name_template, target_template; n = 10, dir = nothing, rsize = (500,500))`

`base`: Path to the base directory where the training dataset is stored

templates: `name_template`(::String) filters the images in the path based on whether they contain the `name_template`
`target_template`(::String) replaces the `name_template` with the `target_template` to look for corresponding masks

`dir`: if `nothing`, all the directories will be sampled at random, else will pick up images from the specified directory. Currently supports taking only one directory as a `String`.

Returns a batch of `n` randomly sampled images
and corresponding masks.
"""
function load_batch(base::String, name_template = "im",
target_template = "cr";
n = 10, channels = 1,
dir=nothing, rsize=(500,500))

# TODO: templates should support regexes
if dir isa Nothing
dir = setdiff(readdir(base), ["zips"])
dir = sample(dir)
end

imgs = filter(x -> occursin(name_template, x),
readdir(joinpath(base, dir)))
imgs = sample(imgs, n)
masks = map(x -> replace(x, name_template=>target_template), imgs)

# Ensure all the files exist, else try again, and error
if !all(isfile.(masks))
imgs = sample(imgs, n)
masks = map(x -> replace(x, name_template=>target_template), imgs)
end

batch = zip(imgs, masks)

x = zeros(Float32, rsize..., channels, n) # []
y = zeros(Float32, rsize..., channels, n) # []
for (i,(img, mask)) in enumerate(batch)
img = load(joinpath(base, dir, img))
mask = load(joinpath(base, dir, mask))
img = imresize(img, rsize...)
mask = imresize(mask, rsize...)
img = channelview(img)
mask = channelview(mask)
img = reshape(img, rsize..., channels)
mask = reshape(mask, rsize..., channels)

x′ = @view x[:,:,:,i]
x′ .= img
function initialize_dataset(name; input_dir = input_dir::String, target_dir = target_dir::String)

dataset = Dataset()

dataset.name = name

dataset.input_directory = input_dir
input_files = readdir(dataset.input_directory)
dataset.input_num_files = length(input_files)
i = 1
for f in input_files
dataset.input_files[i] = joinpath(dataset.input_directory, f)
i = i + 1
end
dataset.input_files_keys = sort(collect(keys(dataset.input_files)))

dataset.target_directory = target_dir
target_files = readdir(dataset.target_directory)
dataset.target_num_files = length(target_files)
i = 1
for f in target_files
dataset.target_files[i] = joinpath(dataset.target_directory, f)
i = i + 1
end
dataset.target_files_keys = sort(collect(keys(dataset.target_files)))

# check if directories in sync
@assert dataset.input_num_files == dataset.target_num_files "Number of input files is different from number of target files! Input: $(dataset.input_num_files) Target:$(dataset.target_num_files)"

y′ = @view x[:,:,:,i]
y′ .= mask
end
input_files = string.(map(v -> replace(v,dataset.input_prefix=>""),input_files))
target_files = string.(map(v -> replace(v,dataset.target_prefix=>""),target_files))

x, y
@assert insync(input_files, target_files) "Input and target directories are not in sync. Make sure each input image has corresponding target image!"

return dataset
end

function insync(input_files, target_files)

isinsync = true
for f in input_files
if !in(f, target_files)
isinsync = false
break
end
end

return isinsync
end

function grab_random_files(dataset::Dataset, num_files::Int; drop_processed = true)

idx = sample(dataset.input_files_keys, min(dataset.input_num_files, num_files)) #this is not unique, some files are duplicated
input_files = []
target_files = []

for i in idx
push!(input_files, dataset.input_files[i])
push!(target_files, dataset.target_files[i])

end

if drop_processed
for i in unique(idx)
pop!(dataset.input_files,i)
pop!(dataset.target_files,i)
end
end

dataset.input_num_files = length(dataset.input_files)
dataset.input_files_keys = collect(keys(dataset.input_files))

dataset.target_num_files = length(dataset.target_files)
dataset.target_files_keys = collect(keys(dataset.target_files))

return input_files, target_files

end

function load_files(input_files::Array, target_files::Array)

nfiles = length(input_files)

s = size(permutedims(channelview(load(input_files[1])), [2,3,1]))
data = zeros(s[1], s[2], s[3], nfiles)

i = 1
for file in input_files
img = permutedims(channelview(load(file)), [2,3,1])
@assert s == size(img) "Input images are not of the same size. Please check!"
data[:,:,:,i] = img
i = i + 1
end

nfiles = length(target_files)
s = size(channelview(load(target_files[1])))
onehotlabels = zeros(Int8, s[1], s[2], nfeatures, nfiles)
itargets = zeros(Int16, s[1], s[2], nfiles)
weights = zeros(nfeatures, nfiles)
i = 1
for file in target_files
target = channelview(load(file))
itargets[:,:,i] = Int16.(get_integer_intensity.(target))
@assert s == size(target) "Input images are not of the same size. Please check!"
onehotlabels[:,:,:,i], weights[:,i] = target_to_onehot(target, nfeatures)
i = i + 1
end

#return convert(Array{Float32}, data), convert(Array{Int8}, onehotlabels), convert(Array{Float32}, weights)
return data, onehotlabels, itargets
end

function get_integer_intensity(value::Normed{UInt8,8})
return value.i
end

function target_to_onehot(target, nfeatures)

s = size(target)
onehottarget = zeros(Int32, s[1], s[2], nfeatures)
ulabels = sort(unique(target)) #nfeatures defined in defaults.jl
itarget = 1 .+ Int8.(get_integer_intensity.(target) ./ 30)

weights = zeros(nfeatures)

for i=1:s[1]
for j=1:s[1]
k = itarget[i, j, 1]
weights[k] += 1
end
end

weights = weights/maximum(weights)
weights = 1 .+ (1 .- weights) .* 100
weights = Int8.(round.(weights))

for i=1:s[1]
for j=1:s[1]
k = itarget[i, j, 1]
onehottarget[i, j, k] = weights[k] # this should be 1 but we'll put weights here to balnce things out
end
end

return onehottarget, weights

end

2 changes: 2 additions & 0 deletions src/defaults.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
const nfeatures = 7
const nchannels = 3
Loading