forked from quantumlib/Qualtran
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move drawing code to
drawing
(quantumlib#380)
* Move drawing code to drawing * pylint * fight between linter and tests
- Loading branch information
1 parent
cf5222b
commit 03198a9
Showing
15 changed files
with
463 additions
and
217 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Copyright 2023 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Convenience functions for showing rich displays in Jupyter notebook.""" | ||
|
||
from typing import Dict, Sequence, TYPE_CHECKING, Union | ||
|
||
import IPython.display | ||
import ipywidgets | ||
|
||
from .bloq_counts_graph import format_counts_sigma, GraphvizCounts | ||
from .graphviz import PrettyGraphDrawer | ||
|
||
if TYPE_CHECKING: | ||
import networkx as nx | ||
import sympy | ||
|
||
from qualtran import Bloq | ||
|
||
|
||
def show_bloq(bloq: 'Bloq'): | ||
"""Display a graph representation of the bloq in IPython.""" | ||
IPython.display.display(PrettyGraphDrawer(bloq).get_svg()) | ||
|
||
|
||
def show_bloqs(bloqs: Sequence['Bloq'], labels: Sequence[str] = None): | ||
"""Display multiple bloqs side-by-side in IPython.""" | ||
n = len(bloqs) | ||
if labels is not None: | ||
assert len(labels) == n, 'Must provide exactly as many labels as bloqs' | ||
else: | ||
labels = [None] * n | ||
|
||
outs = [ipywidgets.Output() for _ in range(n)] | ||
box = ipywidgets.HBox(outs) | ||
|
||
for i, (bloq, label) in enumerate(zip(bloqs, labels)): | ||
if label: | ||
outs[i].append_display_data(IPython.display.Markdown(label)) | ||
outs[i].append_display_data(PrettyGraphDrawer(bloq).get_svg()) | ||
|
||
IPython.display.display(box) | ||
|
||
|
||
def show_counts_graph(g: 'nx.DiGraph') -> None: | ||
"""Display a graph representation of the counts graph `g`.""" | ||
IPython.display.display(GraphvizCounts(g).get_svg()) | ||
|
||
|
||
def show_counts_sigma(sigma: Dict['Bloq', Union[int, 'sympy.Expr']]): | ||
"""Display nicely formatted bloq counts sums `sigma`.""" | ||
IPython.display.display(IPython.display.Markdown(format_counts_sigma(sigma))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
# Copyright 2023 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Classes for drawing bloq counts graphs with Graphviz.""" | ||
|
||
from typing import Dict, Union | ||
|
||
import IPython.display | ||
import networkx as nx | ||
import pydot | ||
import sympy | ||
|
||
from qualtran import Bloq, CompositeBloq | ||
|
||
|
||
class GraphvizCounts: | ||
"""This class turns a bloqs count graph into Graphviz objects and drawings. | ||
Args: | ||
g: The counts graph. | ||
""" | ||
|
||
def __init__(self, g: nx.DiGraph): | ||
self.g = g | ||
self._ids: Dict[Bloq, str] = {} | ||
self._i = 0 | ||
|
||
def get_id(self, b: Bloq) -> str: | ||
if b in self._ids: | ||
return self._ids[b] | ||
new_id = f'b{self._i}' | ||
self._i += 1 | ||
self._ids[b] = new_id | ||
return new_id | ||
|
||
def get_node_properties(self, b: Bloq): | ||
"""Get graphviz properties for a bloq node representing `b`.""" | ||
if isinstance(b, CompositeBloq): | ||
details = f'{len(b.bloq_instances)} bloqs...' | ||
else: | ||
details = repr(b) | ||
|
||
label = [ | ||
'<', | ||
f'{b.pretty_name().replace("<", "<").replace(">", ">")}<br />', | ||
f'<font face="monospace" point-size="10">{details}</font><br/>', | ||
'>', | ||
] | ||
return {'label': ''.join(label), 'shape': 'rect'} | ||
|
||
def add_nodes(self, graph: pydot.Graph): | ||
"""Helper function to add nodes to the pydot graph.""" | ||
b: Bloq | ||
for b in nx.topological_sort(self.g): | ||
graph.add_node(pydot.Node(self.get_id(b), **self.get_node_properties(b))) | ||
|
||
def add_edges(self, graph: pydot.Graph): | ||
"""Helper function to add edges to the pydot graph.""" | ||
for b1, b2 in self.g.edges: | ||
n = self.g.edges[b1, b2]['n'] | ||
label = sympy.printing.pretty(n) | ||
graph.add_edge(pydot.Edge(self.get_id(b1), self.get_id(b2), label=label)) | ||
|
||
def get_graph(self): | ||
"""Get the pydot graph.""" | ||
graph = pydot.Dot('counts', graph_type='digraph', rankdir='TB') | ||
self.add_nodes(graph) | ||
self.add_edges(graph) | ||
return graph | ||
|
||
def get_svg_bytes(self) -> bytes: | ||
"""Get the SVG code (as bytes) for drawing the graph.""" | ||
return self.get_graph().create_svg() | ||
|
||
def get_svg(self) -> IPython.display.SVG: | ||
"""Get an IPython SVG object displaying the graph.""" | ||
return IPython.display.SVG(self.get_svg_bytes()) | ||
|
||
|
||
def _format_bloq_expr_markdown(bloq: Bloq, expr: Union[int, sympy.Expr]) -> str: | ||
"""Return "`bloq`: expr" as markdown.""" | ||
try: | ||
expr = expr._repr_latex_() | ||
except AttributeError: | ||
expr = f'{expr}' | ||
|
||
return f'`{bloq}`: {expr}' | ||
|
||
|
||
def format_counts_graph_markdown(graph: nx.DiGraph) -> str: | ||
"""Format a text version of `graph` as markdown.""" | ||
m = "" | ||
for bloq in nx.topological_sort(graph): | ||
if not graph.succ[bloq]: | ||
continue | ||
m += f' - `{bloq}`\n' | ||
|
||
succ_lines = [] | ||
for succ in graph.succ[bloq]: | ||
expr = sympy.sympify(graph.edges[bloq, succ]['n']) | ||
succ_lines.append(f' - {_format_bloq_expr_markdown(succ, expr)}\n') | ||
succ_lines.sort() | ||
m += ''.join(succ_lines) | ||
|
||
return m | ||
|
||
|
||
def format_counts_sigma(sigma: Dict[Bloq, Union[int, sympy.Expr]]) -> str: | ||
"""Format `sigma` as markdown.""" | ||
lines = [f' - {_format_bloq_expr_markdown(bloq, expr)}' for bloq, expr in sigma.items()] | ||
lines.sort() | ||
lines.insert(0, '#### Counts totals:') | ||
return '\n'.join(lines) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Copyright 2023 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import re | ||
|
||
from qualtran.bloqs.and_bloq import MultiAnd | ||
from qualtran.drawing import format_counts_graph_markdown, format_counts_sigma, GraphvizCounts | ||
from qualtran.resource_counting import get_bloq_counts_graph | ||
|
||
|
||
def test_format_counts_sigma(): | ||
graph, sigma = get_bloq_counts_graph(MultiAnd(cvs=(1,) * 6)) | ||
ret = format_counts_sigma(sigma) | ||
assert ( | ||
ret | ||
== """\ | ||
#### Counts totals: | ||
- `ArbitraryClifford(n=2)`: 45 | ||
- `TGate()`: 20""" | ||
) | ||
|
||
|
||
def test_format_counts_graph_markdown(): | ||
graph, sigma = get_bloq_counts_graph(MultiAnd(cvs=(1,) * 6)) | ||
ret = format_counts_graph_markdown(graph) | ||
assert ( | ||
ret | ||
== r""" - `MultiAnd(cvs=(1, 1, 1, 1, 1, 1), adjoint=False)` | ||
- `And(cv1=1, cv2=1, adjoint=False)`: $\displaystyle 5$ | ||
- `And(cv1=1, cv2=1, adjoint=False)` | ||
- `ArbitraryClifford(n=2)`: $\displaystyle 9$ | ||
- `TGate()`: $\displaystyle 4$ | ||
""" | ||
) | ||
|
||
|
||
def test_graphviz_counts(): | ||
graph, sigma = get_bloq_counts_graph(MultiAnd(cvs=(1,) * 6)) | ||
gvc = GraphvizCounts(graph) | ||
|
||
# The main test is in the drawing notebook, so please spot check that. | ||
# Here: we make sure the edge labels are 5, 9 or 4 (see above) | ||
dot_lines = gvc.get_graph().to_string().splitlines() | ||
edge_lines = [line for line in dot_lines if '->' in line] | ||
for line in edge_lines: | ||
ma = re.search(r'label=(\w+)', line) | ||
assert ma is not None, line | ||
i = int(ma.group(1)) | ||
assert i in [5, 9, 4] |
Oops, something went wrong.