diff --git a/python/oneflow/nn/graph/graph_config.py b/python/oneflow/nn/graph/graph_config.py index ff5b044b779..5e291f12b1f 100644 --- a/python/oneflow/nn/graph/graph_config.py +++ b/python/oneflow/nn/graph/graph_config.py @@ -116,7 +116,7 @@ def build(self, x): Args: mode (bool): if set to true, optimizer states of Data Parallel will be sharded across devices. - stage (int): optimization stage, range from 1 to 3. + stage (int): optimization stage, range from 1 to 3. shard_min_size (int): min size (element count) of a shard of an optimizer state. shard_restore_level (int): level to restore sharded parameter to whole parameter for consumer operators, level 0 is no restore, level 1 is soft restore, level 2 is hard restore. Note that this paremeter is at pre-alpha stage. """ @@ -178,7 +178,7 @@ def __init__(self): self.bn1 = flow.nn.BatchNorm1d(100) self.config.allow_fuse_add_to_output(True) def build(self, x): - bn = self.bn1(x) + bn = self.bn1(x) out = bn + x return out @@ -191,7 +191,7 @@ def build(self, x): def allow_fuse_cast_scale(self, mode: bool = True): r"""If set to true, try to fuse cast and scalar_mul_by_tensor to improve performance. - + For example: .. code-block:: python @@ -240,6 +240,11 @@ def build(self, x): value (int): num of steps. """ self.proto.num_gradient_accumulation_steps = value + if value > 1: + # NOTE(chengcheng): when use gradient accumulation, optimizer nccl allreduce can NOT + # overlap with backward, so nccl use compute stream is optimization without negative + # effects. + nccl_config.enable_use_compute_stream(True) def set_outputs_buffer_size(self, value: int = 2): r"""Set the outputs buffer size of ``nn.Graph``. @@ -278,7 +283,7 @@ def build(self, x): return self.m(x) graph = Graph() - + Args: mode (bool, optional): The default vaule is True. """ @@ -289,7 +294,7 @@ def enable_straighten_algorithm(self, mode: bool = True): If using nccl compute stream, turning it on might not speed up the training. If not using nccl compute stream, turning it on might slow down data parallelism by 0.6% and slow down model parallelism by 6%. - Considering memory, enabling the straighten algorithm is forbidden with one machine/device only, and not recommended under pipeline parallelism. + Considering memory, enabling the straighten algorithm is forbidden with one machine/device only, and not recommended under pipeline parallelism. """ self.proto.enable_straighten_algorithm_in_task_graph = mode