From 4682043f66f68b388f40aa3d9037df6719ebdaf0 Mon Sep 17 00:00:00 2001 From: Riley Murray Date: Mon, 24 Jun 2024 17:45:25 -0400 Subject: [PATCH 1/4] tweak stateless_data --- pygsti/modelmembers/operations/fulltpop.py | 16 ++++++++------- pygsti/modelmembers/povms/tppovm.py | 23 +++++++++++----------- pygsti/modelmembers/states/tpstate.py | 11 ++++++----- 3 files changed, 26 insertions(+), 24 deletions(-) diff --git a/pygsti/modelmembers/operations/fulltpop.py b/pygsti/modelmembers/operations/fulltpop.py index 16866b893..f77307ad7 100644 --- a/pygsti/modelmembers/operations/fulltpop.py +++ b/pygsti/modelmembers/operations/fulltpop.py @@ -164,16 +164,18 @@ def from_vector(self, v, close=False, dirty_value=True): self._ptr_has_changed() # because _rep.base == _ptr (same memory) self.dirty = dirty_value - def stateless_data(self) -> Tuple[int]: - return (self.dim,) - - @staticmethod - def torch_base(sd: Tuple[int], t_param: _torch.Tensor) -> _torch.Tensor: - dim = sd[0] + def stateless_data(self) -> Tuple[int, _torch.Tensor]: + dim = self.dim t_const = _torch.zeros(size=(1, dim), dtype=_torch.double) t_const[0,0] = 1.0 - t_param_mat = t_param.reshape((dim - 1, dim)) + return (dim, t_const) + + @staticmethod + def torch_base(sd: Tuple[int, _torch.Tensor], t_param: _torch.Tensor) -> _torch.Tensor: + dim, t_const = sd + t_param_mat = t_param.view(dim - 1, dim) t = _torch.row_stack((t_const, t_param_mat)) + # TODO: cache the row of all zeros? return t diff --git a/pygsti/modelmembers/povms/tppovm.py b/pygsti/modelmembers/povms/tppovm.py index 80753385f..1183f5e3e 100644 --- a/pygsti/modelmembers/povms/tppovm.py +++ b/pygsti/modelmembers/povms/tppovm.py @@ -102,29 +102,28 @@ def to_vector(self): vec = _np.concatenate(effect_vecs) return vec - def stateless_data(self) -> Tuple[int, _np.ndarray]: + def stateless_data(self) -> Tuple[int, _torch.Tensor, int]: num_effects = len(self) complement_effect = self[self.complement_label] identity = complement_effect.identity.to_vector() - return (num_effects, identity) - - @staticmethod - def torch_base(sd: Tuple[int, _np.ndarray], t_param: _torch.Tensor) -> _torch.Tensor: - num_effects, identity = sd + identity = identity.reshape((1, -1)) # make into a row vector + t_identity = _torch.from_numpy(identity) + dim = identity.size - - first_basis_vec = _np.zeros(dim) - first_basis_vec[0] = dim ** 0.25 + first_basis_vec = _np.zeros((1,dim)) + first_basis_vec[0,0] = dim ** 0.25 TOL = 1e-15 * _np.sqrt(dim) if _np.linalg.norm(first_basis_vec - identity) > TOL: # Don't error out. The documentation for the class # clearly indicates that the meaning of "identity" # can be nonstandard. warnings.warn('Unexpected normalization!') + return (num_effects, t_identity, dim) - identity = identity.reshape((1, -1)) # make into a row vector - t_identity = _torch.from_numpy(identity) - t_param_mat = t_param.reshape((num_effects - 1, dim)) + @staticmethod + def torch_base(sd: Tuple[int, _torch.Tensor, int], t_param: _torch.Tensor) -> _torch.Tensor: + num_effects, t_identity, dim = sd + t_param_mat = t_param.view(num_effects - 1, dim) t_func = t_identity - t_param_mat.sum(axis=0, keepdim=True) t = _torch.row_stack((t_param_mat, t_func)) return t diff --git a/pygsti/modelmembers/states/tpstate.py b/pygsti/modelmembers/states/tpstate.py index 659d6da24..c74c49d78 100644 --- a/pygsti/modelmembers/states/tpstate.py +++ b/pygsti/modelmembers/states/tpstate.py @@ -166,13 +166,14 @@ def from_vector(self, v, close=False, dirty_value=True): self._ptr_has_changed() self.dirty = dirty_value - def stateless_data(self) -> Tuple[int]: - return (self.dim,) + def stateless_data(self) -> Tuple[_torch.Tensor]: + dim = self.dim + t_const = (dim ** -0.25) * _torch.ones(1, dtype=_torch.double) + return (t_const,) @staticmethod - def torch_base(sd: Tuple[int], t_param: _torch.Tensor) -> _torch.Tensor: - dim = sd[0] - t_const = (dim ** -0.25) * _torch.ones(1, dtype=_torch.double) + def torch_base(sd: Tuple[_torch.Tensor], t_param: _torch.Tensor) -> _torch.Tensor: + t_const = sd[0] t = _torch.concat((t_const, t_param)) return t From b0c630a2091a4da611ed6ae3f9322536be3c77d6 Mon Sep 17 00:00:00 2001 From: Riley Murray Date: Mon, 24 Jun 2024 17:47:55 -0400 Subject: [PATCH 2/4] logging --- pygsti/forwardsims/torchfwdsim.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/pygsti/forwardsims/torchfwdsim.py b/pygsti/forwardsims/torchfwdsim.py index 1285e51de..117d55177 100644 --- a/pygsti/forwardsims/torchfwdsim.py +++ b/pygsti/forwardsims/torchfwdsim.py @@ -159,6 +159,17 @@ def get_torch_bases(self, free_params: Tuple[torch.Tensor]) -> Dict[Label, torch fp in free_params. This can be done by calling fp._requires_grad(True) before calling this function. """ + # The closest analog to this function in tgst is the first couple lines in + # tgst.gst.MachineModel.circuit_outcome_probs(...). + # Those lines just assign values a-la new_machine.params[i][:] = fp[:]. + # + # The variables new_machine.params[i] are just references to Tensors + # that are attached to tgst.abstractions objects (Gate, Measurement, State). + # + # Calling abstr.rep_array for a given abstraction performs a computation on + # its attached Tensor, and that computation is roughly analogous to + # torchable.torch_base(...). + # assert len(free_params) == len(self.param_metadata) # ^ A sanity check that we're being called with the correct number of arguments. torch_bases = dict() @@ -248,8 +259,10 @@ def _bulk_fill_dprobs(self, array_to_fill, layout, pr_array_to_fill) -> None: if slm.default_to_reverse_ad: # Then slm.circuit_probs_from_free_params will automatically construct the # torch_base dict to support reverse-mode AD. + print('USING REVERSE-MODE AD') J_func = torch.func.jacrev(slm.circuit_probs_from_free_params, argnums=argnums) else: + print('USING FORWARD-MODE AD') # Then slm.circuit_probs_from_free_params will automatically skip the extra # steps needed for torch_base to support reverse-mode AD. J_func = torch.func.jacfwd(slm.circuit_probs_from_free_params, argnums=argnums) @@ -258,7 +271,13 @@ def _bulk_fill_dprobs(self, array_to_fill, layout, pr_array_to_fill) -> None: # have a need to override the default in the future then we'd need to override # the ForwardSimulator function(s) that call self._bulk_fill_dprobs(...). + import time + print('Calling J_func at current free_params') + tic = time.time() J_val = J_func(*free_params) + toc = time.time() + print() + print(f'Done! --> {toc - tic} seconds elapsed') J_val = torch.column_stack(J_val) array_to_fill[:] = J_val.cpu().detach().numpy() return From 62f37ec58e7845f93fd7bda72fb8eef6692ad1c0 Mon Sep 17 00:00:00 2001 From: Riley Murray Date: Tue, 16 Jul 2024 12:03:16 -0400 Subject: [PATCH 3/4] add trivial __getstate__ and __setstate__ needed for serialization --- pygsti/modelmembers/povms/conjugatedeffect.py | 6 ++++++ pygsti/modelmembers/states/densestate.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/pygsti/modelmembers/povms/conjugatedeffect.py b/pygsti/modelmembers/povms/conjugatedeffect.py index 5af305a44..3b0b5ddec 100644 --- a/pygsti/modelmembers/povms/conjugatedeffect.py +++ b/pygsti/modelmembers/povms/conjugatedeffect.py @@ -80,6 +80,12 @@ def __setitem__(self, key, val): ret = self.columnvec.__setitem__(key, val) self._ptr_has_changed() return ret + + def __getstate__(self): + return self.__dict__ + + def __setstate__(self, d): + self.__dict__.update(d) def __getattr__(self, attr): #use __dict__ so no chance for recursive __getattr__ diff --git a/pygsti/modelmembers/states/densestate.py b/pygsti/modelmembers/states/densestate.py index 3c7df543f..2d9b17fc0 100644 --- a/pygsti/modelmembers/states/densestate.py +++ b/pygsti/modelmembers/states/densestate.py @@ -100,6 +100,12 @@ def __setitem__(self, key, val): ret = self.columnvec.__setitem__(key, val) self._ptr_has_changed() return ret + + def __getstate__(self): + return self.__dict__ + + def __setstate__(self, d): + self.__dict__.update(d) def __getattr__(self, attr): #use __dict__ so no chance for recursive __getattr__ From f3bfb806c5d5401c7b5dce4bac24f9ea5db90dc5 Mon Sep 17 00:00:00 2001 From: Riley Murray Date: Tue, 16 Jul 2024 12:03:47 -0400 Subject: [PATCH 4/4] remove some logging and profiling code (well, just comment out) --- pygsti/forwardsims/torchfwdsim.py | 38 +++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/pygsti/forwardsims/torchfwdsim.py b/pygsti/forwardsims/torchfwdsim.py index 117d55177..5079b26e1 100644 --- a/pygsti/forwardsims/torchfwdsim.py +++ b/pygsti/forwardsims/torchfwdsim.py @@ -31,6 +31,7 @@ try: import torch + from torch.profiler import profile, record_function, ProfilerActivity TORCH_ENABLED = True except ImportError: TORCH_ENABLED = False @@ -89,6 +90,7 @@ def __init__(self, model: ExplicitOpModel, layout: CircuitOutcomeProbabilityArra # framed in terms of the "layout._element_indicies" dict. eind = layout._element_indices assert isinstance(eind, dict) + assert len(eind) > 0 items = iter(eind.items()) k_prev, v_prev = next(items) assert k_prev == 0 @@ -213,8 +215,23 @@ def circuit_probs_from_free_params(self, *free_params: Tuple[torch.Tensor], enab if enable_backward: for fp in free_params: fp._requires_grad(True) - torch_bases = self.get_torch_bases(free_params) - probs = self.circuit_probs_from_torch_bases(torch_bases) + + torch_bases = dict() + for i, val in enumerate(free_params): + label, type_handle, stateless_data = self.param_metadata[i] + param_t = type_handle.torch_base(stateless_data, val) + torch_bases[label] = param_t + + probs = [] + for c in self.circuits: + superket = torch_bases[c.prep_label] + superops = [torch_bases[ol] for ol in c.op_labels] + povm_mat = torch_bases[c.povm_label] + for superop in superops: + superket = superop @ superket + circuit_probs = povm_mat @ superket + probs.append(circuit_probs) + probs = torch.concat(probs) return probs @@ -259,10 +276,10 @@ def _bulk_fill_dprobs(self, array_to_fill, layout, pr_array_to_fill) -> None: if slm.default_to_reverse_ad: # Then slm.circuit_probs_from_free_params will automatically construct the # torch_base dict to support reverse-mode AD. - print('USING REVERSE-MODE AD') + # print('USING REVERSE-MODE AD') J_func = torch.func.jacrev(slm.circuit_probs_from_free_params, argnums=argnums) else: - print('USING FORWARD-MODE AD') + # print('USING FORWARD-MODE AD') # Then slm.circuit_probs_from_free_params will automatically skip the extra # steps needed for torch_base to support reverse-mode AD. J_func = torch.func.jacfwd(slm.circuit_probs_from_free_params, argnums=argnums) @@ -271,13 +288,14 @@ def _bulk_fill_dprobs(self, array_to_fill, layout, pr_array_to_fill) -> None: # have a need to override the default in the future then we'd need to override # the ForwardSimulator function(s) that call self._bulk_fill_dprobs(...). - import time - print('Calling J_func at current free_params') - tic = time.time() + # import time + # print('Calling J_func at current free_params') + # tic = time.time() + # with profile(activities=[ProfilerActivity.CPU], profile_memory=True) as prof: J_val = J_func(*free_params) - toc = time.time() - print() - print(f'Done! --> {toc - tic} seconds elapsed') + # toc = time.time() + # print() + # print(f'Done! --> {toc - tic} seconds elapsed') J_val = torch.column_stack(J_val) array_to_fill[:] = J_val.cpu().detach().numpy() return