From ef89169bfda88226d23359bcf29519efadb4c627 Mon Sep 17 00:00:00 2001 From: Greyson <38485117+gkpotter@users.noreply.github.com> Date: Fri, 21 Jun 2024 09:38:52 +0200 Subject: [PATCH] Changed plot y-axis labels to more accurately express the values. (#34) I think it would be a good idea to update the y-axis labels from "Pr" to something like "Sqrt(Pr)" to reflect the actual values. As of now, it is confusing to see "Pr" for a set of values which do not add up to 1. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../qudits/visualisation/plot_information.py | 68 ++++++++++--------- 1 file changed, 37 insertions(+), 31 deletions(-) diff --git a/src/mqt/qudits/visualisation/plot_information.py b/src/mqt/qudits/visualisation/plot_information.py index 7a8bce9..d0c6cb0 100644 --- a/src/mqt/qudits/visualisation/plot_information.py +++ b/src/mqt/qudits/visualisation/plot_information.py @@ -11,25 +11,44 @@ class HistogramWithErrors: - def __init__(self, labels, counts, errors, title="Simulation") -> None: + def __init__(self, labels, counts, errors, title="", xlabel="Labels", ylabel="Counts") -> None: self.labels = labels self.counts = counts self.errors = errors + self.title = title + self.xlabel = xlabel + self.ylabel = ylabel def generate_histogram(self) -> None: - plt.bar(self.labels, self.counts, yerr=self.errors, capsize=5, color="b", alpha=0.7, align="center") - plt.xlabel("States") - plt.ylabel("Pr") + plt.bar( + self.labels, + self.counts, + yerr=self.errors, + capsize=5, + color="b", + alpha=0.7, + align="center", + ) + plt.xlabel(self.xlabel) + plt.ylabel(self.ylabel) plt.title(self.title) plt.xticks(rotation=45, ha="right") plt.tight_layout() plt.show() def save_to_png(self, filename) -> None: - plt.bar(self.labels, self.counts, yerr=self.errors, capsize=5, color="b", alpha=0.7, align="center") - plt.xlabel("States") - plt.ylabel("Pr") + plt.bar( + self.labels, + self.counts, + yerr=self.errors, + capsize=5, + color="b", + alpha=0.7, + align="center", + ) + plt.xlabel(self.xlabel) + plt.ylabel(self.ylabel) plt.title(self.title) plt.xticks(rotation=45, ha="right") plt.tight_layout() @@ -58,34 +77,21 @@ def state_labels(circuit): return string_states -def plot_state(result: np.ndarray, circuit: QuantumCircuit, errors=None) -> None: - result = np.squeeze(result).tolist() - if errors is None: - errors = len(result) * [0] +def plot_state(state_vector: np.ndarray, circuit: QuantumCircuit, errors=None) -> None: + labels = state_labels(circuit) - string_states = state_labels(circuit) + state_vector_list = np.squeeze(state_vector).tolist() + counts = [abs(coeff) for coeff in state_vector_list] - result = [abs(coeff) for coeff in result] - h_plotter = HistogramWithErrors(string_states, result, errors, title="Simulation") + h_plotter = HistogramWithErrors(labels, counts, errors, title="Simulation", xlabel="States", ylabel="Sqrt(Pr)") h_plotter.generate_histogram() -def plot_counts(result, circuit: QuantumCircuit) -> None: - custom_labels = state_labels(circuit) - - # Count the frequency of each outcome - counts = {label: result.count(i) for i, label in enumerate(custom_labels)} - - # Create a bar plot with custom labels - plt.bar(custom_labels, counts.values()) +def plot_counts(measurements, circuit: QuantumCircuit) -> None: + labels = state_labels(circuit) + counts = [measurements.count(i) for i in range(len(labels))] - # Add labels and title - plt.xlabel("States") - plt.ylabel("Counts") - plt.title("Simulation") + errors = len(labels) * [0] - plt.xticks(rotation=45, ha="right") - plt.tight_layout() - plt.show() - - return counts + h_plotter = HistogramWithErrors(labels, counts, errors, title="Simulation", xlabel="States", ylabel="Counts") + h_plotter.generate_histogram()