Skip to content

Commit

Permalink
Use symbolics for computing register bitsizes (#1353)
Browse files Browse the repository at this point in the history
* use symbolics for computing register bitsizes

* symbolic `n` in `Partition`
  • Loading branch information
anurudhp authored Aug 28, 2024
1 parent 862e7e7 commit bfa5bc1
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
9 changes: 4 additions & 5 deletions qualtran/_infra/registers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@
from typing import cast, Dict, Iterable, Iterator, List, overload, Tuple, Union

import attrs
import numpy as np
import sympy
from attrs import field, frozen

from qualtran.symbolics import is_symbolic, smax, SymbolicInt
from qualtran.symbolics import is_symbolic, prod, smax, ssum, SymbolicInt

from .data_types import QAny, QBit, QDType

Expand Down Expand Up @@ -99,7 +98,7 @@ def total_bits(self) -> int:
This is the product of bitsize and each of the dimensions in `shape`.
"""
return self.bitsize * int(np.prod(self.shape))
return self.bitsize * prod(self.shape_symbolic)

def adjoint(self) -> 'Register':
"""Return the 'adjoint' of this register by switching RIGHT and LEFT registers."""
Expand Down Expand Up @@ -202,8 +201,8 @@ def n_qubits(self) -> int:
is taken to be the greater of the number of left or right qubits. A bloq with this
signature uses at least this many qubits.
"""
left_size = sum(reg.total_bits() for reg in self.lefts())
right_size = sum(reg.total_bits() for reg in self.rights())
left_size = ssum(reg.total_bits() for reg in self.lefts())
right_size = ssum(reg.total_bits() for reg in self.rights())
return smax(left_size, right_size)

def __repr__(self):
Expand Down
8 changes: 6 additions & 2 deletions qualtran/bloqs/bookkeeping/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from qualtran.bloqs.bookkeeping._bookkeeping_bloq import _BookkeepingBloq
from qualtran.drawing import directional_text_box, Text, WireSymbol
from qualtran.symbolics import is_symbolic, ssum, SymbolicInt

if TYPE_CHECKING:
import quimb.tensor as qtn
Expand All @@ -54,14 +55,14 @@ class Partition(_BookkeepingBloq):
[user spec]: The registers provided by the `regs` argument. RIGHT by default.
"""

n: int
n: SymbolicInt
regs: Tuple[Register, ...] = field(
converter=lambda x: x if isinstance(x, tuple) else tuple(x), validator=validators.min_len(1)
)
partition: bool = True

def __attrs_post_init__(self):
if self.n != sum(r.total_bits() for r in self.regs):
if self.n != ssum(r.total_bits() for r in self.regs):
raise ValueError("Total bitsize not equal to sum of registers to partition into")
if len(set(r.name for r in self.regs)) != len(self.regs):
raise ValueError("Duplicate register names")
Expand Down Expand Up @@ -104,6 +105,9 @@ def my_tensors(
) -> List['qtn.Tensor']:
import quimb.tensor as qtn

if is_symbolic(self.n):
raise DecomposeTypeError(f"cannot compute tensors for symbolic {self}")

grouped = incoming['x'] if self.partition else outgoing['x']
partitioned = outgoing if self.partition else incoming

Expand Down

0 comments on commit bfa5bc1

Please sign in to comment.