This release mainly improves
- Fix some fidelity issues.
- Refactor schedule primitives, and add
.fork_rng()
,.annotate()
, and.replace_all()
primitives. - Other bug fixing.
If any of the following cases match your existing schedule based on v0.0.2, you have to change them to support v0.0.3.
- Tagging parameters for DeepSpeed pipeline runtime to perform an additional all-reduce on TP group. For example, you may have the following code snippet that tags LayerNorm parameters:
def tag_layernorm(sch):
for m in sch.mod.modules():
if isinstance(m, nn.LayerNorm):
for p in m.parameters(recurse=False):
p.replicated_param = True
This can be changed to the following in v0.0.3:
def annotate_layernorm_and_bias(sch):
for sub_sch in sch.child.values():
if isinstance(sub_sch.mod, nn.LayerNorm):
for name, _ in sub_sch.mod.named_parameters(recurse=False):
sub_sch.annotate(name, "replicated_param", True)
if issubclass(sub_sch.mod.__class__, LinearWithSyncFunc):
sub_sch.annotate("bias", "replicated_param", True)
annotate_layernorm_and_bias(sub_sch)
Reference: https://github.com/awslabs/slapo/blob/main/slapo/model_schedule/gpt2.py#L529
- RNG control can be done easily with a new introduced schedule primitive
.fork_rng()
. Accordingly, the oldslapo.op.AttentionOpWithRNG
is removed. If you have the following code snippet:
new_op = AttentionOpWithRNG(
sub_sch["module"]["attn_op"].mod.attn_op_name,
sub_sch["module"]["attn_op"].mod.apply_causal_mask,
sub_sch["module"]["attn_op"].mod.scale,
)
sub_sch["module"]["attn_op"].replace(new_op)
It has to be changed to
sub_sch["module"]["attn_op"].fork_rng()
-
The primitive
.trace_for_pipeline()
has been renamed to.trace_until()
. Since the arguments remain the same, you could simply replace all occurrences. -
If you use
slapo.op.FusedMLP
with sharding, you need to change your schedule to reflect the change of FusedMLP implementation. For example:
fc_names = ["fc_in", "act", "fc_out"]
sub_sch[fc_names[0]].shard("weight", axis=0)
sub_sch[fc_names[1]].shard("bias", axis=0)
sub_sch[fc_names[2]].shard("weight", axis=1)
sub_sch[fc_names[0]].sync(mode="bwd_post", sync_op_or_fn="all_reduce")
sub_sch[fc_names[2]].sync(mode="fwd_post", sync_op_or_fn="all_reduce")
changes to
fc_names = ["fc_in", "fc_out"]
sub_sch[fc_names[0]].shard("weight", axis=0)
sub_sch[fc_names[0]].shard("bias", axis=0)
sub_sch[fc_names[1]].shard("weight", axis=1)
sub_sch[fc_names[0]].sync(mode="bwd_post", sync_op_or_fn="all_reduce")
sub_sch[fc_names[1]].sync(mode="fwd_post", sync_op_or_fn="all_reduce")
What's Changed
- [Action] Fix release flow by @comaniac in #69
- [Refactor] Schedule primitives by @comaniac in #68
- [Primitive] .fork_rng() by @comaniac in #70
- [Primitive] .annotate() and .trace_until() by @comaniac in #71
- [CI] Update CI rules for docs by @chhzh123 in #72
- [Op] Fuse bias+dropout in FusedMLP by @comaniac in #73
- [Refactor] Modulize sharding methods by @comaniac in #74
- [CI] Quick fix by @chhzh123 in #75
- [Primitive][fork_rng] Do not replace module by @comaniac in #76
- [Bugfix] Include other custom LinearWithXX by @comaniac in #77
- [Primitive] Add fallback fusion by @chhzh123 in #78
- [examples] Refactor dataloader to support BERT by @chhzh123 in #79
- [Bugfix] Shard embedding hooks by @comaniac in #80
- [Version] Refactor version updating logic by @comaniac in #82
- [Op] Print by @comaniac in #81
- [Primitive] Add .replace_all() by @chhzh123 in #85
- [Version] Update version to v0.0.3 by @chhzh123 in #84
Full Changelog: v0.0.2...v0.0.3