Skip to content

Commit

Permalink
default nccl use compute stream in grad acc
Browse files Browse the repository at this point in the history
  • Loading branch information
chengtbf committed Aug 11, 2022
1 parent 6b20fce commit 972dfee
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions python/oneflow/nn/graph/graph_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def build(self, x):
Args:
mode (bool): if set to true, optimizer states of Data Parallel will be sharded across devices.
stage (int): optimization stage, range from 1 to 3.
stage (int): optimization stage, range from 1 to 3.
shard_min_size (int): min size (element count) of a shard of an optimizer state.
shard_restore_level (int): level to restore sharded parameter to whole parameter for consumer operators, level 0 is no restore, level 1 is soft restore, level 2 is hard restore. Note that this paremeter is at pre-alpha stage.
"""
Expand Down Expand Up @@ -178,7 +178,7 @@ def __init__(self):
self.bn1 = flow.nn.BatchNorm1d(100)
self.config.allow_fuse_add_to_output(True)
def build(self, x):
bn = self.bn1(x)
bn = self.bn1(x)
out = bn + x
return out
Expand All @@ -191,7 +191,7 @@ def build(self, x):

def allow_fuse_cast_scale(self, mode: bool = True):
r"""If set to true, try to fuse cast and scalar_mul_by_tensor to improve performance.
For example:
.. code-block:: python
Expand Down Expand Up @@ -240,6 +240,11 @@ def build(self, x):
value (int): num of steps.
"""
self.proto.num_gradient_accumulation_steps = value
if value > 1:

This comment has been minimized.

Copy link
@yuanms2

yuanms2 Sep 16, 2022

Contributor

建议上面这行value 直接使用 self.proto.num_gradient_accumulation_steps 这个变量名,这样意思更清楚

# NOTE(chengcheng): when use gradient accumulation, optimizer nccl allreduce can NOT
# overlap with backward, so nccl use compute stream is optimization without negative
# effects.
nccl_config.enable_use_compute_stream(True)

def set_outputs_buffer_size(self, value: int = 2):
r"""Set the outputs buffer size of ``nn.Graph``.
Expand Down Expand Up @@ -278,7 +283,7 @@ def build(self, x):
return self.m(x)
graph = Graph()
Args:
mode (bool, optional): The default vaule is True.
"""
Expand All @@ -289,7 +294,7 @@ def enable_straighten_algorithm(self, mode: bool = True):
If using nccl compute stream, turning it on might not speed up the training.
If not using nccl compute stream, turning it on might slow down data parallelism by 0.6% and slow down model parallelism by 6%.
Considering memory, enabling the straighten algorithm is forbidden with one machine/device only, and not recommended under pipeline parallelism.
Considering memory, enabling the straighten algorithm is forbidden with one machine/device only, and not recommended under pipeline parallelism.
"""
self.proto.enable_straighten_algorithm_in_task_graph = mode

Expand Down

0 comments on commit 972dfee

Please sign in to comment.