Skip to content

Commit

Permalink
Merge pull request #78 from AllenInstitute/fix/gpu_memory
Browse files Browse the repository at this point in the history
move to predict_on_batch
  • Loading branch information
jeromelecoq authored Nov 16, 2022
2 parents 2f583c4 + 57ca75a commit 2533346
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 5 deletions.
2 changes: 1 addition & 1 deletion deepinterpolation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Version of deepinterpolation package. Should match the version in setup.py
__version__ = "0.1.4"
__version__ = "0.1.5"
40 changes: 40 additions & 0 deletions deepinterpolation/cli/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,26 @@ class TrainingSchema(argschema.schemas.DefaultSchema):
checkpoints.",
)

use_multiprocessing = argschema.fields.Bool(
required=False,
default=True,
description="whether to use a multiprocessing pool to fetch batch \
samples. Setting this to true will increase data generation speed \
if the generator is limited by read speed. This will also \
increase RAM memory usage. Set to False if your hardware \
encounter RAM memory error during training.",
)

nb_workers = argschema.fields.Int(
required=False,
default=16,
description="Nb of workers running on the CPU to fetch \
batch samples. This parameter is only relevant if \
use_multiprocessing is set to True. Larger number of nb_workers \
will increase memory usage. Increase this number until your \
training becomes limited either by RAM or CPU usage.",
)


class FineTuningSchema(argschema.schemas.DefaultSchema):
name = argschema.fields.String(
Expand Down Expand Up @@ -574,6 +594,26 @@ class FineTuningSchema(argschema.schemas.DefaultSchema):
checkpoints.",
)

use_multiprocessing = argschema.fields.Bool(
required=False,
default=True,
description="whether to use a multiprocessing pool to fetch batch \
samples. Setting this to true will increase data generation speed \
if the generator is limited by read speed. This will also \
increase RAM memory usage. Set to False if your hardware \
encounter RAM memory error during training.",
)

nb_workers = argschema.fields.Int(
required=False,
default=16,
description="Nb of workers running on the CPU to fetch \
batch samples. This parameter is only relevant if \
use_multiprocessing is set to True. Larger number of nb_workers \
will increase memory usage. Increase this number until your \
training becomes limited either by RAM or CPU usage.",
)


class NetworkSchema(argschema.schemas.DefaultSchema):
name = argschema.fields.String(
Expand Down
2 changes: 1 addition & 1 deletion deepinterpolation/inferrence_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def run(self):
for index_dataset in np.arange(0, self.nb_datasets, 1):
local_data = self.generator_obj.__getitem__(index_dataset)

predictions_data = self.model.predict(local_data[0])
predictions_data = self.model.predict_on_batch(local_data[0])

local_mean, local_std = \
self.generator_obj.__get_norm_parameters__(index_dataset)
Expand Down
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
tensorflow==2.4.4
tensorflow==2.7
nibabel
h5py==2.10.0
matplotlib
numpy
python-dateutil
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

setup(
name="deepinterpolation",
version="0.1.4",
version="0.1.5",
description="Implemenent DeepInterpolation to denoise data by removing \
independent noise",
long_description_content_type='text/x-rst',
Expand Down

0 comments on commit 2533346

Please sign in to comment.