Skip to content

Commit

Permalink
feat: Add linage info to tree search
Browse files Browse the repository at this point in the history
  • Loading branch information
lenianiva committed Oct 28, 2024
1 parent 93ecd0d commit de93309
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions pantograph/search.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import abstractmethod
import time
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Self
import collections, unittest

from pantograph.server import Server, TacticFailure, ServerError
Expand All @@ -11,15 +11,15 @@
@dataclass
class SearchState:

state: GoalState
parent: Optional[int]
goal_state: GoalState
parent: Optional[Self]
parent_goal_id: Optional[int]
priorities: list[float]

def __post_init__(self):
assert len(self.priorities) == len(self.state.goals)
self.solved = [False for _ in self.state.goals]
self.trials = [0 for _ in self.state.goals]
assert len(self.priorities) == len(self.goal_state.goals)
self.solved = [False for _ in self.goal_state.goals]
self.trials = [0 for _ in self.goal_state.goals]

@property
def next_goal_id(self) -> int:
Expand Down Expand Up @@ -89,7 +89,7 @@ def search(self,
time_start = time.time()

initial_state = SearchState(
state=goal_state,
goal_state,
parent=None,
parent_goal_id=None,
priorities=[0.0 for _ in goal_state.goals]
Expand Down Expand Up @@ -120,7 +120,7 @@ def search(self,
tactic = None
else:
# Generate tactic for this goal
tactic = self.next_tactic(search_state.state, goal_id)
tactic = self.next_tactic(search_state.goal_state, goal_id)

if verbose:
print(f"Next tactic: {tactic}")
Expand All @@ -143,18 +143,18 @@ def search(self,

try:
search_state.trials[goal_id] += 1
state = search_state.state
goal_state = search_state.goal_state
if verbose:
print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.state.goals[goal_id]}")
next_goal_state = server.goal_tactic(search_state.state, goal_id, tactic)
print(f"{goal_state.state_id}.{goal_id}: {tactic} on {goal_state.goals[goal_id]}")
next_goal_state = server.goal_tactic(goal_state, goal_id, tactic)
# Generate priorities for the next goal state
priorities = [0.0 for _ in next_goal_state.goals] \
if len(next_goal_state.goals) <= 1 else \
self.guidance(next_goal_state)
parent = len(search_stack) - 1
next_state = SearchState(
state=next_goal_state,
parent=parent,
goal_state=next_goal_state,
parent=search_state,
parent_goal_id=goal_id,
priorities=priorities
)
Expand Down

0 comments on commit de93309

Please sign in to comment.