Skip to content

Commit

Permalink
Fixed docstrings for GAT internal functions.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 488212547
  • Loading branch information
samihaija authored and tensorflower-gardener committed Nov 13, 2022
1 parent ff5e05e commit 1a6f7ad
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions tensorflow_gnn/models/gat_v2/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,10 +307,10 @@ def _split_heads(self, tensor):
where `z` is output of this `_split_heads`.
Args:
tensor: with shape [..., num_heads, channels_per_head].
tensor: with shape `[..., num_heads * channels_per_head]`.
Returns:
Tensor with shape [..., num_heads, channels_per_head] that reconstructs
Tensor with shape `[..., num_heads, channels_per_head]` that reconstructs
`z` from `y = _merge_heads(z, "concat")`.
"""
extra_dims = tensor.shape[1:-1] # Possibly empty.
Expand All @@ -328,19 +328,19 @@ def _merge_heads( # pylint: disable=invalid-name.
"""Combines output of attention heads by concatenation or mean.
If merge_type is "concat", then:
it converts tensor from shape [..., num_heads, channels_per_head], to
tensor of shape [..., num_heads * channels_per_head], by concatenation
it converts tensor from shape `[..., num_heads, channels_per_head]`, to
tensor of shape `[..., num_heads * channels_per_head]`, by concatenation
along the last axis.
Otherwise, if merge_type "mean", then:
it converts tensor from shape [..., num_heads, channels_per_head], to
tensor of shape [..., channels_per_head], by reduce_mean(axis=-2).
Args:
tensor: of shape [..., num_heads, channels_per_head].
merge_type: str. Must be one of {"mean", "concat"}.
tensor: of shape `[..., num_heads, channels_per_head]`.
merge_type: str. Must be one of `{"mean", "concat"}`.
Returns:
Tensor, with num_heads dimension removed (either averaged over, or
Tensor, with `num_heads` dimension removed (either averaged over, or
concatenated).
"""
if merge_type == "concat":
Expand Down

0 comments on commit 1a6f7ad

Please sign in to comment.