Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Unet configurable #26

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
c8c8fc4
Added version compatibility
RainerHeintzmann Dec 15, 2021
5b5c6e1
WIP configurable UNet
neptunes5thmoon Dec 16, 2021
ffd9a82
add example with configured UNet
neptunes5thmoon Dec 16, 2021
bd60834
bug fixes
RainerHeintzmann Dec 16, 2021
a4cfa13
bug fixes
RainerHeintzmann Dec 16, 2021
8b91545
added pretty print and removed debug prints.
RainerHeintzmann Dec 17, 2021
91cfb2d
bug fixes in show(). removed flux_test.jl
RainerHeintzmann Dec 17, 2021
74ffb5b
bug fixes in noise2noise and deconvolve
RainerHeintzmann Dec 20, 2021
82a736b
reduced iterations in noise2noise example
RainerHeintzmann Dec 22, 2021
338797f
clarify UNet pretty print
neptunes5thmoon Jan 4, 2022
3c65ea1
emphasize U-structure in pretty print
neptunes5thmoon Jan 5, 2022
f15f8c8
feat: add implementation for valid padding
neptunes5thmoon Mar 15, 2022
0b92b57
Merge branch 'master' of github.com:DhairyaLGandhi/UNet.jl
neptunes5thmoon Apr 1, 2022
d5d4c8e
keep examples dependencies separate
neptunes5thmoon Apr 13, 2022
467d5bf
add Flux back
neptunes5thmoon Apr 13, 2022
6179fcf
add parametrization
neptunes5thmoon Apr 21, 2022
fc194a5
declutter print
neptunes5thmoon Apr 21, 2022
1bc3d04
only conv chains should share type
neptunes5thmoon Apr 21, 2022
e547dfc
separate conv chain types
neptunes5thmoon Apr 21, 2022
a6b516f
Merge branch 'DhairyaLGandhi:master' into master
RainerHeintzmann Jun 13, 2024
c74ec86
updated version numbers of packages. Adapted noise2noise exampel
RainerHeintzmann Jun 13, 2024
4fdd9ba
updated deconvoltion.jl
RainerHeintzmann Jun 13, 2024
597cb4f
bug fixes and chages according to comments
RainerHeintzmann Jun 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions examples/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[deps]
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
FourierTools = "b18b359b-aebc-45ac-a139-9c0ccbb2871e"
ImageIO = "82e4d734-157c-48bb-816b-45c225c6df19"
IndexFunArrays = "613c443e-d742-454e-bfc6-1d7f8dd76566"
Noise = "81d43f40-5267-43b7-ae1c-8b967f377efa"
NDTools = "98581153-e998-4eef-8d0d-5ec2c052313d"
TestImages = "5e47fb64-e119-507b-a336-dd2b206d9990"
View5D = "90d841e0-6953-4e90-9f3a-43681da8e949"
42 changes: 42 additions & 0 deletions examples/deconvolve.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Example using U-net to deconvolve an image

using UNet, Flux, TestImages, View5D, Noise, NDTools, FourierTools, IndexFunArrays

img = 100f0 .* Float32.(testimage("resolution_test_512"))

u = Unet();

u = gpu(u);
function loss(x, y)
return Flux.mse(u(x),y)
end
opt = Momentum()

# selects a tile at a random (default) or predifined (ctr) position returning tile and center.
function get_tile(img, tile_size=(128,128), ctr = (rand(tile_size[1]÷2:size(img,1)-tile_size[1]÷2),rand(tile_size[2]÷2:size(img,2)-tile_size[2]÷2)) )
return select_region(img,new_size=tile_size, center=ctr), ctr
end

R_max = 70;
sz = size(img); psf = abs2.(ift(disc(Float32, sz, R_max))); psf ./= sum(psf); conv_img = conv_psf(img,psf);

scale = 0.5f0/maximum(conv_img)
patch = (128,128)
for n in 1:2000
println("Iteration: $n")
myimg, pos = get_tile(conv_img,patch)
# image to denoise
# nimg1 = gpu(reshape(scale .* myimg,(size(myimg)...,1,1))); # gpu(scale.*reshape(poisson(myimg),(size(myimg)...,1,1)))
nimg1 = gpu(scale.*reshape(poisson(Float64.(myimg)),(size(myimg)...,1,1)))
# goal image (with noise)
pimg, pos = get_tile(img,patch,pos)
pimg = gpu(scale.*reshape(pimg,(size(myimg)...,1,1)))
rep = Iterators.repeated((nimg1, pimg), 4);
Flux.train!(loss, Flux.params(u), rep, opt)
end

# apply the net to the whole image instead:
nimg = gpu(scale .* reshape(conv_img,(size(conv_img)...,1,1))); # gpu(scale.*reshape(poisson(conv_img),(size(conv_img)...,1,1)))
nimg2 = gpu(scale.*reshape(poisson(Float64.(conv_img)),(size(conv_img)...,1,1)))
# display the images using View5D
@ve img nimg u(nimg) nimg2 u(nimg2)
38 changes: 38 additions & 0 deletions examples/noise2noise.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Example using U-net for a noise2noise problem

using UNet, Flux, TestImages, View5D, Noise, NDTools

img = 10.0 .* Float32.(testimage("resolution_test_512"))

u = Unet();

u = gpu(u);
function loss(x, y)
# return mean(abs2.(u(x) .-y))
return Flux.mse(u(x),y)
end
opt = Momentum()

# selects a tile at a random (default) or predifined (ctr) position returning tile and center.
function get_tile(img, tile_size=(128,128), ctr = (rand(tile_size[1]÷2:size(img,1)-tile_size[1]÷2),rand(tile_size[2]÷2:size(img,2)-tile_size[2]÷2)) )
return select_region(img,new_size=tile_size, center=ctr), ctr
end

sz = size(img);
scale = 0.5/maximum(img)
patch = (128,128)
for n in 1:100
println("Iteration: $n")
myimg, pos = get_tile(img,patch)
# image to denoise
nimg1 = gpu(scale.*reshape(poisson(myimg),(size(myimg)...,1,1)))
# goal image (with noise)
nimg2 = gpu(scale.*reshape(poisson(myimg),(size(myimg)...,1,1)))
rep = Iterators.repeated((nimg1, nimg2), 1);
Flux.train!(loss, Flux.params(u), rep, opt)
end

# apply the net to the whole image instead:
nimg = gpu(scale.*reshape(poisson(img),(size(img)...,1,1)));
# display the images using View5D
@ve img nimg u(nimg)
Loading