Skip to content

Commit

Permalink
Update train_toric_L8.py
Browse files Browse the repository at this point in the history
correct some comments
  • Loading branch information
sisimiao authored Apr 6, 2024
1 parent f31178f commit f43b4dc
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions train_toric_L8.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,15 +783,14 @@ def addErrorGivenWeight(n: int, w: int, batch_size: int = 1):
lr = 1
# training for fixed epsilon_0
ep0 = 0.1
#training error sampled from ep1, ep1+sep,ep1+2*sep...
ep1=0.03
num_points = 6
sep=0.01
if m==3*n:
ep0 = 0.37
ep1+=0.06
# train on errors of weight ranging from r1 to r2
r1 = 1
r2 = 5

# number of updates
n_batches = 100
# number of error patterns in each mini batch
Expand All @@ -814,12 +813,13 @@ def addErrorGivenWeight(n: int, w: int, batch_size: int = 1):
# plt.show()


# use Adam

optimizer = torch.optim.SGD([
{'params': decoder.weights_llr, 'lr': lr},
{'params': decoder.weights_vn,'lr': lr},
{'params': decoder.weights_cn,'lr': lr}
])
# could also use Adam, not making too much difference
# optimizer = torch.optim.Adam(parameters, lr=lr)
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer,start_factor=1.0, end_factor=0.1, total_iters=1200)

Expand All @@ -835,12 +835,12 @@ def addErrorGivenWeight(n: int, w: int, batch_size: int = 1):

cpp_executable = './sim_FER'
cpp_parameters = ['-d','128','2',str(m), '25', '1', '-i',str(ep0),'-r','0.15','0.015','0.015']
# pre-training stage, basically only the parameters for the first iteration is trained
# training stage
loss = torch.Tensor()

loss_pre_train = training_loop(decoder, optimizer, ep1, sep,num_points, ep0, n_batches, path, scheduler=scheduler)
loss = torch.cat((loss, loss_pre_train), dim=0)
plot_loss(loss, path)
plot_loss(loss, path) #its ok if it doesn't converge to 0
subprocess.call([cpp_executable] + cpp_parameters)


Expand Down

0 comments on commit f43b4dc

Please sign in to comment.