From 56032c115311c8a5e7f438c76cc4c85d307dd88b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?YonV1943=20=E6=9B=BE=E4=BC=8A=E8=A8=80?= <37322666+Yonv1943@users.noreply.github.com> Date: Thu, 28 Sep 2023 16:55:00 +0800 Subject: [PATCH] stable version --- helloworld/maxcut/graph_max_cut_mcpg.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/helloworld/maxcut/graph_max_cut_mcpg.py b/helloworld/maxcut/graph_max_cut_mcpg.py index 0d2abdb..84fc12a 100644 --- a/helloworld/maxcut/graph_max_cut_mcpg.py +++ b/helloworld/maxcut/graph_max_cut_mcpg.py @@ -58,7 +58,6 @@ def sampler_func(data, simulator, xs_sample, # local search for cnt in range(num_ls): - for node_index in range(0, num_nodes): node0_id = data.sorted_degree_nodes[node_index] node1_ids = data.neighbors[node0_id] @@ -72,7 +71,7 @@ def sampler_func(data, simulator, xs_sample, expected_cut_reshape = expected_cut.reshape((-1, total_mcmc_num)) index = torch.argmin(expected_cut_reshape, dim=0) - index = torch.arange(total_mcmc_num) + index * total_mcmc_num + index = torch.arange(total_mcmc_num, device=device) + index * total_mcmc_num max_cut = expected_cut[index] temp_max = (num_edges - max_cut) / 2 # todo @@ -223,16 +222,16 @@ def run(): repeat_times = 128 num_ls = 8 - reset_epoch_num = 128 - total_mcmc_num = 512 - path = 'data/gset_14.txt' + # reset_epoch_num = 128 + # total_mcmc_num = 512 + # path = 'data/gset_14.txt' # path = 'data/gset_15.txt' # path = 'data/gset_49.txt' # path = 'data/gset_50.txt' - # reset_epoch_num = 192 - # total_mcmc_num = 224 - # path = 'data/gset_22.txt' + reset_epoch_num = 192 + total_mcmc_num = 224 + path = 'data/gset_22.txt' # reset_epoch_num = 128 # total_mcmc_num = 256