Skip to content

Commit

Permalink
Report Plot issues. (#851)
Browse files Browse the repository at this point in the history
* fine tuning visualization of lattice in report.

* fixing filter to work on specific task number

* add `NaN` when no shots exists for a task becaue of perfect sorting issues.

* fixing bug in dataframe indexing.

* fixing bug in `_filter`.
  • Loading branch information
weinbe58 authored Jan 4, 2024
1 parent 5b879f3 commit 467cc56
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 22 deletions.
27 changes: 22 additions & 5 deletions src/bloqade/task/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,21 @@ def markdown(self) -> str:

def _filter(
self,
*,
task_number: Optional[int] = None,
filter_perfect_filling: bool = True,
clusters: Union[tuple[int, int], Sequence[Tuple[int, int]]] = [],
):
mask = np.ones(len(self.dataframe), dtype=bool)

if task_number is not None:
task_numbers = self.dataframe.index.get_level_values("task_number")
np.logical_and(task_numbers == task_number, mask, out=mask)

if filter_perfect_filling:
perfect_sorting = self.dataframe.index.get_level_values("perfect_sorting")
pre_sequence = self.dataframe.index.get_level_values("pre_sequence")

np.logical_and(perfect_sorting == pre_sequence, mask, out=mask)

clusters = [clusters] if isinstance(clusters, tuple) else clusters
Expand Down Expand Up @@ -169,14 +176,22 @@ def bitstrings(
is set to True.
"""
mask = self._filter(filter_perfect_filling, clusters)
df = self.dataframe[mask]

task_numbers = df.index.get_level_values("task_number").unique()
task_numbers = self.dataframe.index.get_level_values("task_number").unique()

bitstrings = []
for task_number in task_numbers:
bitstrings.append(df.loc[task_number, ...].to_numpy())
mask = self._filter(
task_number=task_number,
filter_perfect_filling=filter_perfect_filling,
clusters=clusters,
)
if np.any(mask):
bitstrings.append(self.dataframe.loc[mask].to_numpy())
else:
bitstrings.append(
np.zeros((0, self.dataframe.shape[1]), dtype=np.uint8)
)

return bitstrings

Expand Down Expand Up @@ -239,7 +254,9 @@ def rydberg_densities(
per-site rydberg density for each task as a pandas DataFrame or Series.
"""
mask = self._filter(filter_perfect_filling, clusters)
mask = self._filter(
filter_perfect_filling=filter_perfect_filling, clusters=clusters
)
df = self.dataframe[mask]
return 1 - (df.groupby("task_number").mean())

Expand Down
55 changes: 38 additions & 17 deletions src/bloqade/visualization/report_visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,14 @@ def format_report_data(report: "Report"):
task_tid = list(task_tid)

counts = report.counts()
ryds = report.rydberg_densities()
bitstrings = report.bitstrings()
ryds = [np.mean(ele, axis=0) for ele in bitstrings]

assert len(task_tid) == len(counts)

cnt_sources = []
ryd_sources = []
for i, cnt_data in enumerate(counts):
# bitstrings = list(
# "\n".join(textwrap.wrap(bitstring, 32)) for bitstring in cnt_data.keys()
# )
bitstrings = list(cnt_data.keys())
bit_id = [f"[{x}]" for x in range(len(bitstrings))]

Expand All @@ -60,7 +58,7 @@ def format_report_data(report: "Report"):
data=dict(tid=tid, bitstrings=bitstrings, cnts=cnts, bit_id=bit_id)
)

rydens = list(ryds.iloc[i])
rydens = list(ryds[i])
tid = [i] * len(rydens)

rsrc = ColumnDataSource(
Expand Down Expand Up @@ -147,6 +145,13 @@ def mock_data():
return cnt_sources, ryd_sources, metas, geos, "Mock"


def get_radius(length_scale, x_min, x_max, y_min, y_max):
global_length_scale = max(x_max - x_min, y_max - y_min)
coeff = 0.15 + 0.3 * np.tanh(global_length_scale / 10)
radius = coeff * length_scale
return radius


def plot_register_ryd_dense(geo, ryds):
"""obtain a figure object from the atom arrangement."""
xs_filled, ys_filled, labels_filled, density_filled = [], [], [], []
Expand Down Expand Up @@ -175,7 +180,11 @@ def plot_register_ryd_dense(geo, ryds):
density_vacant.append(density)

if len(geo.sites) > 0:
length_scale = max(y_max - y_min, x_max - x_min, 1)
length_scale = np.inf
for i, site_i in enumerate(geo.sites):
for site_j in geo.sites[i + 1 :]:
dist = np.linalg.norm(np.array(site_i) - np.array(site_j)) / 1e-6
length_scale = min(length_scale, dist)
else:
length_scale = 1

Expand All @@ -196,7 +205,6 @@ def plot_register_ryd_dense(geo, ryds):
("index: ", "@_labels"),
("ryd density: ", "@_ryd"),
]

color_mapper = LinearColorMapper(palette="Magma256", low=min(ryds), high=max(ryds))

# specify that we want to map the colors to the y values,
Expand All @@ -212,28 +220,33 @@ def plot_register_ryd_dense(geo, ryds):
toolbar_location="above",
title="rydberg density",
)
p.x_range = Range1d(x_min - 1, x_min + length_scale + 1)
p.y_range = Range1d(y_min - 1, y_min + length_scale + 1)
p.x_range = Range1d(x_min - length_scale, x_max + length_scale)
p.y_range = Range1d(y_min - length_scale, y_max + length_scale)

# interpolate between a scale for small lattices
# and a scale for larger lattices
radius = get_radius(length_scale, x_min, x_max, y_min, y_max)

p.circle(
"_x",
"_y",
source=source_filled,
radius=0.035 * length_scale,
radius=radius,
fill_alpha=1,
line_color="black",
color={"field": "_ryd", "transform": color_mapper},
)

p.circle(
"_x",
"_y",
source=source_vacant,
radius=0.035 * length_scale,
radius=radius,
fill_alpha=1,
# color="grey",
line_color="black",
color={"field": "_ryd", "transform": color_mapper},
line_width=0.2 * length_scale,
line_width=0.01 * length_scale,
)

color_bar = ColorBar(
Expand Down Expand Up @@ -272,13 +285,17 @@ def plot_register_bits(geo):
x_max = max(x, x_max)
y_max = max(y, y_max)

xs.append(x)
ys.append(y)
xs.append(x)
bits.append(0)
labels.append(idx)

if len(geo.sites) > 0:
length_scale = max(y_max - y_min, x_max - x_min, 1)
length_scale = np.inf
for i, site_i in enumerate(geo.sites):
for site_j in geo.sites[i + 1 :]:
dist = np.linalg.norm(np.array(site_i) - np.array(site_j)) / 1e-6
length_scale = min(length_scale, dist)
else:
length_scale = 1

Expand All @@ -297,6 +314,10 @@ def plot_register_bits(geo):
# this could be replaced with a list of colors
##p.scatter(x,y,color={'field': 'y', 'transform': color_mapper})

# interpolate between a scale for small lattices
# and a scale for larger lattices
radius = get_radius(length_scale, x_min, x_max, y_min, y_max)

## remove box_zoom since we don't want to change the scale

p = figure(
Expand All @@ -306,14 +327,14 @@ def plot_register_bits(geo):
toolbar_location="above",
title="reg state",
)
p.x_range = Range1d(x_min - 1, x_min + length_scale + 1)
p.y_range = Range1d(y_min - 1, y_min + length_scale + 1)
p.x_range = Range1d(x_min - length_scale, x_max + length_scale)
p.y_range = Range1d(y_min - length_scale, y_max + length_scale)

p.circle(
"_x",
"_y",
source=source,
radius=0.035 * length_scale,
radius=radius,
fill_alpha=1,
line_color="black",
color={"field": "_bits", "transform": color_mapper},
Expand Down

0 comments on commit 467cc56

Please sign in to comment.