From de933093932089fbfcca56de6c9c3976f2df19cf Mon Sep 17 00:00:00 2001 From: Leni Aniva Date: Mon, 28 Oct 2024 10:34:09 -0700 Subject: [PATCH] feat: Add linage info to tree search --- pantograph/search.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/pantograph/search.py b/pantograph/search.py index b6cd63e..0a1debf 100644 --- a/pantograph/search.py +++ b/pantograph/search.py @@ -1,7 +1,7 @@ from abc import abstractmethod import time from dataclasses import dataclass -from typing import Optional +from typing import Optional, Self import collections, unittest from pantograph.server import Server, TacticFailure, ServerError @@ -11,15 +11,15 @@ @dataclass class SearchState: - state: GoalState - parent: Optional[int] + goal_state: GoalState + parent: Optional[Self] parent_goal_id: Optional[int] priorities: list[float] def __post_init__(self): - assert len(self.priorities) == len(self.state.goals) - self.solved = [False for _ in self.state.goals] - self.trials = [0 for _ in self.state.goals] + assert len(self.priorities) == len(self.goal_state.goals) + self.solved = [False for _ in self.goal_state.goals] + self.trials = [0 for _ in self.goal_state.goals] @property def next_goal_id(self) -> int: @@ -89,7 +89,7 @@ def search(self, time_start = time.time() initial_state = SearchState( - state=goal_state, + goal_state, parent=None, parent_goal_id=None, priorities=[0.0 for _ in goal_state.goals] @@ -120,7 +120,7 @@ def search(self, tactic = None else: # Generate tactic for this goal - tactic = self.next_tactic(search_state.state, goal_id) + tactic = self.next_tactic(search_state.goal_state, goal_id) if verbose: print(f"Next tactic: {tactic}") @@ -143,18 +143,18 @@ def search(self, try: search_state.trials[goal_id] += 1 - state = search_state.state + goal_state = search_state.goal_state if verbose: - print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.state.goals[goal_id]}") - next_goal_state = server.goal_tactic(search_state.state, goal_id, tactic) + print(f"{goal_state.state_id}.{goal_id}: {tactic} on {goal_state.goals[goal_id]}") + next_goal_state = server.goal_tactic(goal_state, goal_id, tactic) # Generate priorities for the next goal state priorities = [0.0 for _ in next_goal_state.goals] \ if len(next_goal_state.goals) <= 1 else \ self.guidance(next_goal_state) parent = len(search_stack) - 1 next_state = SearchState( - state=next_goal_state, - parent=parent, + goal_state=next_goal_state, + parent=search_state, parent_goal_id=goal_id, priorities=priorities )