diff --git a/tests/conftest.py b/tests/conftest.py index 0bf8219..d673482 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,19 @@ +import random + import numpy as np import pytest import torch from scipy.special import erfinv +@pytest.fixture(autouse=True) +def seed_everything(): + seed = 101589 + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + @pytest.fixture def compare_against_numpy(): """