Skip to content

Commit

Permalink
Days' work: need to re-enable derived architectures.
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebgorman committed Oct 8, 2024
1 parent a914362 commit 71d5652
Show file tree
Hide file tree
Showing 23 changed files with 519 additions and 269 deletions.
36 changes: 20 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,10 @@ Yoyodyne is inspired by [FairSeq](https://github.com/facebookresearch/fairseq)
- It is for small-vocabulary sequence-to-sequence generation, and therefore
includes no affordances for machine translation or language modeling.
Because of this:
- It has no plugin interface and the architectures provided are intended
to be reasonably exhaustive.
- The architectures provided are intended to be reasonably exhaustive.
- There is little need for data preprocessing; it works with TSV files.
- It has support for using features to condition decoding, with
architecture-specific code handling feature information.
architecture-specific code for handling feature information.
- It supports the use of validation accuracy (not loss) for model selection
and early stopping.
- Releases are made regularly.
Expand Down Expand Up @@ -185,9 +184,15 @@ not an attention mechanism is present. This flag also specifies a default
architecture for the encoder(s), but it is possible to override this with
additional flags. Supported values for `--arch` are:

- `attentive_lstm`: This is an LSTM decoder with LSTM encoders (by default)
- `attentive_gru`: This is an GRU decoder with GRU encoders (by default)
and an attention mechanism. The initial hidden state is treated as a learned
parameter.
- `attentive_lstm`: This is an LSTM decoder with LSTM encoders (by default)
and an attention mechanism. The initial hidden and cell states are treated
as a learned parameter.
- `gru`: This is an GRU decoder with GRU encoders (by default); in lieu of
an attention mechanism, the last non-padding hidden state of the encoder is
concatenated with the decoder hidden state.
- `hard_attention_lstm`: This is an LSTM encoder/decoder modeling generation
as a Markov process. By default, it assumes a non-monotonic progression over
the source string, but with `--enforce_monotonic` the model must progress
Expand Down Expand Up @@ -221,18 +226,17 @@ additional flags. Supported values for `--arch` are:
The user may wish to specify the number of attention heads (with
`--source_attention_heads`; default: `4`).

The user can override the default encoder architectures. One can override the
source encoder using the `--source_encoder_arch` flag:
The `--arch` flag specifies the decoder type; the user can override default
encoder types using the `--source_encoder_arch` flag and, when features are
present, the `--features_encoder_arch` flag. Valid values are:

- `feature_invariant_transformer`: This is a variant of the transformer
encoder used with features; it concatenates source and features and uses a
- `feature_invariant_transformer`: a variant of the transformer encoder
used with features; it concatenates source and features and uses a
learned embedding to distinguish between source and features symbols.
- `linear`: This is a linear encoder.
- `lstm`: This is a LSTM encoder.
- `transformer`: This is a transformer encoder.

When using features, the user can also specify a non-default features encoder
using the `--features_encoder_arch` flag (`linear`, `lstm`, `transformer`).
- `linear`: a linear encoder.
- `gru`: a GRU encoder.
- `lstm`: a LSTM encoder.
- `transformer`: a transformer encoder.

For all models, the user may also wish to specify:

Expand All @@ -241,8 +245,8 @@ For all models, the user may also wish to specify:
- `--encoder_layers` (default: `1`): number of encoder layers
- `--hidden_size` (default: `512`): hidden layer size

By default, LSTM encoders are bidirectional. One can disable this with the
`--no_bidirectional` flag.
By default, RNN (i.e., GRU and LSTM) encoders are bidirectional. One can
disable this with the `--no_bidirectional` flag.

## Training options

Expand Down
4 changes: 3 additions & 1 deletion examples/wandb_sweeps/best_hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def main(args: argparse.Namespace) -> None:
logging.basicConfig(format="%(levelname)s: %(message)s", level="INFO")
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--entity", required=True, help="The entity scope for the project."
"--entity",
required=True,
help="The entity scope for the project.",
)
parser.add_argument(
"--project", required=True, help="The project of the sweep."
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ exclude = ["examples*"]

[project]
name = "yoyodyne"
version = "0.2.12"
version = "0.2.13"
description = "Small-vocabulary neural sequence-to-sequence models"
readme = "README.md"
requires-python = ">= 3.9"
Expand Down
6 changes: 4 additions & 2 deletions yoyodyne/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

# Silences irrelevant warnings; these are more like "Did you know?"s.
warnings.filterwarnings(
"ignore", ".*does not have many workers which may be a bottleneck.*"
"ignore",
".*does not have many workers which may be a bottleneck.*",
)
warnings.filterwarnings(
"ignore", ".*option adds dropout after all but last recurrent layer.*"
"ignore",
".*option adds dropout after all but last recurrent layer.*",
)
warnings.filterwarnings("ignore", ".*is a wandb run already in progress.*")
7 changes: 6 additions & 1 deletion yoyodyne/data/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,4 +184,9 @@ def unk_idx(self) -> int:

@property
def special_idx(self) -> Set[int]:
return {self.unk_idx, self.pad_idx, self.start_idx, self.end_idx}
return {
self.unk_idx,
self.pad_idx,
self.start_idx,
self.end_idx,
}
1 change: 1 addition & 0 deletions yoyodyne/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
HIDDEN_SIZE = 512
MAX_SOURCE_LENGTH = 128
MAX_TARGET_LENGTH = 128
RNN_TYPE = "lstm"
SOURCE_ATTENTION_HEADS = 4
TIE_EMBEDDINGS = True

Expand Down
34 changes: 21 additions & 13 deletions yoyodyne/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,30 @@

from .. import defaults
from .base import BaseEncoderDecoder
from .hard_attention import HardAttentionRNN
from .pointer_generator import (
PointerGeneratorRNNEncoderDecoder,
PointerGeneratorTransformerEncoderDecoder,
)
from .rnn import AttentiveRNNEncoderDecoder, RNNEncoderDecoder
from .transducer import TransducerEncoderDecoder

# from .hard_attention import HardAttentionGRU
# from .pointer_generator import (
# PointerGeneratorGRUEncoderDecoder,
# PointerGeneratorTransformerEncoderDecoder,
# )
from .rnn import AttentiveGRUEncoderDecoder # noqa: F401
from .rnn import AttentiveLSTMEncoderDecoder # noqa: F401
from .rnn import GRUEncoderDecoder # noqa: F401
from .rnn import LSTMEncoderDecoder # noqa: F401
from .rnn import RNNEncoderDecoder # noqa: F401

# from .transducer import TransducerEncoderDecoder
from .transformer import TransformerEncoderDecoder

_model_fac = {
"attentive_rnn": AttentiveRNNEncoderDecoder,
"hard_attention_rnn": HardAttentionRNN,
"pointer_generator_rnn": PointerGeneratorRNNEncoderDecoder,
"pointer_generator_transformer": PointerGeneratorTransformerEncoderDecoder, # noqa: 501
"rnn": RNNEncoderDecoder,
"transducer": TransducerEncoderDecoder,
"attentive_gru": AttentiveGRUEncoderDecoder,
"attentive_lstm": AttentiveLSTMEncoderDecoder,
# "hard_attention_gru": HardAttentionGRU,
# "pointer_generator_gru": PointerGeneratorGRUEncoderDecoder,
# "pointer_generator_transformer": PointerGeneratorTransformerEncoderDecoder, # noqa: E501
"gru": GRUEncoderDecoder,
"lstm": LSTMEncoderDecoder,
# "transducer": TransducerEncoderDecoder,
"transformer": TransformerEncoderDecoder,
}

Expand Down
74 changes: 41 additions & 33 deletions yoyodyne/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,32 +63,32 @@ class BaseEncoderDecoder(lightning.LightningModule):
def __init__(
self,
*,
pad_idx,
start_idx,
end_idx,
vocab_size,
features_vocab_size,
target_vocab_size,
pad_idx,
source_encoder_cls,
eval_metrics=defaults.EVAL_METRICS,
features_encoder_cls=None,
start_idx,
target_vocab_size,
vocab_size,
beam_width=defaults.BEAM_WIDTH,
beta1=defaults.BETA1,
beta2=defaults.BETA2,
decoder_layers=defaults.DECODER_LAYERS,
dropout=defaults.DROPOUT,
embedding_size=defaults.EMBEDDING_SIZE,
encoder_layers=defaults.ENCODER_LAYERS,
eval_metrics=defaults.EVAL_METRICS,
features_encoder_cls=None,
hidden_size=defaults.HIDDEN_SIZE,
label_smoothing=defaults.LABEL_SMOOTHING,
learning_rate=defaults.LEARNING_RATE,
max_source_length=defaults.MAX_SOURCE_LENGTH,
max_target_length=defaults.MAX_TARGET_LENGTH,
optimizer=defaults.OPTIMIZER,
scheduler=None,
scheduler_kwargs=None,
dropout=defaults.DROPOUT,
label_smoothing=defaults.LABEL_SMOOTHING,
teacher_forcing=defaults.TEACHER_FORCING,
beam_width=defaults.BEAM_WIDTH,
max_source_length=defaults.MAX_SOURCE_LENGTH,
max_target_length=defaults.MAX_TARGET_LENGTH,
encoder_layers=defaults.ENCODER_LAYERS,
decoder_layers=defaults.DECODER_LAYERS,
embedding_size=defaults.EMBEDDING_SIZE,
hidden_size=defaults.HIDDEN_SIZE,
**kwargs, # Ignored.
**kwargs,
):
super().__init__()
# Symbol processing.
Expand Down Expand Up @@ -133,31 +133,31 @@ def __init__(
)
# Instantiates encoders class.
self.source_encoder = source_encoder_cls(
pad_idx=self.pad_idx,
start_idx=self.start_idx,
end_idx=self.end_idx,
embeddings=self.embeddings,
embedding_size=self.embedding_size,
num_embeddings=self.vocab_size,
dropout=self.dropout,
layers=self.encoder_layers,
hidden_size=self.hidden_size,
embedding_size=self.embedding_size,
embeddings=self.embeddings,
end_idx=self.end_idx,
features_vocab_size=features_vocab_size,
hidden_size=self.hidden_size,
layers=self.encoder_layers,
max_source_length=max_source_length,
num_embeddings=self.vocab_size,
pad_idx=self.pad_idx,
start_idx=self.start_idx,
**kwargs,
)
self.features_encoder = (
features_encoder_cls(
pad_idx=self.pad_idx,
start_idx=self.start_idx,
end_idx=self.end_idx,
embeddings=self.embeddings,
embedding_size=self.embedding_size,
num_embeddings=self.vocab_size,
dropout=self.dropout,
layers=self.encoder_layers,
embedding_size=self.embedding_size,
embeddings=self.embeddings,
hidden_size=self.hidden_size,
layers=self.encoder_layers,
end_idx=self.end_idx,
max_source_length=max_source_length,
num_embeddings=self.vocab_size,
pad_idx=self.pad_idx,
start_idx=self.start_idx,
**kwargs,
)
if features_encoder_cls is not None
Expand All @@ -166,7 +166,12 @@ def __init__(
self.decoder = self.get_decoder()
# Saves hyperparameters for PL checkpointing.
self.save_hyperparameters(
ignore=["source_encoder", "decoder", "expert", "features_encoder"]
ignore=[
"source_encoder",
"decoder",
"expert",
"features_encoder",
]
)
# Logs the module names.
util.log_info(f"Model: {self.name}")
Expand Down Expand Up @@ -257,7 +262,10 @@ def validation_step(
# Gets a dict of all eval metrics for this batch.
val_eval_items_dict = {
evaluator.name: evaluator.evaluate(
greedy_predictions, target_padded, self.end_idx, self.pad_idx
greedy_predictions,
target_padded,
self.end_idx,
self.pad_idx,
)
for evaluator in self.evaluators
}
Expand Down
15 changes: 13 additions & 2 deletions yoyodyne/models/expert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,16 @@
import abc
import argparse
import dataclasses
from typing import Any, Dict, Iterable, Iterator, List, Sequence, Set, Tuple
from typing import (
Any,
Dict,
Iterable,
Iterator,
List,
Sequence,
Set,
Tuple,
)

import numpy
from maxwell import actions, sed
Expand Down Expand Up @@ -118,7 +127,9 @@ def to_i2w(self) -> List[Any]:
return self.i2w[len(self.start_vocab_idx) :] # noqa: E203

@property
def substitutions(self) -> List[Tuple[int, actions.ConditionalEdit]]:
def substitutions(
self,
) -> List[Tuple[int, actions.ConditionalEdit]]:
return [
i
for i, a in enumerate(self.i2w)
Expand Down
12 changes: 9 additions & 3 deletions yoyodyne/models/hard_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,9 @@ def greedy_decode(
break
# Pads if finished decoding.
pred = torch.where(
~finished, pred, torch.tensor(self.end_idx, device=self.device)
~finished,
pred,
torch.tensor(self.end_idx, device=self.device),
)
predictions = torch.cat((predictions, pred), dim=-1)
# Updates likelihood emissions.
Expand Down Expand Up @@ -334,7 +336,9 @@ def _gather_at_idx(
return output * pad_mask

@staticmethod
def _apply_mono_mask(transition_prob: torch.Tensor) -> torch.Tensor:
def _apply_mono_mask(
transition_prob: torch.Tensor,
) -> torch.Tensor:
"""Applies monotonic attention mask to transition probabilities.
Enforces a 0 log-probability values for all non-monotonic relations
Expand Down Expand Up @@ -472,7 +476,9 @@ def _loss(
fwd = fwd + transition_probs[tgt_char_idx].transpose(1, 2)
fwd = fwd.logsumexp(dim=-1, keepdim=True).transpose(1, 2)
fwd = fwd + self._gather_at_idx(
log_probs[tgt_char_idx], target[:, tgt_char_idx], self.pad_idx
log_probs[tgt_char_idx],
target[:, tgt_char_idx],
self.pad_idx,
)
loss = -torch.logsumexp(fwd, dim=-1).mean() / target.shape[1]
return loss
Expand Down
Loading

0 comments on commit 71d5652

Please sign in to comment.