Skip to content
This repository has been archived by the owner on Jan 21, 2025. It is now read-only.

Commit

Permalink
Option to use mtf.Print to log which tokens are sent to which experts…
Browse files Browse the repository at this point in the history
… when run on CPU.

PiperOrigin-RevId: 368137313
  • Loading branch information
William Fedus authored and Mesh TensorFlow Team committed Nov 9, 2021
1 parent 57ed401 commit da43c27
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 11 deletions.
72 changes: 66 additions & 6 deletions mesh_tensorflow/transformer/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def __init__(self,
word_embed_mode=None,
use_second_place_expert_prob=None,
use_second_place_expert_prob_temp=None,
top_n_num_experts_per_token=3):
top_n_num_experts_per_token=3,
token_logging=False):
self._hparams = HParams(
moe_gating=moe_gating,
moe_num_experts=num_experts,
Expand Down Expand Up @@ -97,6 +98,7 @@ def __init__(self,
use_second_place_expert_prob_temp),
moe_top_n_num_experts_per_token=top_n_num_experts_per_token)
self._activation = activation
self.token_logging = token_logging

def call(self, context, x, losses=None):
"""Call the layer."""
Expand All @@ -116,7 +118,13 @@ def call(self, context, x, losses=None):
output_dim = self._hparams.moe_output_dim
else:
output_dim = context.model.model_dim
y, loss = transformer_moe_layer_v1(
if self.token_logging:
tokens = _detokenize(context.inputs, context.model.vocabulary)
x = mtf.Print(x, [tokens], "tokens:", summarize=1000)
extras = _windows(context.inputs, context.length_dim)
else:
extras = None
y, loss, extras = transformer_moe_layer_v1(
x,
output_dim,
self._hparams,
Expand All @@ -127,7 +135,16 @@ def call(self, context, x, losses=None):
nonpadding=context.nonpadding,
activation=self._activation,
num_microbatches=context.num_microbatches,
token_embeddings=context.input_embeddings)
token_embeddings=context.input_embeddings,
extras=extras)

if extras:
extras = _detokenize(extras, context.model.vocabulary)
experts_dim = mtf.Dimension("experts", self._hparams.moe_num_experts)
extras = mtf.unstack(extras, experts_dim)
for i, t in enumerate(extras):
y = mtf.Print(y, [t], "EXPERT %s:" % i, summarize=1000)

if context.losses is not None:
context.losses.append(loss)
if not has_length_dim:
Expand All @@ -139,6 +156,23 @@ def call(self, context, x, losses=None):
return y


@gin.configurable
def _windows(ids, length_dim, window_start=0, window_end=0):
to_stack = []
for offset in range(window_start, window_end + 1):
to_stack.append(mtf.shift(ids, -offset, length_dim, wrap=False))
return mtf.stack(to_stack, "window", axis=ids.shape.ndims)


def _detokenize(ids, vocabulary):
return mtf.slicewise(
vocabulary.decode_tf,
[ids],
output_shape=mtf.Shape(ids.shape.dims[:-1]),
output_dtype=tf.string,
splittable_dims=ids.shape.dims[:-1])


class MoE2D(transformer.TransformerLayer):
"""Mixture of Experts Layer."""

Expand Down Expand Up @@ -202,7 +236,7 @@ def call(self, context, x, losses=None):
def transformer_moe_layer_v1(
inputs, output_dim, hparams, train, variable_dtype,
layout=None, mesh_shape=None, nonpadding=None, activation=mtf.relu,
num_microbatches=None, token_embeddings=None):
num_microbatches=None, token_embeddings=None, extras=None):
"""Local mixture of experts that works well on TPU.
Adapted from the paper https://arxiv.org/abs/1701.06538
Expand Down Expand Up @@ -281,6 +315,7 @@ def transformer_moe_layer_v1(
[batch_dim(s), length_dim, input_dim]. These are the word embeddings for
that correspond to the inputs. These can optionally be used to make
routing decisions.
extras: a tensor to dispatch (for debugging purposes)
Returns:
outputs: a Tensor with shape [batch_dim(s), length_dim, output_dim]
Expand Down Expand Up @@ -344,6 +379,10 @@ def transformer_moe_layer_v1(
# over which those groups are split.
batch_and_length_dims, input_dim = (orig_inputs.shape.dims[:-1],
orig_inputs.shape.dims[-1])

if extras:
extras_dims = extras.shape.dims[len(batch_and_length_dims):]

# Hack: we assume that
# "outer_batch" == replication of experts
# mesh_dim_size can be derived from mesh_shape and orig_batch_dim
Expand Down Expand Up @@ -381,6 +420,11 @@ def transformer_moe_layer_v1(
token_embeddings = mtf.cast(
mtf.reshape(token_embeddings, moe_input_dims), inputs.dtype)

if extras:
extras = mtf.reshape(
extras,
[outer_batch_dim, num_groups_dim, group_size_dim] + extras_dims)

# Each sequence sends expert_capacity positions to each expert.
if train:
capacity_factor = hparams.moe_capacity_factor_train
Expand Down Expand Up @@ -503,6 +547,17 @@ def transformer_moe_layer_v1(
input_dim
]))

if extras:
extras = mtf.einsum([extras, mtf.cast(dispatch_tensor, extras.dtype)],
mtf.Shape([
outer_batch_dim, experts_dim_unsplit,
num_groups_dim, expert_capacity_dim] + extras_dims))
extras = mtf.reshape(
extras,
mtf.Shape([
outer_batch_dim, experts_dim, batch_dim_unsplit,
expert_capacity_dim] + extras_dims))

# Now feed the expert inputs through the experts.
h = mtf.layers.dense_product(
expert_inputs,
Expand Down Expand Up @@ -559,10 +614,15 @@ def _compute_output(hidden, layer_name):
k = _compute_output(k_h, layer_name="k_wo")
outputs.append(q)
outputs.append(k)
return outputs, loss * hparams.moe_loss_coef
return outputs, loss * hparams.moe_loss_coef, None
else:
output = _compute_output(h, layer_name="wo")
return output, loss * hparams.moe_loss_coef
loss *= hparams.moe_loss_coef

if extras:
return output, loss, extras
else:
return output, loss, None


def transformer_moe_layer_v2(
Expand Down
5 changes: 4 additions & 1 deletion mesh_tensorflow/transformer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,8 @@ def __init__(self,
input_full_attention=False,
loss_on_targets_only=False,
loss_denominator=None,
token_dropout_rate=0.0):
token_dropout_rate=0.0,
vocabulary=None):
"""Create a Unitransformer.
Args:
Expand Down Expand Up @@ -767,6 +768,7 @@ def __init__(self,
same denominator as was used for the pretraining. This complication
might be avoided by always using loss_denominator = 1.0.
token_dropout_rate: an optional floating point value
vocabulary: an optional vocabularies.Vocabulary
"""
self.layer_stack = layer_stack
self.model_dim = mtf.Dimension("d_model", d_model)
Expand Down Expand Up @@ -807,6 +809,7 @@ def __init__(self,
raise ValueError(
"input_full_attention only makes sense with autoregressive")
self.token_dropout_rate = token_dropout_rate
self.vocabulary = vocabulary

@property
def fully_autoregressive(self):
Expand Down
19 changes: 15 additions & 4 deletions mesh_tensorflow/transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,9 @@ def build_model(model_type="bitransformer",
input_vocab_size=gin.REQUIRED,
output_vocab_size=gin.REQUIRED,
layout_rules=None,
mesh_shape=None):
mesh_shape=None,
input_vocabulary=None,
target_vocabulary=None):
"""Build a transformer model.
Currently, four types of models are supported:
Expand Down Expand Up @@ -214,15 +216,21 @@ def build_model(model_type="bitransformer",
output_vocab_size: an integer
layout_rules: optional, input to mtf.convert_to_layout_rules
mesh_shape: optional, an input to mtf.convert_to_shape()
input_vocabulary: optional, a vocubalaries.Vocabulary
target_vocabulary: optional, a vocubalaries.Vocabulary
Returns:
a Unitransformer or Bitransformer
"""
if model_type == "bitransformer":
return transformer.make_bitransformer(
ret = transformer.make_bitransformer(
input_vocab_size=input_vocab_size,
output_vocab_size=output_vocab_size,
mesh_shape=mesh_shape,
layout=layout_rules)
ret.encoder.vocabulary = input_vocabulary
ret.decoder.vocabulary = target_vocabulary
return ret
elif model_type == "bi_student_teacher":
return transformer.make_bi_student_teacher(
input_vocab_size=input_vocab_size,
Expand All @@ -236,7 +244,8 @@ def build_model(model_type="bitransformer",
input_vocab_size=input_vocab_size,
output_vocab_size=output_vocab_size,
mesh_shape=mesh_shape,
layout=layout_rules)
layout=layout_rules,
vocabulary=input_vocabulary)
else:
raise ValueError("unknown model_type")

Expand Down Expand Up @@ -2067,7 +2076,9 @@ def get_estimator(model_type, vocabulary, mesh_shape,
input_vocab_size=inputs_vocabulary(vocabulary).vocab_size,
output_vocab_size=targets_vocabulary(vocabulary).vocab_size,
layout_rules=layout_rules,
mesh_shape=mesh_shape)
mesh_shape=mesh_shape,
input_vocabulary=inputs_vocabulary(vocabulary),
target_vocabulary=targets_vocabulary(vocabulary))

model_fn = tpu_estimator_model_fn(
model_type=model_type,
Expand Down

0 comments on commit da43c27

Please sign in to comment.