diff --git a/tests/unit/model/test_mblm.py b/tests/unit/model/test_mblm.py index 1250a4b..536fd88 100644 --- a/tests/unit/model/test_mblm.py +++ b/tests/unit/model/test_mblm.py @@ -6,6 +6,7 @@ import torch from mblm import MBLM, MBLMModelConfig, MBLMReturnType, TransformerBlock +from mblm.utils.seed import seed_everything from mblm.utils.stream import ByteStreamer @@ -68,6 +69,8 @@ def test_masked_loss( def test_generate(self): ctx_windows = [12, 4] + + seed_everything(8) mblm = MBLM( MBLMModelConfig( num_tokens=self.num_tokens, @@ -102,6 +105,6 @@ def test_generate(self): buff = io.BytesIO() with ByteStreamer(buff) as stream: - generate(stream=stream) + generate(stream=stream, filter_thres=1) assert len(buff.getbuffer()) == total_generation_len