From 512ec0dc47ac477251cf2d9b716a170cea425a7a Mon Sep 17 00:00:00 2001 From: Andeshog <157955984+Andeshog@users.noreply.github.com> Date: Sun, 17 Mar 2024 16:33:57 +0100 Subject: [PATCH] Fixed testing --- guidance/d_star_lite/CMakeLists.txt | 3 +- guidance/d_star_lite/d_star_lite/__init__.py | 2 - .../d_star_lite/d_star_lite/d_star_lite.py | 202 +++++++++--------- .../d_star_lite/d_star_lite_node.py | 8 +- guidance/d_star_lite/test/test_d_star_lite.py | 82 ------- .../d_star_lite/tests/test_d_star_lite.py | 82 +++++++ 6 files changed, 188 insertions(+), 191 deletions(-) delete mode 100644 guidance/d_star_lite/test/test_d_star_lite.py create mode 100644 guidance/d_star_lite/tests/test_d_star_lite.py diff --git a/guidance/d_star_lite/CMakeLists.txt b/guidance/d_star_lite/CMakeLists.txt index 1130c462..a1c1e0d0 100644 --- a/guidance/d_star_lite/CMakeLists.txt +++ b/guidance/d_star_lite/CMakeLists.txt @@ -25,10 +25,9 @@ install(PROGRAMS if(BUILD_TESTING) find_package(ament_lint_auto REQUIRED) find_package(ament_cmake_pytest REQUIRED) - ament_add_pytest_test(test_python test) + ament_add_pytest_test(python_tests tests) set(ament_cmake_copyright_FOUND TRUE) set(ament_cmake_cpplint_FOUND TRUE) - ament_lint_auto_find_test_dependencies() endif() ament_package() \ No newline at end of file diff --git a/guidance/d_star_lite/d_star_lite/__init__.py b/guidance/d_star_lite/d_star_lite/__init__.py index 0879f0ab..e69de29b 100644 --- a/guidance/d_star_lite/d_star_lite/__init__.py +++ b/guidance/d_star_lite/d_star_lite/__init__.py @@ -1,2 +0,0 @@ -from .d_star_lite import DStarLite -from .d_star_lite import Node \ No newline at end of file diff --git a/guidance/d_star_lite/d_star_lite/d_star_lite.py b/guidance/d_star_lite/d_star_lite/d_star_lite.py index 156a136c..1653abd3 100755 --- a/guidance/d_star_lite/d_star_lite/d_star_lite.py +++ b/guidance/d_star_lite/d_star_lite/d_star_lite.py @@ -4,66 +4,66 @@ # Link to the original code: # https://github.com/AtsushiSakai/PythonRobotics/blob/master/PathPlanning/DStarLite/d_star_lite.py -class Node: +class DSLNode: """ - Represents a node in the grid. + Represents a DSLNode in the grid. Attributes: - x (int): The x-coordinate of the node. - y (int): The y-coordinate of the node. - cost (float): The cost of moving to the node. + x (int): The x-coordinate of the DSLNode. + y (int): The y-coordinate of the DSLNode. + cost (float): The cost of moving to the DSLNode. """ def __init__(self, x: int = 0, y: int = 0, cost: float = 0.0): """ - Initializes a new instance of the Node class. + Initializes a new instance of the DSLNode class. Args: - x (int): The x-coordinate of the node. Defaults to 0. - y (int): The y-coordinate of the node. Defaults to 0. - cost (float): The cost of moving to the node. Defaults to 0.0. + x (int): The x-coordinate of the DSLNode. Defaults to 0. + y (int): The y-coordinate of the DSLNode. Defaults to 0. + cost (float): The cost of moving to the DSLNode. Defaults to 0.0. """ self.x = x self.y = y self.cost = cost -def combine_nodes(node1: Node, node2: Node) -> Node: +def combine_nodes(node1: DSLNode, node2: DSLNode) -> DSLNode: """ - Combines two Node objects by summing their x and y coordinates and their costs. + Combines two DSLNode objects by summing their x and y coordinates and their costs. Args: - node1 (Node): The first node to combine. - node2 (Node): The second node to combine. + node1 (DSLNode): The first DSLNode to combine. + node2 (DSLNode): The second DSLNode to combine. Returns: - Node: A new Node object with the combined x and y coordinates and costs. + DSLNode: A new DSLNode object with the combined x and y coordinates and costs. """ - new_node = Node() + new_node = DSLNode() new_node.x = node1.x + node2.x new_node.y = node1.y + node2.y new_node.cost = node1.cost + node2.cost return new_node -def compare_coordinates(node1: Node, node2: Node) -> bool: +def compare_coordinates(node1: DSLNode, node2: DSLNode) -> bool: """ - Checks if two Node objects have the same x and y coordinates. + Checks if two DSLNode objects have the same x and y coordinates. Args: - node1 (Node): The first node to compare. - node2 (Node): The second node to compare. + node1 (DSLNode): The first DSLNode to compare. + node2 (DSLNode): The second DSLNode to compare. Returns: bool: True if the nodes have the same x and y coordinates, False otherwise. """ return node1.x == node2.x and node1.y == node2.y -def distance(node1: Node, node2: Node) -> float: +def distance(node1: DSLNode, node2: DSLNode) -> float: """ - Computes the Euclidean distance between two Node objects. + Computes the Euclidean distance between two DSLNode objects. Args: - node1 (Node): The first node. - node2 (Node): The second node. + node1 (DSLNode): The first DSLNode. + node2 (DSLNode): The second DSLNode. Returns: float: The Euclidean distance between the two nodes. @@ -74,7 +74,7 @@ class DStarLite: """ Implements the D* Lite algorithm for path planning in a grid. - This class manages the pathfinding grid, obstacles and calculates the shortest path from a start node to a goal node. + This class manages the pathfinding grid, obstacles and calculates the shortest path from a start DSLNode to a goal DSLNode. Methods: ------------------------------------------------------------------------------------------- @@ -82,43 +82,43 @@ class DStarLite: create_grid(val: float) -> np.ndarray: Creates a grid initialized with a specific value. - is_obstacle(node: Node) -> bool: Check if the node is considered an obstacle or is too close to an obstacle. + is_obstacle(DSLNode: DSLNode) -> bool: Check if the DSLNode is considered an obstacle or is too close to an obstacle. - movement_cost(node1: Node, node2: Node) -> float: Calculates the cost of moving from node1 to node2. + movement_cost(node1: DSLNode, node2: DSLNode) -> float: Calculates the cost of moving from node1 to node2. - heuristic_distance(s: Node) -> float: Calculates the heuristic distance from node s to the goal using the Euclidean distance. + heuristic_distance(s: DSLNode) -> float: Calculates the heuristic distance from DSLNode s to the goal using the Euclidean distance. - calculate_key(s: Node) -> tuple: Calculates the priority key for a node 's' based on the D* Lite algorithm. + calculate_key(s: DSLNode) -> tuple: Calculates the priority key for a DSLNode 's' based on the D* Lite algorithm. - is_valid(node: Node) -> bool: Determines if a node is within the grid boundaries. + is_valid(DSLNode: DSLNode) -> bool: Determines if a DSLNode is within the grid boundaries. - get_neighbours(u: Node) -> list[Node]: Generates a list of valid neighbours of a node 'u'. + get_neighbours(u: DSLNode) -> list[DSLNode]: Generates a list of valid neighbours of a DSLNode 'u'. - pred(u: Node) -> list[Node]: Retrieves the predecessors of a node 'u'. + pred(u: DSLNode) -> list[DSLNode]: Retrieves the predecessors of a DSLNode 'u'. - initialize(start: Node, goal: Node): Initializes the grid and the D* Lite algorithm. + initialize(start: DSLNode, goal: DSLNode): Initializes the grid and the D* Lite algorithm. - update_vertex(u: Node): Updates the vertex in the priority queue and the rhs value of the node 'u'. + update_vertex(u: DSLNode): Updates the vertex in the priority queue and the rhs value of the DSLNode 'u'. - get_direction(node1: Node, node2: Node) -> tuple: Calculates the direction from node1 to node2. + get_direction(node1: DSLNode, node2: DSLNode) -> tuple: Calculates the direction from node1 to node2. - detect_and_update_waypoints(current_point: Node, next_point: Node): Updates the waypoints based on the current and next points. + detect_and_update_waypoints(current_point: DSLNode, next_point: DSLNode): Updates the waypoints based on the current and next points. compare_keys(key_pair1: tuple[float, float], key_pair2: tuple[float, float]) -> bool: Compares the priority keys of two nodes. compute_shortest_path(): Computes or recomputes the shortest path from the start to the goal using the D* Lite algorithm. - compute_current_path() -> list[Node]: Computes the current path from the start to the goal. + compute_current_path() -> list[DSLNode]: Computes the current path from the start to the goal. get_WP() -> list[list[int]]: Retrieves the waypoints and adjusts their coordinates to the original coordinate system. - dsl_main(start: Node, goal: Node) -> tuple[bool, list[int], list[int]]: Main function to run the D* Lite algorithm. + dsl_main(start: DSLNode, goal: DSLNode) -> tuple[bool, list[int], list[int]]: Main function to run the D* Lite algorithm. """ motions = [ - Node(1, 0, 1), Node(0, 1, 1), Node(-1, 0, 1), Node(0, -1, 1), - Node(1, 1, math.sqrt(2)), Node(1, -1, math.sqrt(2)), - Node(-1, 1, math.sqrt(2)), Node(-1, -1, math.sqrt(2)) + DSLNode(1, 0, 1), DSLNode(0, 1, 1), DSLNode(-1, 0, 1), DSLNode(0, -1, 1), + DSLNode(1, 1, math.sqrt(2)), DSLNode(1, -1, math.sqrt(2)), + DSLNode(-1, 1, math.sqrt(2)), DSLNode(-1, -1, math.sqrt(2)) ] def __init__(self, ox: list, oy: list, dist_to_obstacle: float = 4.5): @@ -128,18 +128,18 @@ def __init__(self, ox: list, oy: list, dist_to_obstacle: float = 4.5): Args: ox (list): The x-coordinates of the obstacles. oy (list): The y-coordinates of the obstacles. - dist_to_obstacle (float): The minimum distance a node must be from any obstacle to be considered valid. Defaults to 4.5. + dist_to_obstacle (float): The minimum distance a DSLNode must be from any obstacle to be considered valid. Defaults to 4.5. """ self.x_min_world = int(min(ox)) # The minimum x and y coordinates of the grid self.y_min_world = int(min(oy)) # The minimum x and y coordinates of the grid self.x_max = int(abs(max(ox) - self.x_min_world)) # The maximum x and y coordinates of the grid self.y_max = int(abs(max(oy) - self.y_min_world)) # The maximum x and y coordinates of the grid - self.obstacles = [Node(x - self.x_min_world, y - self.y_min_world) for x, y in zip(ox, oy)] # The obstacles + self.obstacles = [DSLNode(x - self.x_min_world, y - self.y_min_world) for x, y in zip(ox, oy)] # The obstacles self.obstacles_xy = np.array( # Numpy array for of obstacle coordinates [[obstacle.x, obstacle.y] for obstacle in self.obstacles] ) - self.start = Node(0, 0) # The start node - self.goal = Node(0, 0) # The goal node + self.start = DSLNode(0, 0) # The start DSLNode + self.goal = DSLNode(0, 0) # The goal DSLNode self.U = [] # Priority queue self.km = 0.0 # The minimum key in U self.kold = 0.0 # The old minimum key in U @@ -147,7 +147,7 @@ def __init__(self, ox: list, oy: list, dist_to_obstacle: float = 4.5): self.g = self.create_grid(float("inf")) # The g values self.initialized = False # Whether the grid has been initialized self.WP = [] # The waypoints - self.dist_to_obstacle = dist_to_obstacle # The minimum distance a node must be from any obstacle to be considered valid + self.dist_to_obstacle = dist_to_obstacle # The minimum distance a DSLNode must be from any obstacle to be considered valid def create_grid(self, val: float) -> np.ndarray: """ @@ -161,32 +161,32 @@ def create_grid(self, val: float) -> np.ndarray: """ return np.full((self.x_max, self.y_max), val) - def is_obstacle(self, node: Node) -> bool: + def is_obstacle(self, DSLNode: DSLNode) -> bool: """ - Check if the node is considered an obstacle or is too close to an obstacle. + Check if the DSLNode is considered an obstacle or is too close to an obstacle. Args: - node (Node): The node to check. + DSLNode (DSLNode): The DSLNode to check. Returns: - bool: True if the node is too close to an obstacle or is an obstacle, False otherwise. + bool: True if the DSLNode is too close to an obstacle or is an obstacle, False otherwise. """ - # Convert the node's coordinates to a numpy array for efficient distance computation - node_xy = np.array([node.x, node.y]) + # Convert the DSLNode's coordinates to a numpy array for efficient distance computation + node_xy = np.array([DSLNode.x, DSLNode.y]) - # Compute the euclidean distances from the node to all obstacles + # Compute the euclidean distances from the DSLNode to all obstacles distances = np.sqrt(np.sum((self.obstacles_xy - node_xy) ** 2, axis=1)) # Check if any distance is less than the minimum distance (default: 4.5) return np.any(distances < self.dist_to_obstacle) - def movement_cost(self, node1: Node, node2: Node) -> float: + def movement_cost(self, node1: DSLNode, node2: DSLNode) -> float: """ Calculates the cost of moving from node1 to node2. Returns infinity if node2 is an obstacle or if no valid motion is found. Args: - node1 (Node): The starting node. - node2 (Node): The ending node. + node1 (DSLNode): The starting DSLNode. + node2 (DSLNode): The ending DSLNode. Returns: float: The cost of moving from node1 to node2. @@ -194,27 +194,27 @@ def movement_cost(self, node1: Node, node2: Node) -> float: if self.is_obstacle(node2): return math.inf - movement_vector = Node(node1.x - node2.x, node1.y - node2.y) + movement_vector = DSLNode(node1.x - node2.x, node1.y - node2.y) for motion in self.motions: if compare_coordinates(motion, movement_vector): return motion.cost return math.inf - def heuristic_distance(self, s: Node) -> float: + def heuristic_distance(self, s: DSLNode) -> float: """ - Calculates the heuristic distance from node s to the goal using the Euclidean distance. + Calculates the heuristic distance from DSLNode s to the goal using the Euclidean distance. Args: - s (Node): The node to calculate the heuristic distance from. + s (DSLNode): The DSLNode to calculate the heuristic distance from. Returns: - float: The heuristic distance from node s to the goal. + float: The heuristic distance from DSLNode s to the goal. """ return distance(s, self.goal) # Euclidean distance - def calculate_key(self, s: Node) -> tuple: + def calculate_key(self, s: DSLNode) -> tuple: """ - Calculates the priority key for a node 's' based on the D* Lite algorithm. + Calculates the priority key for a DSLNode 's' based on the D* Lite algorithm. The key is a tuple consisting of two parts: 1. The estimated total cost from the start to the goal through 's', combining the minimum of g(s) and rhs(s), @@ -222,51 +222,51 @@ def calculate_key(self, s: Node) -> tuple: 2. The minimum of g(s) and rhs(s) representing the best known cost to reach 's'. Args: - s (Node): The node to calculate the key for. + s (DSLNode): The DSLNode to calculate the key for. Returns: - tuple: A tuple of two floats representing the priority key for the node. + tuple: A tuple of two floats representing the priority key for the DSLNode. """ return (min(self.g[s.x][s.y], self.rhs[s.x][s.y]) + self.heuristic_distance(s) + self.km, min(self.g[s.x][s.y], self.rhs[s.x][s.y])) - def is_valid(self, node: Node) -> bool: + def is_valid(self, DSLNode: DSLNode) -> bool: """ - Determines if a node is within the grid boundaries. + Determines if a DSLNode is within the grid boundaries. Args: - node (Node): The node to check. + DSLNode (DSLNode): The DSLNode to check. Returns: - bool: True if the node is within the grid boundaries, False otherwise. + bool: True if the DSLNode is within the grid boundaries, False otherwise. """ - return 0 <= node.x < self.x_max and 0 <= node.y < self.y_max + return 0 <= DSLNode.x < self.x_max and 0 <= DSLNode.y < self.y_max - def get_neighbours(self, u: Node) -> list[Node]: + def get_neighbours(self, u: DSLNode) -> list[DSLNode]: """ - Generates a list of valid neighbours of a node 'u'. + Generates a list of valid neighbours of a DSLNode 'u'. Args: - u (Node): The node to generate neighbours for. + u (DSLNode): The DSLNode to generate neighbours for. Returns: - list: A list of valid neighbours of the node 'u'. + list: A list of valid neighbours of the DSLNode 'u'. """ return [combine_nodes(u, motion) for motion in self.motions if self.is_valid(combine_nodes(u, motion))] - def pred(self, u: Node) -> list[Node]: + def pred(self, u: DSLNode) -> list[DSLNode]: """ - Retrieves the predecessors of a node 'u'. In this case, the predecessors are the neighbours of the node. + Retrieves the predecessors of a DSLNode 'u'. In this case, the predecessors are the neighbours of the DSLNode. Args: - u (Node): The node to retrieve predecessors for. + u (DSLNode): The DSLNode to retrieve predecessors for. Returns: - list: A list of predecessors of the node 'u'. + list: A list of predecessors of the DSLNode 'u'. """ return self.get_neighbours(u) - def initialize(self, start: Node, goal: Node): + def initialize(self, start: DSLNode, goal: DSLNode): """ Initializes the grid and the D* Lite algorithm. This function adjusts the coordinates of the start and goal nodes based on the grid's minimum world coordinates, @@ -275,8 +275,8 @@ def initialize(self, start: Node, goal: Node): function will have no effect. Args: - start (Node): The start node. - goal (Node): The goal node. + start (DSLNode): The start DSLNode. + goal (DSLNode): The goal DSLNode. """ self.start.x = start.x - self.x_min_world self.start.y = start.y - self.y_min_world @@ -292,23 +292,23 @@ def initialize(self, start: Node, goal: Node): self.rhs[self.goal.x][self.goal.y] = 0 self.U.append((self.goal, self.calculate_key(self.goal))) - def update_vertex(self, u: Node): + def update_vertex(self, u: DSLNode): """ - Updates the vertex in the priority queue and the rhs value of the node 'u'. + Updates the vertex in the priority queue and the rhs value of the DSLNode 'u'. - This method adjusts the right-hand side (rhs) value for a node unless it's the goal. It also ensures that the - node's priority in the queue reflects its current g and rhs values, reordering the queue as necessary. + This method adjusts the right-hand side (rhs) value for a DSLNode unless it's the goal. It also ensures that the + DSLNode's priority in the queue reflects its current g and rhs values, reordering the queue as necessary. Args: - u (Node): The node to update. + u (DSLNode): The DSLNode to update. """ if not compare_coordinates(u, self.goal): self.rhs[u.x][u.y] = min([self.movement_cost(u, sprime) + self.g[sprime.x][sprime.y] for sprime in self.pred(u)]) # Update the priority queue - if any([compare_coordinates(u, node) for node, key in self.U]): - self.U = [(node, key) for node, key in self.U if not compare_coordinates(node, u)] + if any([compare_coordinates(u, DSLNode) for DSLNode, key in self.U]): + self.U = [(DSLNode, key) for DSLNode, key in self.U if not compare_coordinates(DSLNode, u)] self.U.sort(key=lambda x: x[1]) if self.g[u.x][u.y] != self.rhs[u.x][u.y]: self.U.append((u, self.calculate_key(u))) @@ -316,13 +316,13 @@ def update_vertex(self, u: Node): # Resort the priority queue self.U.sort(key=lambda x: x[1]) - def get_direction(self, node1: Node, node2: Node) -> tuple: + def get_direction(self, node1: DSLNode, node2: DSLNode) -> tuple: """ Calculates the direction from node1 to node2. Args: - node1 (Node): The starting node. - node2 (Node): The ending node. + node1 (DSLNode): The starting DSLNode. + node2 (DSLNode): The ending DSLNode. Returns: tuple: A tuple of two integers representing the direction from node1 to node2. @@ -333,7 +333,7 @@ def get_direction(self, node1: Node, node2: Node) -> tuple: dy = dy/abs(dy) if dy != 0 else 0 return dx, dy - def detect_and_update_waypoints(self, current_point: Node, next_point: Node): + def detect_and_update_waypoints(self, current_point: DSLNode, next_point: DSLNode): """ Updates the waypoints based on the current and next points. @@ -342,8 +342,8 @@ def detect_and_update_waypoints(self, current_point: Node, next_point: Node): deviation in the path, the current point is added to the list of waypoints. Args: - current_point (Node): The current point. - next_point (Node): The next point. + current_point (DSLNode): The current point. + next_point (DSLNode): The next point. """ if not self.WP: # If the waypoint list is empty self.WP.append(current_point) @@ -377,8 +377,8 @@ def compute_shortest_path(self): Computes or recomputes the shortest path from the start to the goal using the D* Lite algorithm. This method iteratively updates the priorities and costs of nodes based on the graph's current state, - adjusting the path as necessary until the start node's key does not precede the smallest key in the - priority queue and the start node's rhs and g values are equal. + adjusting the path as necessary until the start DSLNode's key does not precede the smallest key in the + priority queue and the start DSLNode's rhs and g values are equal. """ self.U.sort(key=lambda x: x[1]) has_elements = len(self.U) > 0 @@ -403,15 +403,15 @@ def compute_shortest_path(self): start_key_not_updated = self.compare_keys(self.U[0][1], self.calculate_key(self.start)) rhs_not_equal_to_g = self.rhs[self.start.x][self.start.y] != self.g[self.start.x][self.start.y] - def compute_current_path(self) -> list[Node]: + def compute_current_path(self) -> list[DSLNode]: """ Computes the current path from the start to the goal. Returns: - list: A list of Node objects representing the current path from the start to the goal. + list: A list of DSLNode objects representing the current path from the start to the goal. """ path = list() - current_point = Node(self.start.x, self.start.y) + current_point = DSLNode(self.start.x, self.start.y) last_point = None while not compare_coordinates(current_point, self.goal): @@ -439,13 +439,13 @@ def get_WP(self) -> list[list[int]]: return WP_list - def dsl_main(self, start: Node, goal: Node) -> tuple[bool, list[int], list[int]]: + def dsl_main(self, start: DSLNode, goal: DSLNode) -> tuple[bool, list[int], list[int]]: """ Main function to run the D* Lite algorithm. Args: - start (Node): The start node. - goal (Node): The goal node. + start (DSLNode): The start DSLNode. + goal (DSLNode): The goal DSLNode. Returns: tuple: A tuple containing a boolean indicating if the path was found, and the x and y coordinates of the path. diff --git a/guidance/d_star_lite/d_star_lite/d_star_lite_node.py b/guidance/d_star_lite/d_star_lite/d_star_lite_node.py index cd2fecc3..b01c4815 100755 --- a/guidance/d_star_lite/d_star_lite/d_star_lite_node.py +++ b/guidance/d_star_lite/d_star_lite/d_star_lite_node.py @@ -1,12 +1,12 @@ #!/usr/bin/env python3 import rclpy -from rclpy.node import Node as rclpy_node +from rclpy.node import Node import numpy as np -from d_star_lite import DStarLite, Node +from d_star_lite import DStarLite, DSLNode from vortex_msgs.srv import MissionPlanner, Waypoint -class DStarLiteNode(rclpy_node): +class DStarLiteNode(Node): """ A ROS2 node implementing the D* Lite algorithm. @@ -43,7 +43,7 @@ def d_star_lite_callback(self, request, response): gx = request.gx gy = request.gy dsl = DStarLite(ox, oy) - dsl.dsl_main(Node(sx, sy), Node(gx, gy)) + dsl.dsl_main(DSLNode(sx, sy), DSLNode(gx, gy)) path = dsl.compute_current_path() WP = np.array(dsl.get_WP()).tolist() # Convert to float32[] for Waypoint service diff --git a/guidance/d_star_lite/test/test_d_star_lite.py b/guidance/d_star_lite/test/test_d_star_lite.py deleted file mode 100644 index fed94e5c..00000000 --- a/guidance/d_star_lite/test/test_d_star_lite.py +++ /dev/null @@ -1,82 +0,0 @@ -import unittest -from d_star_lite.d_star_lite import DStarLite -from d_star_lite.d_star_lite import combine_nodes, compare_coordinates, distance -from d_star_lite.d_star_lite import Node -import numpy as np - -class TestDStarLite(unittest.TestCase): - - def setUp(self): - # Create example nodes - self.node1 = Node(0, 0, 2.0) - self.node2 = Node(1, 1, 3.4) - self.node5 = Node(-1, -1, -5.8) - - # Create example obstacle coordinates - self.ox = [1, 2, 3, 4, 5] - self.oy = [0, 0, 0, 0, 0] - - # Create example dsl object - self.dsl = DStarLite(self.ox, self.oy) - - def tearDown(self): - pass - - def test_combine_nodes(self): - # Test the combine_nodes function - result = combine_nodes(self.node1, self.node2) - self.assertEqual(result.x, 1) - self.assertEqual(result.y, 1) - self.assertEqual(result.cost, 5.4) - - result = combine_nodes(self.node2, self.node5) - self.assertEqual(result.x, 0) - self.assertEqual(result.y, 0) - self.assertEqual(result.cost, -2.4) - - def test_compare_coordinates(self): - # Test the compare_coordinates function - result = compare_coordinates(self.node1, self.node2) - self.assertEqual(result, False) - - result = compare_coordinates(self.node1, self.node1) - self.assertEqual(result, True) - - def test_distance(self): - # Test the distance function - result = distance(self.node1, self.node2) - self.assertEqual(result, np.sqrt(2)) - - result = distance(self.node2, self.node5) - self.assertEqual(result, np.sqrt(8)) - - def test_is_obstacle(self): - # Test the is_obstacle function - self.assertEqual(self.dsl.is_obstacle(Node(1, 0)), True) - self.assertEqual(self.dsl.is_obstacle(Node(2, 0)), True) - self.assertEqual(self.dsl.is_obstacle(Node(5, 0)), True) - self.assertEqual(self.dsl.is_obstacle(Node(10, 0)), False) - - def test_movement_cost(self): - # Test the movement_cost function - self.assertEqual(self.dsl.movement_cost(Node(10, 0), Node(11, 0)), 1.0) - self.assertEqual(self.dsl.movement_cost(Node(10, 0), Node(10, 1)), 1.0) - self.assertEqual(self.dsl.movement_cost(Node(10, 10), Node(11, 11)), np.sqrt(2)) - self.assertEqual(self.dsl.movement_cost(Node(1, 0), Node(2, 0)), np.inf) - - def test_heuristic_distance(self): - # Test the heuristic_distance function - self.dsl.goal = Node(5, 5) - self.assertEqual(self.dsl.heuristic_distance(Node(0, 0)), np.sqrt(50)) - self.assertEqual(self.dsl.heuristic_distance(Node(5, 5)), 0.0) - self.assertEqual(self.dsl.heuristic_distance(Node(10, 10)), np.sqrt(50)) - - def test_get_direction(self): - # Test the get_direction function - self.assertEqual(self.dsl.get_direction(Node(0, 0), Node(1, 0)), (1, 0)) - self.assertEqual(self.dsl.get_direction(Node(0, 0), Node(2, 2)), (1, 1)) - self.assertEqual(self.dsl.get_direction(Node(0, 0), Node(-1, 0)), (-1, 0)) - self.assertEqual(self.dsl.get_direction(Node(0, 0), Node(0, -1)), (0, -1)) - -if __name__ == '__main__': - unittest.main() \ No newline at end of file diff --git a/guidance/d_star_lite/tests/test_d_star_lite.py b/guidance/d_star_lite/tests/test_d_star_lite.py new file mode 100644 index 00000000..d84a2869 --- /dev/null +++ b/guidance/d_star_lite/tests/test_d_star_lite.py @@ -0,0 +1,82 @@ +import unittest +import pytest +from d_star_lite.d_star_lite import DStarLite, DSLNode +from d_star_lite.d_star_lite import combine_nodes, compare_coordinates, distance +import numpy as np + +class TestDStarLite(unittest.TestCase): + + def setUp(self): + # Create example DSLNodes + self.DSLNode1 = DSLNode(0, 0, 2.0) + self.DSLNode2 = DSLNode(1, 1, 3.4) + self.DSLNode5 = DSLNode(-1, -1, -5.8) + + # Create example obstacle coordinates + self.ox = [1, 2, 3, 4, 5] + self.oy = [0, 0, 0, 0, 0] + + # Create example dsl object + self.dsl = DStarLite(self.ox, self.oy) + + def tearDown(self): + pass + + def test_combine_DSLNodes(self): + # Test the combine_DSLNodes function + result = combine_nodes(self.DSLNode1, self.DSLNode2) + self.assertEqual(result.x, 1) + self.assertEqual(result.y, 1) + self.assertEqual(result.cost, 5.4) + + result = combine_nodes(self.DSLNode2, self.DSLNode5) + self.assertEqual(result.x, 0) + self.assertEqual(result.y, 0) + self.assertEqual(result.cost, -2.4) + + def test_compare_coordinates(self): + # Test the compare_coordinates function + result = compare_coordinates(self.DSLNode1, self.DSLNode2) + self.assertEqual(result, False) + + result = compare_coordinates(self.DSLNode1, self.DSLNode1) + self.assertEqual(result, True) + + def test_distance(self): + # Test the distance function + result = distance(self.DSLNode1, self.DSLNode2) + self.assertEqual(result, np.sqrt(2)) + + result = distance(self.DSLNode2, self.DSLNode5) + self.assertEqual(result, np.sqrt(8)) + + def test_is_obstacle(self): + # Test the is_obstacle function + self.assertEqual(self.dsl.is_obstacle(DSLNode(1, 0)), True) + self.assertEqual(self.dsl.is_obstacle(DSLNode(2, 0)), True) + self.assertEqual(self.dsl.is_obstacle(DSLNode(5, 0)), True) + self.assertEqual(self.dsl.is_obstacle(DSLNode(10, 0)), False) + + def test_movement_cost(self): + # Test the movement_cost function + self.assertEqual(self.dsl.movement_cost(DSLNode(10, 0), DSLNode(11, 0)), 1.0) + self.assertEqual(self.dsl.movement_cost(DSLNode(10, 0), DSLNode(10, 1)), 1.0) + self.assertEqual(self.dsl.movement_cost(DSLNode(10, 10), DSLNode(11, 11)), np.sqrt(2)) + self.assertEqual(self.dsl.movement_cost(DSLNode(1, 0), DSLNode(2, 0)), np.inf) + + def test_heuristic_distance(self): + # Test the heuristic_distance function + self.dsl.goal = DSLNode(5, 5) + self.assertEqual(self.dsl.heuristic_distance(DSLNode(0, 0)), np.sqrt(50)) + self.assertEqual(self.dsl.heuristic_distance(DSLNode(5, 5)), 0.0) + self.assertEqual(self.dsl.heuristic_distance(DSLNode(10, 10)), np.sqrt(50)) + + def test_get_direction(self): + # Test the get_direction function + self.assertEqual(self.dsl.get_direction(DSLNode(0, 0), DSLNode(1, 0)), (1, 0)) + self.assertEqual(self.dsl.get_direction(DSLNode(0, 0), DSLNode(2, 2)), (1, 1)) + self.assertEqual(self.dsl.get_direction(DSLNode(0, 0), DSLNode(-1, 0)), (-1, 0)) + self.assertEqual(self.dsl.get_direction(DSLNode(0, 0), DSLNode(0, -1)), (0, -1)) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file