Skip to content

Commit

Permalink
For issue # 962: switch from threading to multiprocessing for log_pro…
Browse files Browse the repository at this point in the history
…bability and predict_proba, and added kwargs for parallelised function. (#978)

Co-authored-by: Yeh ML EC2 <[email protected]>
  • Loading branch information
joy13975 and c-yeh authored Jul 4, 2022
1 parent 0652e95 commit f115a24
Showing 1 changed file with 26 additions and 6 deletions.
32 changes: 26 additions & 6 deletions pomegranate/BayesianNetwork.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -443,10 +443,19 @@ cdef class BayesianNetwork(GraphModel):
with open(fn, 'w') as outfile:
outfile.write(self.to_json())

with Parallel(n_jobs=n_jobs, backend='threading') as parallel:
with Parallel(n_jobs=n_jobs, backend='multiprocessing') as parallel:
f = delayed(parallelize_function)
logp_array = parallel(f(batch[0], BayesianNetwork, 'log_probability',
fn) for batch in data_generator.batches())
logp_array = parallel(
f(
batch[0],
BayesianNetwork,
'log_probability',
fn,
check_input=check_input,
n_jobs=1
)
for batch in data_generator.batches()
)

os.remove(fn)
return numpy.concatenate(logp_array)
Expand Down Expand Up @@ -610,10 +619,21 @@ cdef class BayesianNetwork(GraphModel):
with open(fn, 'w') as outfile:
outfile.write(self.to_json())

with Parallel(n_jobs=n_jobs, backend='threading') as parallel:

with Parallel(n_jobs=n_jobs, backend='multiprocessing') as parallel:
f = delayed(parallelize_function)
logp_array = parallel(f(batch[0], BayesianNetwork, 'predict_proba',
fn) for batch in data_generator.batches())
logp_array = parallel(
f(
batch[0],
self.__class__,
'predict_proba',
fn,
max_iterations=max_iterations,
check_input=check_input,
n_jobs=1
)
for batch in data_generator.batches()
)

os.remove(fn)
return numpy.concatenate(logp_array)
Expand Down

0 comments on commit f115a24

Please sign in to comment.