Skip to content

Commit

Permalink
Iterate on multi-gpu design
Browse files Browse the repository at this point in the history
Another iteration on the multi-gpu prototype and testing script.  Still
seems to be some problem running clfft on multiple threads.  Perhaps the
contexts are not set up properly in the low level code.  Going to push
this though and retry on other machines.
  • Loading branch information
bnorthan committed Dec 25, 2023
1 parent c772db6 commit 8f55124
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 14 deletions.
35 changes: 24 additions & 11 deletions python/clij2fft/richardson_lucy_dask_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pyopencl as cl
from clij2fft.pad import get_next_smooth
from clij2fft.libs import getlib

bytes_per_gb = 1024 * 1024 * 1024

Expand Down Expand Up @@ -148,9 +149,13 @@ def richardson_lucy_dask(img, psf, numiterations, regularizationfactor, non_circ
from multiprocessing import Pool, current_process, Queue
queue = Queue()

for i in range(1):
num_gpus = 1
for i in range(num_gpus):
queue.put(i)


lib = getlib()

#from dask.distributed import get_worker
if non_circulant:

Expand All @@ -161,17 +166,18 @@ def richardson_lucy_nc_dask_task(img, psf, numiterations, regularizationfactor=0
gpu_num=queue.get()
print('start rlnc')
print('gpu num is', gpu_num)
print('block id is', block_id)
print('block info is', block_info)
print('thread id is', thread_id)
#print('block id is', block_id)
#print('block info is', block_info)
#print('thread id is', thread_id)
#print('worker is ', get_worker())
if block_id is None:
print('returning block id is None')
return None
result=richardson_lucy_nc(img, psf, numiterations, regularizationfactor=regularizationfactor, lib=lib)
#if block_id is None:
# print('returning block id is None')
# return None
result=richardson_lucy_nc(img, psf, numiterations, regularizationfactor=regularizationfactor, lib=lib, platform = 0, device = gpu_num)
print('end rlnc')
return result
except:
except Exception as e:
print("EXCEPTION",e)
pass
finally:
print('putting gpu num back', gpu_num)
Expand All @@ -186,9 +192,16 @@ def richardson_lucy_dask_task(img, psf, numiterations, regularizationfactor=0, l
return img#richardson_lucy(img, psf, numiterations, regularizationfactor=regularizationfactor, lib=lib)
rl_func = richardson_lucy_dask_task

import time

out = dimg.map_overlap(rl_func, depth={0:0, 1:overlap, 2:overlap}, dtype=np.float32, psf=psf, numiterations=numiterations, regularizationfactor=regularizationfactor)
return out.compute(num_workers=4)
start_time = time.time()
out = dimg.map_overlap(rl_func, depth={0:0, 1:overlap, 2:overlap}, dtype=np.float32, psf=psf, numiterations=numiterations, regularizationfactor=regularizationfactor, lib=lib)
out_img = out.compute(num_workers=num_gpus)
end_time = time.time()
execution_time = end_time - start_time
print(f"Execution time of rl dask multi gpu: {execution_time} seconds")

return out_img



Expand Down
4 changes: 2 additions & 2 deletions python/clij2fft/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
img= np.ones((256, 256, 128), dtype=np.float32)
psf = np.ones((128, 128, 64), dtype=np.float32)

result = richardson_lucy(img, psf, 100, 0, platform=0, device=0)
result = richardson_lucy(img, psf, 100, 0, platform=1, device=0)
result = richardson_lucy(img, psf, 100, 0, platform=0, device=3)
#result = richardson_lucy(img, psf, 100, 0, platform=1, device=0)

print()
print(result.shape, result.mean())
6 changes: 5 additions & 1 deletion python/clij2fft/test_richardson_lucy_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
import numpy as np
from matplotlib import pyplot as plt


img_name=r'D:\\images/images/Bars-G10-P15-stack-cropped.tif'
psf_name=r'D:\\images/images/PSF-Bars-stack-cropped.tif'

img_name=r'/home/bnorthan/images/deconvolution/Bars-G10-P15-stack.tif'
psf_name=r'/home/bnorthan/images/deconvolution/PSF-Bars-stack.tif'

img_name = r'C:\Users\Administrator\data\Bars-G10-P15-stack.tif'
psf_name = r'C:\Users\Administrator\data\PSF-Bars-stack.tif'

img=imread(img_name)
print('image shape is',img.shape)

Expand All @@ -21,7 +25,7 @@
print('image shape is',img.shape)
psf=imread(psf_name)

decon=richardson_lucy_dask(img, psf, 10, 0.0001, mem_to_use=mem_to_use)
decon=richardson_lucy_dask(img, psf, 50, 0.0001, mem_to_use=mem_to_use)

fig, ax = plt.subplots(1,2)
ax[0].imshow(img.max(axis=0))
Expand Down

0 comments on commit 8f55124

Please sign in to comment.