Skip to content

Commit

Permalink
Changed plot y-axis labels to more accurately express the values. (#34)
Browse files Browse the repository at this point in the history
I think it would be a good idea to update the y-axis labels from "Pr" to
something like "Sqrt(Pr)" to reflect the actual values.

As of now, it is confusing to see "Pr" for a set of values which do not
add up to 1.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
gkpotter and pre-commit-ci[bot] authored Jun 21, 2024
1 parent bea16b8 commit ef89169
Showing 1 changed file with 37 additions and 31 deletions.
68 changes: 37 additions & 31 deletions src/mqt/qudits/visualisation/plot_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,44 @@


class HistogramWithErrors:
def __init__(self, labels, counts, errors, title="Simulation") -> None:
def __init__(self, labels, counts, errors, title="", xlabel="Labels", ylabel="Counts") -> None:
self.labels = labels
self.counts = counts
self.errors = errors

self.title = title
self.xlabel = xlabel
self.ylabel = ylabel

def generate_histogram(self) -> None:
plt.bar(self.labels, self.counts, yerr=self.errors, capsize=5, color="b", alpha=0.7, align="center")
plt.xlabel("States")
plt.ylabel("Pr")
plt.bar(
self.labels,
self.counts,
yerr=self.errors,
capsize=5,
color="b",
alpha=0.7,
align="center",
)
plt.xlabel(self.xlabel)
plt.ylabel(self.ylabel)
plt.title(self.title)
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()

def save_to_png(self, filename) -> None:
plt.bar(self.labels, self.counts, yerr=self.errors, capsize=5, color="b", alpha=0.7, align="center")
plt.xlabel("States")
plt.ylabel("Pr")
plt.bar(
self.labels,
self.counts,
yerr=self.errors,
capsize=5,
color="b",
alpha=0.7,
align="center",
)
plt.xlabel(self.xlabel)
plt.ylabel(self.ylabel)
plt.title(self.title)
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
Expand Down Expand Up @@ -58,34 +77,21 @@ def state_labels(circuit):
return string_states


def plot_state(result: np.ndarray, circuit: QuantumCircuit, errors=None) -> None:
result = np.squeeze(result).tolist()
if errors is None:
errors = len(result) * [0]
def plot_state(state_vector: np.ndarray, circuit: QuantumCircuit, errors=None) -> None:
labels = state_labels(circuit)

string_states = state_labels(circuit)
state_vector_list = np.squeeze(state_vector).tolist()
counts = [abs(coeff) for coeff in state_vector_list]

result = [abs(coeff) for coeff in result]
h_plotter = HistogramWithErrors(string_states, result, errors, title="Simulation")
h_plotter = HistogramWithErrors(labels, counts, errors, title="Simulation", xlabel="States", ylabel="Sqrt(Pr)")
h_plotter.generate_histogram()


def plot_counts(result, circuit: QuantumCircuit) -> None:
custom_labels = state_labels(circuit)

# Count the frequency of each outcome
counts = {label: result.count(i) for i, label in enumerate(custom_labels)}

# Create a bar plot with custom labels
plt.bar(custom_labels, counts.values())
def plot_counts(measurements, circuit: QuantumCircuit) -> None:
labels = state_labels(circuit)
counts = [measurements.count(i) for i in range(len(labels))]

# Add labels and title
plt.xlabel("States")
plt.ylabel("Counts")
plt.title("Simulation")
errors = len(labels) * [0]

plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()

return counts
h_plotter = HistogramWithErrors(labels, counts, errors, title="Simulation", xlabel="States", ylabel="Counts")
h_plotter.generate_histogram()

0 comments on commit ef89169

Please sign in to comment.