Skip to content

Commit

Permalink
fix MPS sample handling of RNG seed
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinsung committed Jul 19, 2024
1 parent 1960e05 commit 1769759
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
6 changes: 4 additions & 2 deletions quimb/tensor/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4043,7 +4043,8 @@ def sample(
seed : None, int, or generator, optional
A random seed or generator to use for reproducibility.
"""
for config, _ in self._psi.sample(C, seed=seed):
rng = np.random.default_rng(seed)
for config, _ in self._psi.sample(C, seed=rng):
yield "".join(map(str, config))

def fidelity_estimate(self):
Expand Down Expand Up @@ -4154,9 +4155,10 @@ def sample(self, C, seed=None):
str
The next sample bitstring.
"""
rng = np.random.default_rng(seed)
# configuring is in physical order, so need to reorder for sampling
ordering = self.calc_qubit_ordering()
for config, _ in self._psi.sample(C, seed=seed):
for config, _ in self._psi.sample(C, seed=rng):
yield "".join(str(config[i]) for i in ordering)

@property
Expand Down
14 changes: 14 additions & 0 deletions tests/test_tensor/test_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,13 @@ def test_mps_sampling(self):
for x in circ.sample(10):
assert x in {"000010", "111101"}

def test_mps_sampling_seed(self):
N = 1
circ = qtn.CircuitMPS(N)
circ.h(0)
samples = list(circ.sample(100, seed=1234))
assert len(set(samples)) == 2

def test_permmps_sampling(self):
N = 6
circ = qtn.CircuitPermMPS(N)
Expand All @@ -710,6 +717,13 @@ def test_permmps_sampling(self):
for x in circ.sample(10):
assert x in {"000010", "111101"}

def test_permmps_sampling_seed(self):
N = 1
circ = qtn.CircuitPermMPS(N)
circ.h(0)
samples = list(circ.sample(100, seed=1234))
assert len(set(samples)) == 2


class TestCircuitGen:
@pytest.mark.parametrize(
Expand Down

0 comments on commit 1769759

Please sign in to comment.