Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 find bug in TNCO env and the explain this env data struct #93

Open
Yonv1943 opened this issue Mar 31, 2023 · 2 comments
Open

🐛 find bug in TNCO env and the explain this env data struct #93

Yonv1943 opened this issue Mar 31, 2023 · 2 comments
Labels
bug Something isn't working

Comments

@Yonv1943
Copy link
Collaborator

Yonv1943 commented Mar 31, 2023

张量收缩计算图解.pptx


#92

已经提交PR 92 去修复此BUG


https://github.com/AI4Finance-Foundation/ElegantRL_Solver/blob/52b4dc3ac5b8461772751a7294f5c9c10fdba5a5/rlsolver/rlsolver_learn2opt/tensor_train/TNCO_env.py#L269-L271

上面的代码有bug,这里想要修改 list中存放的指针的指向,但是最右边的等号是一个【赋值】操作,它让指针重新指向新的地址,这是不正确的。

应该修改成

node_dims_tens = th.stack([self.node_dims_ten.clone() for _ in range(num_envs)])
node_bool_tens = th.stack([self.node_bool_ten.clone() for _ in range(num_envs)])
for i in range(run_edges):
    ...
    for j in range(num_envs):
        ...
        node_dims_arys = node_dims_tens[j]
        node_bool_arys = node_bool_tens[j]
        ...
        node_dims_arys[contract_bool] = contract_dims.repeat(1, 1)  # 根据 bool 将所有收缩后的节点都刷新成相同的信息
        node_bool_arys[contract_bool] = contract_bool.repeat(1, 1)  # 根据 bool 将所有收缩后的节点都刷新成相同的信息
        ...

举例:
初始化 arys = [torch.zeros(2) + i for i in range(5)],打印 arys

[tensor([0., 0.]),
 tensor([1., 1.]),
 tensor([2., 2.]),
 tensor([3., 3.]),
 tensor([4., 4.])]

修改指针指向 arys[0] = arys[1] = arys[2] = torch.zeros(2) -1,打印 arys

[tensor([-1., -1.]),
 tensor([-1., -1.]),
 tensor([-1., -1.]),
 tensor([3., 3.]),
 tensor([4., 4.])]

修改 arys[0],使用赋值操作 arys[0] = torch.zeros(2) + 0。重新打印 arys,确认指针指向是否正确

[tensor([0., 0.]),       -----> 我把 arys[0] 从-1 改成了 0
 tensor([-1., -1.]),    -----> 发现 arys[1] 没有跟着 arys[0] 一起变成 0,这是错误的
 tensor([-1., -1.]),    -----> 发现 arys[2] 没有跟着 arys[0] 一起变成 0,这是错误的
 tensor([3., 3.]),
 tensor([4., 4.])]

不能使用赋值操作改变指针指向的地址,应该用 arys[0][:] = torch.zeros(2) + 0
重新运行命令 修改指针指向 arys[0] = arys[1] = arys[2] = torch.zeros(2) -1
重新运行命令 修改 arys[0] 的赋值 arys[0][:] = torch.zeros(2) + 0
得到预期中的结果:

[tensor([0., 0.]),
 tensor([0., 0.]),
 tensor([0., 0.]),
 tensor([3., 3.]),
 tensor([4., 4.])]
@Yonv1943 Yonv1943 added the bug Something isn't working label Mar 31, 2023
@Yonv1943
Copy link
Collaborator Author

Yonv1943 commented Mar 31, 2023

我想要对一个列表进行一种操作,例如:
我的列表是 s = [0, 1, 2, 3, 4, 5]
我想要按照某个任意给定的顺序,对它们进行求和,例如:

第1个给定的顺序是 (0, 1), 那么我取列表s 索引为0 和 1 的两个元素进行相加,0+1 得到的结果 1 赋值给这两个元素,相当于我把这两个元素合并成一个元素了。

第2个给定的顺序是 (2, 3), 那么我取列表s 索引为2 和 3 的两个元素进行相加,2+3 得到的结果 5 赋值给这两个元素,相当于我把这两个元素合并成一个元素了。

第3个给定的顺序是 (0, 2), 那么我取列表s 索引为0 和 2 的两个元素进行相加,注意,由于这两个节点都被合并成新的元素了,因此不是计算 0+2,而是计算 1+5 得到的结果 6 赋值给这四个元素,相当于我把初始的这四个元素合并成一个元素了。

第4个给定的顺序是 (3, 4), 那么我取列表s 索引为3 和 4 的两个元素进行相加,注意,由于索引为3的元素被合并成新的元素了,因此不是计算 3+4,而是计算 6+4 得到的结果 10 赋值给这几个元素,相当于我把初始的这几个元素合并成一个元素了。

请帮我写成代码


下面是ChatGPT给出的答案:

s = [0, 1, 2, 3, 4, 5]

# 定义一个字典,存储每个元素的父节点,初始时每个元素的父节点都是它自己
parents = {i: i for i in range(len(s))}

# 定义一个函数,用于查找某个节点的根节点
def find_root(i):
    if parents[i] != i:
        parents[i] = find_root(parents[i])
    return parents[i]

# 定义一个函数,用于合并两个节点所在的集合
def merge_set(i, j):
    root_i = find_root(i)
    root_j = find_root(j)
    parents[root_i] = root_j

# 定义一个函数,用于对列表s按照给定顺序进行求和并合并节点
def sum_and_merge(s, order):
    sum = s[order[0]] + s[order[1]]
    merge_set(order[0], order[1])
    for i in range(len(s)):
        if find_root(i) == find_root(order[0]):
            s[i] = sum

# 按照给定的顺序进行求和并合并节点
sum_and_merge(s, (0, 1))
sum_and_merge(s, (2, 3))
sum_and_merge(s, (0, 2))
sum_and_merge(s, (3, 4))

# 打印最终结果
print(s)

最后果然得到想要的结果(修改了print 的位置):

[1, 1, 2, 3, 4, 5]
[1, 1, 5, 5, 4, 5]
[6, 6, 6, 6, 4, 5]
[10, 10, 10, 10, 10, 5]

根据 ChatGPT 的建议,在PyTorch实现相同的操作更简单,有:

import torch as th
a = th.tensor((0,1,0,1)).bool()  # tensor([False,  True, False,  True])

b = th.rand(4, 6)
'''
tensor([[0.4802, 0.3409, 0.3359, 0.5106, 0.1429, 0.2140],
        [0.2697, 0.3475, 0.3719, 0.4058, 0.6563, 0.5772],
        [0.6439, 0.0933, 0.7830, 0.0989, 0.2740, 0.2559],
        [0.1416, 0.6473, 0.4084, 0.0258, 0.5571, 0.2749]])
'''

b[a, :] = b[0].repeat(2, 1)
b[a, :] = b[0].repeat(1, 1)  # use `.repeat(1, 1)` it is OK.
'''
tensor([[0.4802, 0.3409, 0.3359, 0.5106, 0.1429, 0.2140],  # --> b[0]
        [0.4802, 0.3409, 0.3359, 0.5106, 0.1429, 0.2140],  # --> b[1] = b[0]
        [0.6439, 0.0933, 0.7830, 0.0989, 0.2740, 0.2559],
        [0.4802, 0.3409, 0.3359, 0.5106, 0.1429, 0.2140]])  # --> b[3] = b[0]
'''

@Yonv1943
Copy link
Collaborator Author

Yonv1943 commented Apr 1, 2023

image

image

image

image

image

image

image

image

image

Yonv1943 added a commit to Yonv1943/ElegantRL_Solver that referenced this issue Apr 1, 2023
@AI4Finance-Foundation AI4Finance-Foundation deleted a comment from Yonv1943 Apr 7, 2023
@Yonv1943 Yonv1943 changed the title 🐛 find bug in TNCO env 🐛 find bug in TNCO env and the explain this env data struct Apr 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant