From 628e415f7fbec9c2616215b6dfdeea6029389d7e Mon Sep 17 00:00:00 2001 From: Johnnie Gray Date: Wed, 10 Jul 2024 15:24:15 -0700 Subject: [PATCH] SimpleUpdateGen: add simple .plot() method --- quimb/tensor/tensor_arbgeom_tebd.py | 43 +++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/quimb/tensor/tensor_arbgeom_tebd.py b/quimb/tensor/tensor_arbgeom_tebd.py index a76773e9..ff9d4b8b 100644 --- a/quimb/tensor/tensor_arbgeom_tebd.py +++ b/quimb/tensor/tensor_arbgeom_tebd.py @@ -685,6 +685,49 @@ def compute_energy(self): **self.compute_energy_opts ) + @default_to_neutral_style + def plot( + self, + zoom="auto", + xscale="symlog", + xscale_linthresh=20, + hlines=() + ): + import numpy as np + import matplotlib.pyplot as plt + from matplotlib.colors import hsv_to_rgb + + fig, ax = plt.subplots() + + xs = np.array(self.its) + ys = np.array(self.energies) + + ax.plot(xs, ys, '.-') + ax.set_xlabel("Iteration") + ax.set_ylabel("Energy") + + if xscale == "symlog": + ax.set_xscale(xscale, linthresh=xscale_linthresh) + ax.axvline(xscale_linthresh, color=(.5, .5, .5), ls="-", lw=0.5) + else: + ax.set_xscale(xscale) + + if hlines: + hlines = dict(hlines) + for i, (label, value) in enumerate(hlines.items()): + color = hsv_to_rgb([(0.1 * i) % 1.0, 0.9, 0.9]) + ax.axhline(value, color=color, ls="--", label=label) + ax.text(1, value, label, color=color, va="bottom", ha="left") + + if zoom is not None: + if zoom == "auto": + zoom = min(50, ys.size // 2) + + iax = ax.inset_axes([0.5, 0.5, 0.5, 0.5]) + iax.plot(xs[-zoom:], ys[-zoom:], ".-") + + return fig, ax + def __repr__(self): s = "<{}(n={}, tau={}, D={})>" return s.format(