Skip to content

Commit

Permalink
🐛 fix assign bug issue 93
Browse files Browse the repository at this point in the history
  • Loading branch information
Yonv1943 committed Apr 1, 2023
1 parent 206fd9b commit b52b05d
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 52 deletions.
18 changes: 9 additions & 9 deletions rlsolver/rlsolver_learn2opt/tensor_train/TNCO_L2O.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@

TEN = th.Tensor

# NodesList = NodesSycamoreN53M12
# NumBanEdge = 0
NodesList = get_nodes_list_of_tensor_train(len_list=100)
NumBanEdge = 100
NodesList, BanEdges = NodesSycamoreN53M12, 0
# NodesList, BanEdges = get_nodes_list_of_tensor_train(len_list=100), 100


def build_mlp(dims: [int], activation: nn = None, if_raw_out: bool = True) -> nn.Sequential:
Expand Down Expand Up @@ -164,11 +162,13 @@ def __init__(self, dim, device):
self.device = device
self.args = ()

self.env = TensorNetworkEnv(nodes_list=NodesList, device=device) # NodesSycamoreN53M12
self.dim = self.env.num_edges - NumBanEdge
print(self.dim) if self.dim != dim else None
env = TensorNetworkEnv(nodes_list=NodesList, device=device)
env.ban_edges = BanEdges
self.env = env
self.dim = env.num_edges - env.ban_edges
print(f"dim {self.dim} = num_edges {env.num_edges} - ban_edges {env.ban_edges}") if self.dim != dim else None

self.obj_model = MLP(inp_dim=self.dim, out_dim=1, dims=(256, 256, 256)).to(device)
self.obj_model = MLP(inp_dim=self.dim, out_dim=1, dims=(512, 256, 256)).to(device)

self.optimizer = th.optim.Adam(self.obj_model.parameters(), lr=1e-4)
self.criterion = nn.MSELoss()
Expand Down Expand Up @@ -225,7 +225,7 @@ def random_generate_input_output(self, warm_up_size: int = 1024, if_tqdm: bool =
thetas = th.randn((warm_up_size, self.dim), dtype=th.float32, device=self.device).clamp(-3, +3)
thetas = ((thetas - thetas.mean(dim=1, keepdim=True)) / thetas.std(dim=1, keepdim=True)).clamp(-3, +3)

thetas_iter = thetas.reshape((-1, 1024, self.dim))
thetas_iter = thetas.reshape((-1, 256, self.dim))
if if_tqdm:
from tqdm import tqdm
thetas_iter = tqdm(thetas_iter, ascii=True)
Expand Down
134 changes: 91 additions & 43 deletions rlsolver/rlsolver_learn2opt/tensor_train/TNCO_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,29 @@ def get_edges_ary(nodes_ary: TEN) -> TEN:
edges_ary[nodes_ary == -1] = -1 # -1 表示这里的 node 没有连接另一个 node

num_edges = 0
for i, nodes in enumerate(nodes_ary): # i 表示节点的编号
'''get nodes_ary'''
# for i, nodes in enumerate(nodes_ary): # i 表示节点的编号
# for j, node in enumerate(nodes): # node 表示跟编号为i的节点相连的另一个节点
# edge_i = edges_ary[i, j]
# if edge_i == -2:
# _j = th.where(nodes_ary[node] == i)
# edges_ary[i, j] = num_edges
# edges_ary[node, _j] = num_edges
# num_edges += 1
'''get nodes_ary and sort the ban edges to large indices'''
for i, nodes in list(enumerate(nodes_ary))[::-1]: # i 表示节点的编号
for j, node in enumerate(nodes): # node 表示跟编号为i的节点相连的另一个节点
edge_i = edges_ary[i, j]
if edge_i == -2:
_j = th.where(nodes_ary[node] == i)
nodes_ary_node: TEN = nodes_ary[node]
_j = th.where(nodes_ary_node == i)

edges_ary[i, j] = num_edges
edges_ary[node, _j] = num_edges
num_edges += 1
_edges_ary = edges_ary.max() - edges_ary
_edges_ary[edges_ary == -1] = -1
edges_ary = _edges_ary
return edges_ary


Expand Down Expand Up @@ -248,6 +263,7 @@ def __init__(self, nodes_list: list, device: th.device):
self.edges_ary = edges_ary.to(device)
self.num_nodes = num_nodes
self.num_edges = num_edges
self.ban_edges = None # todo not elegant

'''build for get_log10_multiple_times'''
node_dims_arys = get_node_dims_arys(nodes_ary)
Expand All @@ -256,81 +272,113 @@ def __init__(self, nodes_list: list, device: th.device):
node_bool_arys = get_node_bool_arys(nodes_ary)
assert num_nodes == sum([ary.sum() for ary in node_bool_arys])

self.node_dims_arys = [ary.to(device) for ary in node_dims_arys]
self.node_bool_arys = [ary.to(device) for ary in node_bool_arys]
self.node_dims_ten = th.stack(node_dims_arys).to(device)
self.node_bool_ten = th.stack(node_bool_arys).to(device)

def get_log10_multiple_times(self, edge_argsorts: TEN) -> TEN:
# edge_argsort = th.rand(self.num_edges).argsort()
device = self.device
edges_ary: TEN = self.edges_ary
num_envs, run_edges = edge_argsorts.shape
assert run_edges == self.num_edges - self.ban_edges
vec_env_is = th.arange(num_envs, device=device)

num_envs, num_edges = edge_argsorts.shape
node_dims_aryss = [[ary.clone() for ary in self.node_dims_arys] for _ in range(num_envs)]
node_bool_aryss = [[ary.clone() for ary in self.node_bool_arys] for _ in range(num_envs)]
node_dims_tens = th.stack([self.node_dims_ten.clone() for _ in range(num_envs)])
node_bool_tens = th.stack([self.node_bool_ten.clone() for _ in range(num_envs)])

mult_pow_timess = th.zeros((num_envs, num_edges), dtype=th.float64, device=device)
mult_pow_timess = th.zeros((num_envs, run_edges), dtype=th.float64, device=device)

vec_env_is = th.arange(num_envs, device=device)
for i in range(num_edges):
for i in range(run_edges):
edge_is = edge_argsorts[:, i]
# [edge_i for edge_i in edge_is]

"""find two nodes of an edge_i"""
# node_i0, node_i1 = th.where(edges_ary == edge_i)[0] # 找出这条edge 两端的node
# assert isinstance(node_i0.item(), int)
# assert isinstance(node_i1.item(), int)
"""Vanilla (single)"""
# for j in range(num_envs):
# edge_i = edge_is[j]
# node_dims_arys = node_dims_tens[j]
# node_bool_arys = node_bool_tens[j]
#
# '''find two nodes of an edge_i'''
# node_i0, node_i1 = th.where(edges_ary == edge_i)[0] # 找出这条edge 两端的node
# # assert isinstance(node_i0.item(), int)
# # assert isinstance(node_i1.item(), int)
#
# '''calculate the multiple and avoid repeat'''
# contract_dims = node_dims_arys[node_i0] + node_dims_arys[node_i1] # 计算收缩后的node 的邻接张量的维度 以及来源
# contract_bool = node_bool_arys[node_i0] | node_bool_arys[node_i1] # 计算收缩后的node 由哪些原初node 合成
# # assert contract_dims.shape == (num_nodes, )
# # assert contract_bool.shape == (num_nodes, )
#
# # 收缩掉的edge 只需要算一遍乘法。因此上面对 两次重复的指数求和后乘以0.5
# mult_pow_time = contract_dims.sum(dim=0) - (contract_dims * contract_bool).sum(dim=0) * 0.5
# # assert mult_pow_time.shape == (1, )
# mult_pow_timess[j, i] = mult_pow_time
#
# '''adjust two list: node_dims_arys, node_bool_arys'''
# contract_dims[contract_bool] = 0 # 把收缩掉的边的乘法数量赋值为2**0,接下来不再参与乘法次数的计算
#
# node_dims_arys[contract_bool] = contract_dims.repeat(1, 1) # 根据 bool 将所有收缩后的节点都刷新成相同的信息
# node_bool_arys[contract_bool] = contract_bool.repeat(1, 1) # 根据 bool 将所有收缩后的节点都刷新成相同的信息
#
# # print('\n;;;', i, edge_i, node_i0, node_i1)
# # [print(ary) for ary in node_dims_arys[:-self.ban_edges]]
# # [print(ary.int()) for ary in node_bool_arys[:-self.ban_edges]]

"""Vectorized"""
'''find two nodes of an edge_i'''
vec_edges_ary: TEN = edges_ary[None, :, :]
vec_edges_is: TEN = edge_is[:, None, None]
res = th.where(vec_edges_ary == vec_edges_is)[1]
res = res.reshape((num_envs, 2))
node_i0s, node_i1s = res[:, 0], res[:, 1]
# assert node_i0s.shape == (num_envs, )
# assert node_i1s.shape == (num_envs, )

node_dims_ten = th.stack([th.stack(arys) for arys in node_dims_aryss])
node_bool_ten = th.stack([th.stack(arys) for arys in node_bool_aryss])

# contract_dims = node_dims_arys[node_i0] + node_dims_arys[node_i1] # 计算收缩后的node 的邻接张量的维度 以及来源
contract_dimss = node_dims_ten[vec_env_is, node_i0s] + node_dims_ten[vec_env_is, node_i1s]

# contract_bool = node_bool_arys[node_i0] | node_bool_arys[node_i1] # 计算收缩后的node 由哪些原初node 合成
contract_bools = node_bool_ten[vec_env_is, node_i0s] + node_bool_ten[vec_env_is, node_i1s]
'''calculate the multiple and avoid repeat'''
contract_dimss = node_dims_tens[vec_env_is, node_i0s] + node_dims_tens[vec_env_is, node_i1s]
contract_bools = node_bool_tens[vec_env_is, node_i0s] | node_bool_tens[vec_env_is, node_i1s]
# assert contract_dimss.shape == (num_envs, num_nodes)
# assert contract_bools.shape == (num_envs, num_nodes)

# 收缩掉的edge 只需要算一遍乘法。因此上面对 两次重复的指数求和后乘以0.5
mult_pow_times = contract_dimss.sum(dim=1) - (contract_dimss * contract_bools).sum(dim=1) * 0.5

# assert mult_pow_times.shape == (num_envs,)
# assert mult_pow_times.shape == (num_envs, )
mult_pow_timess[:, i] = mult_pow_times

for j in range(num_envs):
node_i0 = node_i0s[j]
node_i1 = node_i1s[j]
node_dims_arys = node_dims_aryss[j]
node_bool_arys = node_bool_aryss[j]
'''adjust two list: node_dims_arys, node_bool_arys'''
contract_dimss[contract_bools] = 0 # 把收缩掉的边的乘法数量赋值为2**0,接下来不再参与乘法次数的计算

for j in range(num_envs): # 根据 bool 将所有收缩后的节点都刷新成相同的信息
contract_dims = contract_dimss[j]
contract_bool = contract_bools[j]
node_dims_tens[j, contract_bool] = contract_dims.repeat(1, 1)
node_bool_tens[j, contract_bool] = contract_bool.repeat(1, 1)

contract_dimss[j][contract_bools[j]] = 0 # 把收缩掉的边的乘法数量赋值为1,接下来不再参与乘法次数的计算
node_dims_arys[node_i0] = node_dims_arys[node_i1] = contract_dimss[j] # 让收缩前的两个node 指向收缩后的node
node_bool_arys[node_i0] = node_bool_arys[node_i1] = contract_bools[j] # 让收缩前的两个node 指向收缩后的node
# print('\n;;;', i, )
# env_i = 0
# [print(ary) for ary in node_dims_tens[env_i, :-self.ban_edges]]
# [print(ary.int()) for ary in node_bool_tens[env_i, :-self.ban_edges]]

temp_power = 10 # 计算这个乘法个数时,即便用 float64 也偶尔会过拟合,所以先除以 2**temp_power ,求log10 后再恢复它
multiple_times = (2 ** (mult_pow_timess - temp_power)).sum(dim=1)
multiple_times = multiple_times.log10() + th.log10(th.tensor(2 ** temp_power))
return multiple_times.detach()

"""
L2O_H_term.py", line 463, in opt_train
gradients = p.grad.view(hid_dim, 1).detach().clone().requires_grad_(True)
"""


def run():
gpu_id = int(sys.argv[1]) if len(sys.argv) > 1 else 0
device = th.device(f'cuda:{gpu_id}' if th.cuda.is_available() and gpu_id >= 0 else 'cpu')

env = TensorNetworkEnv(nodes_list=NodesSycamoreN53M12, device=device)
# nodes_list, ban_edges = NodesSycamoreN53M12, 0
# nodes_list, ban_edges = get_nodes_list_of_tensor_train(len_list=8), 8
nodes_list, ban_edges = get_nodes_list_of_tensor_train(len_list=100), 100

env = TensorNetworkEnv(nodes_list=nodes_list, device=device)
env.ban_edges = ban_edges
print(f"\nnum_nodes {env.num_nodes:9}"
f"\nnum_edges {env.num_edges:9}")
num_envs = 8
f"\nnum_edges {env.num_edges:9}"
f"\nban_edges {env.ban_edges:9}")
num_envs = 32

edge_arys = th.rand((num_envs, env.num_edges), device=device)
edge_arys = th.rand((num_envs, env.num_edges - env.ban_edges), device=device)
# th.save(edge_arys, 'temp.pth')
# edge_arys = th.load('temp.pth', map_location=device)

Expand Down

0 comments on commit b52b05d

Please sign in to comment.