Skip to content

Commit

Permalink
Fix cryptic error on smb, plus numpy==2.0 compatibility (#203)
Browse files Browse the repository at this point in the history
* Adds a check and a more helpful error for SMB

* Updates SMB to numpy 2.0
  • Loading branch information
miguelgondu authored Jun 17, 2024
1 parent b0f2384 commit 841087e
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,15 @@ class SMBIsolatedLogic(AbstractIsolatedFunction):
Runs the given input x as a flattened level (14x14)
through the model and returns the number
of jumps Mario makes in the level. If the
level is not solvable, returns np.NaN.
level is not solvable, returns np.nan.
"""

def __init__(
self,
alphabet: List[str] = smb_info.alphabet,
max_time: int = 30,
visualize: bool = False,
value_on_unplayable: float = np.NaN,
value_on_unplayable: float = np.nan,
):
self.alphabet = alphabet
self.alphabet_s_to_i = {s: i for i, s in enumerate(alphabet)}
Expand All @@ -92,8 +92,15 @@ def __call__(self, x: np.ndarray, context=None) -> np.ndarray:
level, max_time=self.max_time, visualize=self.visualize
)

if not isinstance(res, dict):
raise ValueError(
"Something probably went wrong with the Java simulation "
"of the level. It is quite likely you haven't set up a "
"virtual screen/frame buffer. Check the docs."
)

# Return the number of jumps if the level was
# solved successfully, else return np.NaN
# solved successfully, else return np.nan
if res["marioStatus"] == 1:
jumps = res["jumpActionsPerformed"]
else:
Expand Down
6 changes: 3 additions & 3 deletions src/poli/objective_repository/super_mario_bros/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ class SuperMarioBrosBlackBox(AbstractBlackBox):
Runs the given input x as a latent code
through the model and returns the number
of jumps Mario makes in the level. If the
level is not solvable, returns np.NaN.
level is not solvable, returns np.nan.
"""

def __init__(
self,
max_time: int = 30,
visualize: bool = False,
value_on_unplayable: float = np.NaN,
value_on_unplayable: float = np.nan,
batch_size: int = None,
parallelize: bool = False,
num_workers: int = None,
Expand Down Expand Up @@ -146,7 +146,7 @@ def create(
self,
max_time: int = 30,
visualize: bool = False,
value_on_unplayable: float = np.NaN,
value_on_unplayable: float = np.nan,
seed: int = None,
batch_size: int = None,
parallelize: bool = False,
Expand Down

0 comments on commit 841087e

Please sign in to comment.