Skip to content

Commit

Permalink
Move drawing code to drawing (quantumlib#380)
Browse files Browse the repository at this point in the history
* Move drawing code to drawing

* pylint

* fight between linter and tests
  • Loading branch information
mpharrigan authored Oct 11, 2023
1 parent cf5222b commit 03198a9
Show file tree
Hide file tree
Showing 15 changed files with 463 additions and 217 deletions.
7 changes: 6 additions & 1 deletion dev_tools/execute-notebooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,13 @@ def parse_args():
p = argparse.ArgumentParser()
p.add_argument('--output-nbs', action=argparse.BooleanOptionalAction, default=True)
p.add_argument('--output-html', action=argparse.BooleanOptionalAction, default=False)
p.add_argument('--only-out-of-date', action=argparse.BooleanOptionalAction, default=True)
args = p.parse_args()
execute_and_export_notebooks(output_nbs=args.output_nbs, output_html=args.output_html)
execute_and_export_notebooks(
output_nbs=args.output_nbs,
output_html=args.output_html,
only_out_of_date=args.only_out_of_date,
)


if __name__ == '__main__':
Expand Down
7 changes: 4 additions & 3 deletions qualtran/bloqs/and_bloq.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,12 @@
"metadata": {},
"outputs": [],
"source": [
"from qualtran.resource_counting import get_bloq_counts_graph, GraphvizCounts, SympySymbolAllocator\n",
"from qualtran.resource_counting import get_bloq_counts_graph\n",
"from qualtran.drawing import show_counts_graph\n",
"import attrs\n",
"\n",
"graph, sigma = get_bloq_counts_graph(bloq)\n",
"GraphvizCounts(graph).get_svg()"
"show_counts_graph(graph)"
]
},
{
Expand Down Expand Up @@ -166,7 +167,7 @@
"outputs": [],
"source": [
"graph, sigma = get_bloq_counts_graph(bloq)\n",
"GraphvizCounts(graph).get_svg()"
"show_counts_graph(graph)"
]
},
{
Expand Down
7 changes: 5 additions & 2 deletions qualtran/bloqs/basic_gates.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,12 @@
"metadata": {},
"outputs": [],
"source": [
"from qualtran.resource_counting import get_bloq_counts_graph, GraphvizCounts\n",
"from qualtran.resource_counting import get_bloq_counts_graph\n",
"from qualtran.drawing import show_counts_graph, show_counts_sigma\n",
"\n",
"g, sigma = get_bloq_counts_graph(bloq)\n",
"GraphvizCounts(g).get_svg()"
"show_counts_graph(g)\n",
"show_counts_sigma(sigma)"
]
},
{
Expand Down
17 changes: 5 additions & 12 deletions qualtran/bloqs/factoring/ref-factoring.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -207,19 +207,12 @@
},
"outputs": [],
"source": [
"from qualtran.resource_counting import get_bloq_counts_graph, GraphvizCounts, markdown_counts_sigma\n",
"from qualtran.resource_counting import get_bloq_counts_graph\n",
"from qualtran.drawing import show_counts_graph, show_counts_sigma\n",
"\n",
"g, sigma = get_bloq_counts_graph(bloq)\n",
"GraphvizCounts(g).get_svg()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18e83743",
"metadata": {},
"outputs": [],
"source": [
"markdown_counts_sigma(sigma)"
"show_counts_graph(g)\n",
"show_counts_sigma(sigma)"
]
},
{
Expand Down
9 changes: 7 additions & 2 deletions qualtran/drawing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
isort:skip_file
"""


from .graphviz import GraphDrawer, PrettyGraphDrawer, ClassicalSimGraphDrawer, show_bloq
from .graphviz import GraphDrawer, PrettyGraphDrawer
from .musical_score import (
RegPosition,
HLine,
Expand All @@ -37,3 +36,9 @@
draw_musical_score,
dump_musical_score,
)

from .classical_sim_graph import ClassicalSimGraphDrawer

from .bloq_counts_graph import GraphvizCounts, format_counts_sigma, format_counts_graph_markdown

from ._show_funcs import show_bloq, show_bloqs, show_counts_graph, show_counts_sigma
63 changes: 63 additions & 0 deletions qualtran/drawing/_show_funcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Convenience functions for showing rich displays in Jupyter notebook."""

from typing import Dict, Sequence, TYPE_CHECKING, Union

import IPython.display
import ipywidgets

from .bloq_counts_graph import format_counts_sigma, GraphvizCounts
from .graphviz import PrettyGraphDrawer

if TYPE_CHECKING:
import networkx as nx
import sympy

from qualtran import Bloq


def show_bloq(bloq: 'Bloq'):
"""Display a graph representation of the bloq in IPython."""
IPython.display.display(PrettyGraphDrawer(bloq).get_svg())


def show_bloqs(bloqs: Sequence['Bloq'], labels: Sequence[str] = None):
"""Display multiple bloqs side-by-side in IPython."""
n = len(bloqs)
if labels is not None:
assert len(labels) == n, 'Must provide exactly as many labels as bloqs'
else:
labels = [None] * n

outs = [ipywidgets.Output() for _ in range(n)]
box = ipywidgets.HBox(outs)

for i, (bloq, label) in enumerate(zip(bloqs, labels)):
if label:
outs[i].append_display_data(IPython.display.Markdown(label))
outs[i].append_display_data(PrettyGraphDrawer(bloq).get_svg())

IPython.display.display(box)


def show_counts_graph(g: 'nx.DiGraph') -> None:
"""Display a graph representation of the counts graph `g`."""
IPython.display.display(GraphvizCounts(g).get_svg())


def show_counts_sigma(sigma: Dict['Bloq', Union[int, 'sympy.Expr']]):
"""Display nicely formatted bloq counts sums `sigma`."""
IPython.display.display(IPython.display.Markdown(format_counts_sigma(sigma)))
124 changes: 124 additions & 0 deletions qualtran/drawing/bloq_counts_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Classes for drawing bloq counts graphs with Graphviz."""

from typing import Dict, Union

import IPython.display
import networkx as nx
import pydot
import sympy

from qualtran import Bloq, CompositeBloq


class GraphvizCounts:
"""This class turns a bloqs count graph into Graphviz objects and drawings.
Args:
g: The counts graph.
"""

def __init__(self, g: nx.DiGraph):
self.g = g
self._ids: Dict[Bloq, str] = {}
self._i = 0

def get_id(self, b: Bloq) -> str:
if b in self._ids:
return self._ids[b]
new_id = f'b{self._i}'
self._i += 1
self._ids[b] = new_id
return new_id

def get_node_properties(self, b: Bloq):
"""Get graphviz properties for a bloq node representing `b`."""
if isinstance(b, CompositeBloq):
details = f'{len(b.bloq_instances)} bloqs...'
else:
details = repr(b)

label = [
'<',
f'{b.pretty_name().replace("<", "&lt;").replace(">", "&gt;")}<br />',
f'<font face="monospace" point-size="10">{details}</font><br/>',
'>',
]
return {'label': ''.join(label), 'shape': 'rect'}

def add_nodes(self, graph: pydot.Graph):
"""Helper function to add nodes to the pydot graph."""
b: Bloq
for b in nx.topological_sort(self.g):
graph.add_node(pydot.Node(self.get_id(b), **self.get_node_properties(b)))

def add_edges(self, graph: pydot.Graph):
"""Helper function to add edges to the pydot graph."""
for b1, b2 in self.g.edges:
n = self.g.edges[b1, b2]['n']
label = sympy.printing.pretty(n)
graph.add_edge(pydot.Edge(self.get_id(b1), self.get_id(b2), label=label))

def get_graph(self):
"""Get the pydot graph."""
graph = pydot.Dot('counts', graph_type='digraph', rankdir='TB')
self.add_nodes(graph)
self.add_edges(graph)
return graph

def get_svg_bytes(self) -> bytes:
"""Get the SVG code (as bytes) for drawing the graph."""
return self.get_graph().create_svg()

def get_svg(self) -> IPython.display.SVG:
"""Get an IPython SVG object displaying the graph."""
return IPython.display.SVG(self.get_svg_bytes())


def _format_bloq_expr_markdown(bloq: Bloq, expr: Union[int, sympy.Expr]) -> str:
"""Return "`bloq`: expr" as markdown."""
try:
expr = expr._repr_latex_()
except AttributeError:
expr = f'{expr}'

return f'`{bloq}`: {expr}'


def format_counts_graph_markdown(graph: nx.DiGraph) -> str:
"""Format a text version of `graph` as markdown."""
m = ""
for bloq in nx.topological_sort(graph):
if not graph.succ[bloq]:
continue
m += f' - `{bloq}`\n'

succ_lines = []
for succ in graph.succ[bloq]:
expr = sympy.sympify(graph.edges[bloq, succ]['n'])
succ_lines.append(f' - {_format_bloq_expr_markdown(succ, expr)}\n')
succ_lines.sort()
m += ''.join(succ_lines)

return m


def format_counts_sigma(sigma: Dict[Bloq, Union[int, sympy.Expr]]) -> str:
"""Format `sigma` as markdown."""
lines = [f' - {_format_bloq_expr_markdown(bloq, expr)}' for bloq, expr in sigma.items()]
lines.sort()
lines.insert(0, '#### Counts totals:')
return '\n'.join(lines)
59 changes: 59 additions & 0 deletions qualtran/drawing/bloq_counts_graph_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re

from qualtran.bloqs.and_bloq import MultiAnd
from qualtran.drawing import format_counts_graph_markdown, format_counts_sigma, GraphvizCounts
from qualtran.resource_counting import get_bloq_counts_graph


def test_format_counts_sigma():
graph, sigma = get_bloq_counts_graph(MultiAnd(cvs=(1,) * 6))
ret = format_counts_sigma(sigma)
assert (
ret
== """\
#### Counts totals:
- `ArbitraryClifford(n=2)`: 45
- `TGate()`: 20"""
)


def test_format_counts_graph_markdown():
graph, sigma = get_bloq_counts_graph(MultiAnd(cvs=(1,) * 6))
ret = format_counts_graph_markdown(graph)
assert (
ret
== r""" - `MultiAnd(cvs=(1, 1, 1, 1, 1, 1), adjoint=False)`
- `And(cv1=1, cv2=1, adjoint=False)`: $\displaystyle 5$
- `And(cv1=1, cv2=1, adjoint=False)`
- `ArbitraryClifford(n=2)`: $\displaystyle 9$
- `TGate()`: $\displaystyle 4$
"""
)


def test_graphviz_counts():
graph, sigma = get_bloq_counts_graph(MultiAnd(cvs=(1,) * 6))
gvc = GraphvizCounts(graph)

# The main test is in the drawing notebook, so please spot check that.
# Here: we make sure the edge labels are 5, 9 or 4 (see above)
dot_lines = gvc.get_graph().to_string().splitlines()
edge_lines = [line for line in dot_lines if '->' in line]
for line in edge_lines:
ma = re.search(r'label=(\w+)', line)
assert ma is not None, line
i = int(ma.group(1))
assert i in [5, 9, 4]
Loading

0 comments on commit 03198a9

Please sign in to comment.