Skip to content

Commit

Permalink
Add custom samplers, better collection, and better plotting to sinter (
Browse files Browse the repository at this point in the history
…#804)

- Add `sinter.Sampler` and `sinter.CompiledSampler` classes
- They can go anywhere a Decoder would go, but they are responsible for
all parts of the sampling instead of only prediction
- Add a new default sampler `perfectionist`, which discards anything
with detection events and predicts the observables are not flipped
- Improved layout of the progress printouts when collect is running
- Sinter decoders can now flag that they want to discard shots by adding
an extra byte to the returned observable data, with 0 meaning keep and
not-0 meaning discard
- Change how `sinter collect` distributes work
- Workers are now distributed as widely as possible, instead of all on
one task
- Workers are now never switched between tasks until their current task
is done
- Add `sinter plot --point_label_func` argument for drawing text next to
data points
- Augment `sinter plot --group_func` to support dictionaries with
special keys controlling precise grouping behaviors
- If group_func returns a dict with a `"color"` key, all items with the
same `"color"` value are drawn with the same color
- If group_func returns a dict with a `"linestyle"` key, all items with
the same `"linestyle"` value are drawn with the same linestyle
- If group_func returns a dict with a `"marker"` key, all items with the
same `"marker"` value are drawn with the same marker
- If group_func returns a dict with a `"label"` key, this forces the
label shown in the legend
- If group_func returns a dict with an `"order"` key, this takes
priority for ordering the legend
- `sinter collect --processes` is no longer required (defaults to
`"auto"`)
- `sinter plot --show` is no longer required (defaults to showing,
unless `--out` is specified, unless `--show` is specified)
- Group some of sinter's code into private subpackages
- Show traditional error bars instead of a filled region for high/low
fit when only one data point is present
- Add `sinter plot --preprocess_stats_func`
- Add `sinter.TaskStats.with_edits`
- Add safety error when adding stats that have equal strong ids but
differing identifying information (json_metadata or decoder)

Some of the sampler design is adapted from @inmzhang's design in
#735

Fixes #774

Fixes #682

Fixes #392

---------

Co-authored-by: Matt McEwen <[email protected]>
  • Loading branch information
Strilanc and m-mcewen committed Sep 10, 2024
1 parent 64cf7e1 commit 7c26e1c
Show file tree
Hide file tree
Showing 65 changed files with 3,000 additions and 1,188 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ jobs:
- run: mv dist/* output/stim
- run: mv glue/cirq/dist/* output/stimcirq
- run: mv glue/sample/dist/* output/sinter
- uses: actions/upload-artifact@v3
- uses: actions/upload-artifact@v4.4.0
with:
name: dist
path: |
Expand All @@ -185,7 +185,7 @@ jobs:
if: github.ref == 'refs/heads/main'
runs-on: ubuntu-latest
steps:
- uses: actions/download-artifact@v2
- uses: actions/download-artifact@v4.1.7
with:
name: dist
path: dist
Expand Down
24 changes: 23 additions & 1 deletion dev/gen_sinter_api_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,25 @@ def main():
```
'''.strip())

replace_rules = []
for package in ['stim', 'sinter']:
p = __import__(package)
for name in dir(p):
x = getattr(p, name)
if isinstance(x, type) and 'class' in str(x):
desired_name = f'{package}.{name}'
if '._' in str(x):
bad_name = str(x).split("'")[1]
replace_rules.append((bad_name, desired_name))
lonely_name = desired_name.split(".")[-1]
for q in ['"', "'"]:
replace_rules.append(('ForwardRef(' + q + lonely_name + q + ')', desired_name))
replace_rules.append(('ForwardRef(' + q + desired_name + q + ')', desired_name))
replace_rules.append((q + desired_name + q, desired_name))
replace_rules.append((q + lonely_name + q, desired_name))
replace_rules.append(('ForwardRef(' + desired_name + ')', desired_name))
replace_rules.append(('ForwardRef(' + lonely_name + ')', desired_name))

for obj in objects:
print()
print(f'<a name="{obj.full_name}"></a>')
Expand All @@ -58,7 +77,10 @@ def main():
print(f'# (in class {".".join(obj.full_name.split(".")[:-1])})')
else:
print(f'# (at top-level in the sinter module)')
print('\n'.join(obj.lines))
for line in obj.lines:
for a, b in replace_rules:
line = line.replace(a, b)
print(line)
print("```")


Expand Down
13 changes: 1 addition & 12 deletions dev/util_gen_stub_file.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import dataclasses
import sys
import types
from typing import Any
from typing import Optional, Iterator, List
Expand All @@ -9,6 +8,7 @@

keep = {
"__add__",
"__radd__",
"__eq__",
"__call__",
"__ge__",
Expand Down Expand Up @@ -224,17 +224,6 @@ def print_doc(*, full_name: str, parent: object, obj: object, level: int) -> Opt
text += '@abc.abstractmethod\n'
sig_name = f'{term_name}{inspect.signature(obj)}'
text += "\n".join(splay_signature(f"def {sig_name}:"))
text = text.replace('''ForwardRef('sinter.TaskStats')''', 'sinter.TaskStats')
text = text.replace('''ForwardRef('sinter.Task')''', 'sinter.Task')
text = text.replace('''ForwardRef('sinter.Progress')''', 'sinter.Progress')
text = text.replace('''ForwardRef('sinter.Decoder')''', 'sinter.Decoder')
text = text.replace("'AnonTaskStats'", "sinter.AnonTaskStats")
text = text.replace('sinter._decoding_decoder_class.CompiledDecoder', 'sinter.CompiledDecoder')
text = text.replace("'AnonTaskStats'", "sinter.AnonTaskStats")
text = text.replace("'stim.Circuit'", "stim.Circuit")
text = text.replace("'stim.DetectorErrorModel'", "stim.DetectorErrorModel")
text = text.replace("'sinter.CollectionOptions'", "sinter.CollectionOptions")
text = text.replace("'sinter.Fit'", 'sinter.Fit')

# Replace default value lambdas with their source.
if 'lambda' in str(text):
Expand Down
1 change: 1 addition & 0 deletions doc/python_api_reference_vDev.md
Original file line number Diff line number Diff line change
Expand Up @@ -1610,6 +1610,7 @@ def diagram(
*,
tick: Union[None, int, range] = None,
filter_coords: Iterable[Union[Iterable[float], stim.DemTarget]] = ((),),
rows: int | None = None,
) -> 'stim._DiagramHelper':
"""Returns a diagram of the circuit, from a variety of options.
Expand Down
Loading

0 comments on commit 7c26e1c

Please sign in to comment.