Skip to content

Commit

Permalink
remove unnecessary else
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinsung committed Aug 1, 2024
1 parent 61fa15b commit cb3c51f
Showing 1 changed file with 34 additions and 38 deletions.
72 changes: 34 additions & 38 deletions python/ffsim/states/slater.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def sample_slater_determinant(
rng = np.random.default_rng(seed)

if isinstance(nelec, int):
# Spinless case.
# Spinless case
rdm = cast(np.ndarray, rdm)
norb, _ = rdm.shape
if orbs is None:
Expand All @@ -87,51 +87,47 @@ def sample_slater_determinant(
bitstring_type,
length=len(orbs),
)
else:
# Spinful case
rdm_a, rdm_b = rdm
n_a, n_b = nelec
norb, _ = rdm_a.shape
if orbs is None:
orbs = (range(norb), range(norb))
orbs_a, orbs_b = orbs
orbs_a = cast(Sequence[int], orbs_a)
orbs_b = cast(Sequence[int], orbs_b)
strings_a = _sample_slater_spinless(rdm_a, n_a, shots, rng)
strings_b = _sample_slater_spinless(rdm_b, n_b, shots, rng)
strings_a = restrict_bitstrings(
strings_a, orbs_a, bitstring_type=BitstringType.INT
)
strings_b = restrict_bitstrings(
strings_b, orbs_b, bitstring_type=BitstringType.INT
)

if concatenate:
strings = concatenate_bitstrings(
strings_a,
strings_b,
BitstringType.INT,
length=len(orbs_a),
)
return convert_bitstring_type(
strings,
BitstringType.INT,
bitstring_type,
length=len(orbs_a) + len(orbs_b),
)

return convert_bitstring_type(
# Spinful case
rdm_a, rdm_b = rdm
n_a, n_b = nelec
norb, _ = rdm_a.shape
if orbs is None:
orbs = (range(norb), range(norb))
orbs_a, orbs_b = orbs
orbs_a = cast(Sequence[int], orbs_a)
orbs_b = cast(Sequence[int], orbs_b)
strings_a = _sample_slater_spinless(rdm_a, n_a, shots, rng)
strings_b = _sample_slater_spinless(rdm_b, n_b, shots, rng)
strings_a = restrict_bitstrings(strings_a, orbs_a, bitstring_type=BitstringType.INT)
strings_b = restrict_bitstrings(strings_b, orbs_b, bitstring_type=BitstringType.INT)

if concatenate:
strings = concatenate_bitstrings(
strings_a,
strings_b,
BitstringType.INT,
bitstring_type,
length=len(orbs_a),
), convert_bitstring_type(
strings_b,
)
return convert_bitstring_type(
strings,
BitstringType.INT,
bitstring_type,
length=len(orbs_b),
length=len(orbs_a) + len(orbs_b),
)

return convert_bitstring_type(
strings_a,
BitstringType.INT,
bitstring_type,
length=len(orbs_a),
), convert_bitstring_type(
strings_b,
BitstringType.INT,
bitstring_type,
length=len(orbs_b),
)


def _sample_slater_spinless(
rdm: np.ndarray,
Expand Down

0 comments on commit cb3c51f

Please sign in to comment.