-
Notifications
You must be signed in to change notification settings - Fork 0
/
rrt.py
180 lines (137 loc) · 4.24 KB
/
rrt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""
Contains customized Node and RRT class for tree generation.
Classes:
Node: For the nodes in a RRT.
RRT: Functions needed to expand the RRT.
"""
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection as lc
import random
class Node:
def __init__(self, pos):
self.parent = None
self.child = []
self.pos = pos
def get_pos(self):
return self.pos
def get_child(self):
return self.child
def get_parent(self):
return self.parent
def set_child(self, child):
self.child.append(child)
def set_parent(self, parent):
self.parent = parent
def set_pos(self, pos):
self.pos = pos
class RRT:
def __init__(self, q_init, k, delt, domain):
self.q_init = q_init
self.k = k
self.delt = delt
self.domain = domain
# dictionary of nodes:
self.node_pos_dict = {}
root = Node(q_init)
self.node_list = [root]
# return the latest node in the tree
def get_latest_node(self):
return self.node_list[-1]
def get_ref_pos(self):
# random reference pos
x_pos = random.uniform(0, 100)
y_pos = random.uniform(0, 100)
ref_pos = np.array([x_pos, y_pos])
return ref_pos
def get_new_pos(self, tag_node=None):
# avoid selecting the tagged node
ref_pos = self.get_ref_pos()
nearest_node = self.find_nearest_node(ref_pos)
while nearest_node == tag_node and tag_node != None:
# print("select a new ref node")
ref_pos = self.get_ref_pos()
nearest_node = self.find_nearest_node(ref_pos)
# angle between ref and horizontal:
ref_curr_vec = ref_pos - nearest_node.get_pos()
ref_curr_dist = np.linalg.norm(ref_curr_vec)
# pos of new node:
new_pos = nearest_node.get_pos() + np.array([ref_curr_vec[0]/ref_curr_dist, ref_curr_vec[1]/ref_curr_dist])
return nearest_node, new_pos
# add random num of random nodes to current node
def expand(self, nearest_node:Node, new_pos):
# new node:
new_node = Node(new_pos)
nearest_node.set_child(new_node)
new_node.set_parent(nearest_node)
self.node_list.append(new_node)
return new_node
def find_nearest_node(self, ref_pt):
dist_dict = {}
dist_list = []
# list of dist:
for node in self.node_list:
dist = np.linalg.norm(ref_pt - node.get_pos())
dist_dict[dist] = node
dist_list.append(dist)
dist_list.sort()
nearest_node = dist_dict[dist_list[0]]
return nearest_node
def get_node_num(self):
return len(self.node_list)
def get_init_pos(self):
return self.q_init
def get_node_dict(self):
return self.node_pos_dict
def get_node_list(self):
return self.node_list
def expand_test():
"""
Test the RRT expansion.
Output:
A list of all nodes on the RRT.
"""
q_init = np.array([50, 50])
delt = 1
D = 100
K = 300
simple_rrt = RRT(q_init, K, delt, D)
# expand the tree:
node_num = simple_rrt.get_node_num()
# new_nodes_list = [curr_node]
while node_num < K:
simple_rrt.expand()
node_num = simple_rrt.get_node_num()
print("finish tree expansion")
return simple_rrt.get_node_list()
def draw_lines(line_seg):
"""
Plot the RRT.
Input:
line_seg: List of all the nodes on RRT.
Output:
NONE
"""
l_c = lc(line_seg)
_, ax = plt.subplots()
ax.set_xlim(0, 100)
ax.set_ylim(0, 100)
ax.add_collection(l_c, '-o')
plt.show()
if __name__ == "__main__":
nodes_list = expand_test()
nodes_pos_x = []
nodes_pos_y = []
line_seg = []
for node in nodes_list:
pos = node.get_pos()
parent_node_pos = (pos[0], pos[1])
nodes_pos_x.append(pos[0])
nodes_pos_y.append(pos[1])
# get a line collection
for child in node.get_child():
c_pos = child.get_pos()
child_pos = (c_pos[0], c_pos[1])
seg = [parent_node_pos, child_pos]
line_seg.append(seg)
draw_lines(line_seg)