Skip to content

Commit

Permalink
feat(predict_deepspeed): --cpu_offload + help docs for args #89
Browse files Browse the repository at this point in the history
  • Loading branch information
jyaacoub committed May 6, 2024
1 parent 9a0c2e6 commit b93e289
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions predict_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,26 @@
parser.add_argument('--runtime_json', type=str, default=None)
parser.add_argument('--no_overwrite', action='store_true', default=False)

attn_method = parser.add_mutually_exclusive_group()
attn_method.add_argument('--lma', action='store_true', default=False, help="Uses bfloat16 and LMA")
attn_method.add_argument('--flash', action='store_true', default=False, help="Uses bfloat16 and flash attention (requires CUDA >= 11.6 and torch >= 1.12)")
parser.add_argument('--chunk_size', type=int, default=None,
inf_opt = parser.add_argument_group('Inference args',
description="When suffering from OOM setting --chunk_size to 4 is usually a good start, "+\
"next would be to set --low_pres, and finally --cpu_offload. "+\
"This is the order of least to most performance impact with minimal impact to inference times."+\
"In rare cases, we might need to use --lma but this is not recommended unless absolutely necessary "+\
"due to the immense increase in time complexity.")
inf_opt.add_argument('--chunk_size', type=int, default=None,
help="chunk size for reducing memory overhead (lower=less mem; 4 is usually good)")
parser.add_argument('--low_pres', action='store_true', default=False, help="Use low precision")
inf_opt.add_argument('--low_pres', action='store_true', default=False,
help="Use low precision")
inf_opt.add_argument('--cpu_offload', action='store_true', default=False,
help="Offload params to cpu when not in use (decreases memory usage but increases inference time)")

attn_method = inf_opt.add_mutually_exclusive_group()
attn_method.add_argument('--flash', action='store_true', default=False, help="Uses bfloat16 and flash attention (requires CUDA >= 11.6 and torch >= 1.12)")
attn_method.add_argument('--lma', action='store_true', default=False, help="Uses bfloat16 and LMA")

parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--world_size', type=int, default=2)
ds_opt = parser.add_argument_group('Deepspeed init args')
ds_opt.add_argument('--local_rank', type=int, default=0)
ds_opt.add_argument('--world_size', type=int, default=2)

args = parser.parse_args()

Expand Down Expand Up @@ -112,7 +123,7 @@ def main():

c = ckpt['hyper_parameters']['config']
c.globals.chunk_size = args.chunk_size # setting to None means no chunking
if args.chunk_size is not None:
if args.cpu_offload:
c.globals.offload_inference = True
c.model.template.offload_inference = True

Expand Down

1 comment on commit b93e289

@jyaacoub
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

related to jyaacoub/MutDTA#89

Please sign in to comment.