diff --git a/HISTORY.rst b/HISTORY.rst index caacc03..e8b7a29 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -233,7 +233,8 @@ History * bug fixed in plt_utils * plt_imshow added to plt_utils -0.10.10 (2024-01-25) +0.10.10 (2024-01-30) ------------------ * rgb2hsv is added -* plt_imshow supports complex color map and is bug free \ No newline at end of file +* plt_imshow supports complex color map and is bug free +* added printprogress to loopprocessor \ No newline at end of file diff --git a/lognflow/loopprocessor.py b/lognflow/loopprocessor.py index 32abe74..88a374d 100644 --- a/lognflow/loopprocessor.py +++ b/lognflow/loopprocessor.py @@ -23,7 +23,7 @@ def _loopprocessor_function( class loopprocessor(): def __init__(self, targetFunction, n_cpu = None, test_mode = False, logger = print, - concatenate_outputs = True, verbose = True): + concatenate_outputs = True, verbose = True, n_processes = 0): self.targetFunction = targetFunction self.test_mode = test_mode self.aQ = Queue() @@ -33,9 +33,15 @@ def __init__(self, else: self.n_cpu = n_cpu self.verbose = verbose - if(self.verbose): + self.n_processes = n_processes + if self.verbose: self.logger = logger self.logger(f'lognflow loopprocessor initialized with {self.n_cpu} CPUs.') + if self.n_processes: + assert self.n_processes > 0 + assert self.n_processes == int(self.n_processes) + from .printprogress import printprogress + self.pBar = printprogress(self.n_processes) self.outputs_is_given = False self.outputs = [] @@ -68,6 +74,8 @@ def __call__(self, *args, **kwargs): ret_result = aQElement[1] if ((not self.any_error) & aQElement[2]): self.any_error = True + if self.n_processes: + del self.pBar self.empty_queue = True error_ret_procID = ret_procID_range.copy() self.logger('') @@ -78,6 +86,8 @@ def __call__(self, *args, **kwargs): for ret_procID, result in zip(ret_procID_range, ret_result): self.Q_procID.append(ret_procID) self.outputs.append(result) + if self.n_processes: + self.pBar() elif(self.numBusyCores): self.logger(f'Number of busy cores: {self.numBusyCores}') diff --git a/tests/test_multiprocessor.py b/tests/test_multiprocessor.py index 71fa687..da19ae6 100644 --- a/tests/test_multiprocessor.py +++ b/tests/test_multiprocessor.py @@ -168,7 +168,7 @@ def test_loopprocessor(): print('-'*80, '\n', inspect.stack()[0][3], '\n', '-'*80) N = 16 - D = 100000 + D = 1000000 data = (100+10*np.random.randn(N,D)).astype('int') mask = (2*np.random.rand(N,D)).astype('int')