Skip to content

Commit

Permalink
resolving comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
weinbe58 committed Oct 16, 2023
1 parent 8d491ff commit f3db67d
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 51 deletions.
10 changes: 5 additions & 5 deletions src/bloqade/codegen/hardware/piecewise_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def visit_linear(self, ast: waveform.Linear) -> Tuple[List[Decimal], List[Decima
if start != stop:
raise ValueError(
"Failed to compile Waveform to piecewise constant, "
"found non-constant Linear piecce."
"found non-constant Linear piece."
)

self.append_timeseries(start, duration)
Expand All @@ -109,16 +109,16 @@ def visit_constant(
self.append_timeseries(value, duration)

def visit_poly(self, ast: waveform.Poly) -> Tuple[List[Decimal], List[Decimal]]:
order = len(ast.coeffs)
order = len(ast.coeffs) - 1
duration = ast.duration(**self.assignments)

if order == 0:
if len(ast.coeffs) == 0:
value = Decimal(0)

elif order == 1:
elif len(ast.coeffs) == 1:
value = ast.coeffs[0](**self.assignments)

elif order == 2:
elif len(ast.coeffs) == 2:
start = ast.coeffs[0](**self.assignments)
stop = start + ast.coeffs[1](**self.assignments) * duration

Expand Down
23 changes: 11 additions & 12 deletions src/bloqade/codegen/hardware/piecewise_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,27 +81,27 @@ def append(self, other: "PiecewiseLinear"):
)


def check_continiuity(left, right):
if left != right:
diff = abs(left - right)
raise ValueError(
f"discontinuity with a jump of {diff} found when compiling to "
"piecewise linear."
)


class PiecewiseLinearCodeGen(WaveformVisitor):
def __init__(self, assignments: Dict[str, Union[numbers.Real, List[numbers.Real]]]):
self.assignments = assignments
self.times = []
self.values = []

@staticmethod
def check_continiuity(left, right):
if left != right:
diff = abs(left - right)
raise ValueError(
f"discontinuity with a jump of {diff} found when compiling to "
"piecewise linear."
)

def append_timeseries(self, start, stop, duration):
if len(self.times) == 0:
self.times = [Decimal(0), duration]
self.values = [start, stop]
else:
check_continiuity(self.values[-1], start)
self.check_continiuity(self.values[-1], start)

self.times.append(duration + self.times[-1])
self.values.append(stop)
Expand Down Expand Up @@ -130,7 +130,6 @@ def visit_poly(self, ast: waveform.Poly) -> Tuple[List[Decimal], List[Decimal]]:
start = ast.coeffs[0](**self.assignments)
stop = start
elif len(ast.coeffs) == 2:
duration = ast.duration(**self.assignments)
start = ast.coeffs[0](**self.assignments)
stop = start + ast.coeffs[1](**self.assignments) * duration
else:
Expand Down Expand Up @@ -196,7 +195,7 @@ def visit_append(self, ast: waveform.Append) -> Tuple[List[Decimal], List[Decima
if new_pwl.times[-1] == Decimal(0):
continue

check_continiuity(pwl.values[-1], new_pwl.values[0])
self.check_continiuity(pwl.values[-1], new_pwl.values[0])
pwl = pwl.append(new_pwl)

self.times = pwl.times
Expand Down
79 changes: 45 additions & 34 deletions src/bloqade/codegen/hardware/quera.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,22 +73,29 @@ def append(self, other: "AHSCodegenResult") -> "AHSCodegenResult":
self.local_detuning.append(other.local_detuning),
)

@staticmethod
def convert_position_units(position):
return tuple(coordinate * Decimal("1e-6") for coordinate in position)

@staticmethod
def convert_time_units(time):
return Decimal("1e-6") * time

@staticmethod
def convert_energy_units(energy):
return Decimal("1e6") * energy

@cached_property
def braket_task_ir(self) -> BraketTaskSpecification:
import braket.ir.ahs as ir

def convert_time_units(time):
return Decimal("1e-6") * time

def convert_energy_units(energy):
return Decimal("1e6") * energy

return BraketTaskSpecification(
nshots=self.nshots,
program=ir.Program(
setup=ir.Setup(
ahs_register=ir.AtomArrangement(
sites=self.sites, filling=self.filling
sites=list(map(self.convert_position_units, self.sites)),
filling=self.filling,
)
),
hamiltonian=ir.Hamiltonian(
Expand All @@ -98,13 +105,13 @@ def convert_energy_units(energy):
time_series=ir.TimeSeries(
times=list(
map(
convert_time_units,
self.convert_time_units,
self.global_rabi_amplitude.times,
)
),
values=list(
map(
convert_energy_units,
self.convert_energy_units,
self.global_rabi_amplitude.values,
)
),
Expand All @@ -115,7 +122,7 @@ def convert_energy_units(energy):
time_series=ir.TimeSeries(
times=list(
map(
convert_time_units,
self.convert_time_units,
self.global_rabi_phase.times,
)
),
Expand All @@ -127,13 +134,13 @@ def convert_energy_units(energy):
time_series=ir.TimeSeries(
times=list(
map(
convert_time_units,
self.convert_time_units,
self.global_detuning.times,
)
),
values=list(
map(
convert_energy_units,
self.convert_energy_units,
self.global_detuning.values,
)
),
Expand All @@ -151,13 +158,13 @@ def convert_energy_units(energy):
time_series=ir.TimeSeries(
times=list(
map(
convert_time_units,
self.convert_time_units,
self.local_detuning.times,
)
),
values=list(
map(
convert_energy_units,
self.convert_energy_units,
self.local_detuning.values,
)
),
Expand All @@ -175,27 +182,25 @@ def convert_energy_units(energy):
def quera_task_ir(self) -> QuEraTaskSpecification:
import bloqade.submission.ir.task_specification as task_spec

def convert_time_units(time):
return Decimal("1e-6") * time

def convert_energy_units(energy):
return Decimal("1e6") * energy

return task_spec.QuEraTaskSpecification(
nshots=self.nshots,
lattice=task_spec.Lattice(sites=self.sites, filling=self.filling),
lattice=task_spec.Lattice(
sites=list(map(self.convert_position_units, self.sites)),
filling=self.filling,
),
effective_hamiltonian=task_spec.EffectiveHamiltonian(
rydberg=task_spec.RydbergHamiltonian(
rabi_frequency_amplitude=task_spec.RabiFrequencyAmplitude(
global_=task_spec.GlobalField(
times=list(
map(
convert_time_units, self.global_rabi_amplitude.times
self.convert_time_units,
self.global_rabi_amplitude.times,
)
),
values=list(
map(
convert_energy_units,
self.convert_energy_units,
self.global_rabi_amplitude.values,
)
),
Expand All @@ -204,30 +209,40 @@ def convert_energy_units(energy):
rabi_frequency_phase=task_spec.RabiFrequencyPhase(
global_=task_spec.GlobalField(
times=list(
map(convert_time_units, self.global_rabi_phase.times)
map(
self.convert_time_units,
self.global_rabi_phase.times,
)
),
values=self.global_rabi_phase.values,
)
),
detuning=task_spec.Detuning(
global_=task_spec.GlobalField(
times=list(
map(convert_time_units, self.global_detuning.times)
map(self.convert_time_units, self.global_detuning.times)
),
values=list(
map(convert_energy_units, self.global_detuning.values)
map(
self.convert_energy_units,
self.global_detuning.values,
)
),
),
local=(
None
if self.lattice_site_coefficients is None
else task_spec.LocalField(
times=list(
map(convert_time_units, self.local_detuning.times)
map(
self.convert_time_units,
self.local_detuning.times,
)
),
values=list(
map(
convert_energy_units, self.local_detuning.values
self.convert_energy_units,
self.local_detuning.values,
)
),
lattice_site_coefficients=self.lattice_site_coefficients,
Expand Down Expand Up @@ -303,10 +318,6 @@ def fix_up_missing_fields(self) -> None:
if self.local_detuning is None:
pass

@staticmethod
def convert_position_to_SI_units(position: Tuple[Decimal]):
return tuple(coordinate * Decimal("1e-6") for coordinate in position)

def post_visit_spatial_modulation(self, lattice_site_coefficients):
self.lattice_site_coefficients = []
if self.parallel_decoder:
Expand Down Expand Up @@ -495,7 +506,7 @@ def visit_register(self, ast: AtomArrangement):

for location_info in ast.enumerate():
site = tuple(ele(**self.assignments) for ele in location_info.position)
self.sites.append(AHSCodegen.convert_position_to_SI_units(site))
self.sites.append(site)
self.filling.append(location_info.filling.value)

self.n_atoms = len(self.sites)
Expand Down Expand Up @@ -573,7 +584,7 @@ def visit_parallel_register(self, ast: ParallelRegister) -> Any:
for cluster_location_index, (location, filled) in enumerate(
zip(new_register_locations[:], register_filling)
):
site = AHSCodegen.convert_position_to_SI_units(tuple(location))
site = tuple(location)
sites.append(site)
filling.append(filled)

Expand Down

0 comments on commit f3db67d

Please sign in to comment.