Skip to content

Commit

Permalink
Updates model test.
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebgorman committed Oct 8, 2024
1 parent 01138bc commit f41d06b
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions tests/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,22 @@
@pytest.mark.parametrize(
["arch", "expected_cls"],
[
("attentive_lstm", models.AttentiveLSTMEncoderDecoder),
("hard_attention_lstm", models.HardAttentionLSTM),
("lstm", models.LSTMEncoderDecoder),
("attentive_gru", models.AttentiveGRUModel),
("attentive_lstm", models.AttentiveLSTMModel),
("gru", models.GRUModel),
("hard_attention_lstm", models.HardAttentionLSTMModel),
("lstm", models.LSTMModel),
(
"pointer_generator_gru",
models.PointerGeneratorGRUModel,
),
(
"pointer_generator_lstm",
models.PointerGeneratorLSTMEncoderDecoder,
models.PointerGeneratorLSTMModel,
),
("transducer", models.TransducerEncoderDecoder),
("transformer", models.TransformerEncoderDecoder),
("transducer_gru", models.TransducerGRUModel),
("transducer_lstm", models.TransducerLSTMModel),
("transformer", models.TransformerModel),
],
)
def test_get_model_cls(arch, expected_cls):
Expand Down

0 comments on commit f41d06b

Please sign in to comment.