diff --git a/src/bloqade/emulate/ir/state_vector.py b/src/bloqade/emulate/ir/state_vector.py index 8e465cdf8..3fdc1e93f 100644 --- a/src/bloqade/emulate/ir/state_vector.py +++ b/src/bloqade/emulate/ir/state_vector.py @@ -163,8 +163,7 @@ def _error_check(solver_name: str, status_code: int): elif solver_name in ["dop853", "dopri5"]: AnalogGate._error_check_dop(status_code) - @beartype - def apply( + def _apply( self, state: StateArray, solver_name: str = "dop853", @@ -172,7 +171,6 @@ def apply( rtol: float = 1e-14, nsteps: int = 2_147_483_647, times: Union[List[float], RealArray] = [], - interaction_picture: bool = False, ): if state is None: state = self.hamiltonian.space.zero_state() @@ -184,10 +182,44 @@ def apply( state = np.asarray(state).astype(np.complex128, copy=False) - if interaction_picture: - solver = ode(self.hamiltonian._ode_real_kernel_int) - else: - solver = ode(self.hamiltonian._ode_real_kernel) + solver = ode(self.hamiltonian._ode_real_kernel) + + solver.set_initial_value(state.view(np.float64)) + solver.set_integrator(solver_name, atol=atol, rtol=rtol, nsteps=nsteps) + + if any(time >= duration or time < 0.0 for time in times): + raise ValueError("Times must be between 0 and duration.") + + times = [*times, duration] + + for time in times: + if time == 0.0: + yield state + continue + solver.integrate(time) + AnalogGate._error_check(solver_name, solver.get_return_code()) + yield solver.y.view(np.complex128) + + def _apply_interation_picture( + self, + state: StateArray, + solver_name: str = "dop853", + atol: float = 1e-7, + rtol: float = 1e-14, + nsteps: int = 2_147_483_647, + times: Union[List[float], RealArray] = [], + ): + if state is None: + state = self.hamiltonian.space.zero_state() + + if solver_name not in AnalogGate.SUPPORTED_SOLVERS: + raise ValueError(f"'{solver_name}' not supported.") + + duration = self.hamiltonian.emulator_ir.duration + + state = np.asarray(state).astype(np.complex128, copy=False) + + solver = ode(self.hamiltonian._ode_real_kernel_int) solver.set_initial_value(state.view(np.float64)) solver.set_integrator(solver_name, atol=atol, rtol=rtol, nsteps=nsteps) @@ -198,11 +230,44 @@ def apply( times = [*times, duration] for time in times: + if time == 0.0: + yield state + continue solver.integrate(time) AnalogGate._error_check(solver_name, solver.get_return_code()) u = np.exp(-1j * time * self.hamiltonian.rydberg) yield u * solver.y.view(np.complex128) + @beartype + def apply( + self, + state: StateArray, + solver_name: str = "dop853", + atol: float = 1e-7, + rtol: float = 1e-14, + nsteps: int = 2_147_483_647, + times: Union[List[float], RealArray] = [], + interaction_picture: bool = False, + ): + if interaction_picture: + return self._apply_interation_picture( + state, + solver_name=solver_name, + atol=atol, + rtol=rtol, + nsteps=nsteps, + times=times, + ) + else: + return self._apply( + state, + solver_name=solver_name, + atol=atol, + rtol=rtol, + nsteps=nsteps, + times=times, + ) + @beartype def run( self, @@ -224,9 +289,11 @@ def run( nsteps=nsteps, interaction_picture=interaction_picture, ) + state = self.hamiltonian.space.zero_state() (result,) = self.apply(state, **options) result /= np.linalg.norm(result) + return self.hamiltonian.space.sample_state_vector( result, shots, project_hyperfine=project_hyperfine ) diff --git a/src/bloqade/task/bloqade.py b/src/bloqade/task/bloqade.py index dcc4b604e..db0e9cfa5 100644 --- a/src/bloqade/task/bloqade.py +++ b/src/bloqade/task/bloqade.py @@ -63,7 +63,8 @@ def run( shot_result = QuEraShotResult( shot_status=QuEraShotStatusCode.Completed, pre_sequence=[1 for _ in shot], - post_sequence=list(shot), + # flip the bits so that 1 = ground state and 0 = excited state + post_sequence=list(1 - shot), ) shot_outputs.append(shot_result) diff --git a/tests/test_python_emulator.py b/tests/test_python_emulator.py index c2fc14273..3f01cb5ac 100644 --- a/tests/test_python_emulator.py +++ b/tests/test_python_emulator.py @@ -15,7 +15,7 @@ def test_integration_1(): .assign(ramp_time=3.0) .batch_assign(r=np.linspace(4, 10, 11).tolist()) .bloqade.python() - .run(10000, cache_matrices=True, blockade_radius=6.0) + .run(10000, cache_matrices=True, blockade_radius=6.0, interaction_picture=True) .report() .bitstrings() )