From 38ddb5242a6d9a2824b3578655716a337e4fbaa4 Mon Sep 17 00:00:00 2001 From: hsahovic <25432355+hsahovic@users.noreply.github.com> Date: Sat, 23 May 2020 21:13:18 -0400 Subject: [PATCH] Feature baseline and evaluation (#50) * Add stats property to mon * Small additions in player and PNI logging * Add damage multiplier method to pokemon * Add MaxBasePowerPlayer to player.baselines * Add SimpleHeuristicsPlayer * Add comment in readme code example * Add format property to Player * Add battle against method to Player object * Add pretty format json pre-commit * Add evaluate_player function - Add evaluate function for estimate relative player strength in a comparable way - Add helper function _estimate_strength_from_result - Add corresponding _EVALUATION_RATINGS dictionnary containing data used to compute ratings - Add simple unit tests for edge cases and helper function * Use math.* instead of np.* in player.utils * Add edge case unit test for player evaluation (inf value) * Add unit test for max base power player * Add Pokemon.damage_multiplier unit test * Add unit test for SimpleHeuristicPlayer._estimate_matchup * Add SimpleHeuristicPlayer._should_dynamax unit test * Expand SimpleHeuristicPlayer unit tests * Update examples and doc with battle_against method instead of cross_evaluations --- .pre-commit-config.yaml | 1 + README.md | 1 + docs/source/max_damage_player.rst | 20 +- docs/source/ou_max_player.rst | 9 +- docs/source/using_custom_teambuilder.rst | 8 +- examples/custom_teambuilder.py | 6 +- examples/max_damage_player.py | 10 +- examples/ou_max_player.py | 10 +- src/poke_env/environment/pokemon.py | 32 +++ src/poke_env/player/baselines.py | 178 ++++++++++++++++ src/poke_env/player/player.py | 25 ++- .../player/player_network_interface.py | 2 + src/poke_env/player/utils.py | 133 ++++++++++++ unit_tests/environment/test_pokemon.py | 15 ++ unit_tests/player/test_baselines.py | 197 ++++++++++++++++++ unit_tests/player/test_player_evaluation.py | 92 ++++++++ 16 files changed, 693 insertions(+), 46 deletions(-) create mode 100644 src/poke_env/player/baselines.py create mode 100644 unit_tests/player/test_baselines.py create mode 100644 unit_tests/player/test_player_evaluation.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b98925be7..2e150ef18 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,5 +17,6 @@ repos: - id: detect-private-key - id: fix-encoding-pragma - id: mixed-line-ending + - id: pretty-format-json - id: requirements-txt-fixer - id: trailing-whitespace diff --git a/README.md b/README.md index f102ef0de..534339267 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ class YourFirstAgent(Player): # A powerful move! Let's use it return self.create_order(move) + # No available move? Let's switch then! for switch in battle.available_switches: if switch.current_hp_fraction > battle.active_pokemon.current_hp_fraction: # This other pokemon has more HP left... Let's switch it in? diff --git a/docs/source/max_damage_player.rst b/docs/source/max_damage_player.rst index c0493ff9a..653b2e81a 100644 --- a/docs/source/max_damage_player.rst +++ b/docs/source/max_damage_player.rst @@ -82,7 +82,7 @@ We also have to return an order corresponding to a random switch if the player c Running and testing our agent ***************************** -We can now test our agent by crossing evaluating it with a random agent. The complete code is: +We can now test our agent by making it battle a random agent. The complete code is: .. code-block:: python @@ -92,7 +92,6 @@ We can now test our agent by crossing evaluating it with a random agent. The com from poke_env.player.player import Player from poke_env.player.random_player import RandomPlayer - from poke_env.player.utils import cross_evaluate class MaxDamagePlayer(Player): @@ -119,18 +118,15 @@ We can now test our agent by crossing evaluating it with a random agent. The com battle_format="gen8randombattle", ) - # Now, let's evaluate our player - cross_evaluation = await cross_evaluate( - [random_player, max_damage_player], n_challenges=100 - ) + # Now, let's evaluate our player + await max_damage_player.battle_against(random_player, n_battles=100) - print( - "Max damage player won %d / 100 battles [this took %f seconds]" - % ( - cross_evaluation[max_damage_player.username][random_player.username] * 100, - time.time() - start, - ) + print( + "Max damage player won %d / 100 battles [this took %f seconds]" + % ( + max_damage_player.n_won_battles, time.time() - start ) + ) if __name__ == "__main__": diff --git a/docs/source/ou_max_player.rst b/docs/source/ou_max_player.rst index 92a80cef1..94dbe0100 100644 --- a/docs/source/ou_max_player.rst +++ b/docs/source/ou_max_player.rst @@ -304,7 +304,7 @@ To attribute a team to an agent, you need to pass a ``team`` argument to the age Running and testing our agent ***************************** -We can now test our agent by crossing evaluating it with a random agent. The complete code is: +We can now test our agent. To do so, we can use the ``cross_evaluate`` function from ``poke_env.player.utils`` or the ``battle_against`` method from ``Player``. .. code-block:: python @@ -313,7 +313,6 @@ We can now test our agent by crossing evaluating it with a random agent. The com from poke_env.player.player import Player from poke_env.player.random_player import RandomPlayer - from poke_env.player.utils import cross_evaluate class MaxDamagePlayer(Player): @@ -493,13 +492,11 @@ We can now test our agent by crossing evaluating it with a random agent. The com ) # Now, let's evaluate our player - cross_evaluation = await cross_evaluate( - [random_player, max_damage_player], n_challenges=50 - ) + await max_damage_player.battle_against(random_player, n_battles = 100) print( "Max damage player won %d / 100 battles" - % (cross_evaluation[max_damage_player.username][random_player.username] * 100) + % max_damage_player.n_won_battles ) diff --git a/docs/source/using_custom_teambuilder.rst b/docs/source/using_custom_teambuilder.rst index 5b50dbf5a..336831a24 100644 --- a/docs/source/using_custom_teambuilder.rst +++ b/docs/source/using_custom_teambuilder.rst @@ -192,7 +192,7 @@ Now that we have two players with custom teambuilders, we can make them battle! .. code-block:: python - await cross_evaluate([player_1, player_2], n_challenges=5) + await player_1.battle_against(player_2, n_battles=5) The complete example looks like that: @@ -203,7 +203,6 @@ The complete example looks like that: import numpy as np from poke_env.player.random_player import RandomPlayer - from poke_env.player.utils import cross_evaluate from poke_env.teambuilder.teambuilder import Teambuilder @@ -347,10 +346,7 @@ The complete example looks like that: max_concurrent_battles=10, ) - await cross_evaluate([player_1, player_2], n_challenges=5) - - for battle in player_1.battles: - print(battle) + await player_1.battle_against(player_2, n_battles=5) if __name__ == "__main__": diff --git a/examples/custom_teambuilder.py b/examples/custom_teambuilder.py index 197de6bef..b3efa1772 100644 --- a/examples/custom_teambuilder.py +++ b/examples/custom_teambuilder.py @@ -3,7 +3,6 @@ import numpy as np from poke_env.player.random_player import RandomPlayer -from poke_env.player.utils import cross_evaluate from poke_env.teambuilder.teambuilder import Teambuilder @@ -143,10 +142,7 @@ async def main(): battle_format="gen8ou", team=custom_builder, max_concurrent_battles=10 ) - await cross_evaluate([player_1, player_2], n_challenges=5) - - for battle in player_1.battles: - print(battle) + await player_1.battle_against(player_2, n_battles=5) if __name__ == "__main__": diff --git a/examples/max_damage_player.py b/examples/max_damage_player.py index 57461a7bb..f189a8087 100644 --- a/examples/max_damage_player.py +++ b/examples/max_damage_player.py @@ -4,7 +4,6 @@ from poke_env.player.player import Player from poke_env.player.random_player import RandomPlayer -from poke_env.player.utils import cross_evaluate class MaxDamagePlayer(Player): @@ -28,16 +27,11 @@ async def main(): max_damage_player = MaxDamagePlayer(battle_format="gen8randombattle") # Now, let's evaluate our player - cross_evaluation = await cross_evaluate( - [random_player, max_damage_player], n_challenges=100 - ) + await max_damage_player.battle_against(random_player, n_battles=100) print( "Max damage player won %d / 100 battles [this took %f seconds]" - % ( - cross_evaluation[max_damage_player.username][random_player.username] * 100, - time.time() - start, - ) + % (max_damage_player.n_won_battles, time.time() - start) ) diff --git a/examples/ou_max_player.py b/examples/ou_max_player.py index b38e12bd8..bc7f7af37 100644 --- a/examples/ou_max_player.py +++ b/examples/ou_max_player.py @@ -4,7 +4,6 @@ from poke_env.player.player import Player from poke_env.player.random_player import RandomPlayer -from poke_env.player.utils import cross_evaluate class MaxDamagePlayer(Player): @@ -180,14 +179,9 @@ async def main(): ) # Now, let's evaluate our player - cross_evaluation = await cross_evaluate( - [random_player, max_damage_player], n_challenges=50 - ) + await max_damage_player.battle_against(random_player, n_battles=100) - print( - "Max damage player won %d / 100 battles" - % (cross_evaluation[max_damage_player.username][random_player.username] * 100) - ) + print("Max damage player won %d / 100 battles" % max_damage_player.n_won_battles) if __name__ == "__main__": diff --git a/src/poke_env/environment/pokemon.py b/src/poke_env/environment/pokemon.py index 8f2f1ec16..dd4cd10c4 100644 --- a/src/poke_env/environment/pokemon.py +++ b/src/poke_env/environment/pokemon.py @@ -5,6 +5,7 @@ from typing import Optional from typing import Set from typing import Tuple +from typing import Union from poke_env.data import POKEDEX from poke_env.environment.effect import Effect @@ -59,6 +60,13 @@ def __init__( self._heightm: int self._possible_abilities: List[str] self._species: str + self._stats: Dict[str, Optional[int]] = { + "atk": None, + "def": None, + "spa": None, + "spd": None, + "spe": None, + } self._type_1: PokemonType self._type_2: Optional[PokemonType] = None self._weightkg: int @@ -359,6 +367,22 @@ def _was_illusionned(self): self._status = None self._switch_out() + def damage_multiplier(self, type_or_move: Union[PokemonType, Move]) -> float: + """ + Returns the damage multiplier associated with a given type or move on this + pokemon. + + This method is a shortcut for PokemonType.damage_multiplier with relevant types. + + :param type_or_move: The type or move of interest. + :type type_or_move: PokemonType or Move + :return: The damage multiplier associated with given type on the pokemon. + :rtype: float + """ + if isinstance(type_or_move, Move): + type_or_move = type_or_move.type + return type_or_move.damage_multiplier(self._type_1, self._type_2) + @property def ability(self) -> Optional[str]: """ @@ -565,6 +589,14 @@ def species(self) -> str: """ return self._species + @property + def stats(self) -> Dict[str, Optional[int]]: + """ + :return: The pokemon's stats, as a dictionary. + :rtype: Dict[str, Optional[int]] + """ + return self._stats + @property def status(self) -> Optional[Status]: """ diff --git a/src/poke_env/player/baselines.py b/src/poke_env/player/baselines.py new file mode 100644 index 000000000..85884faef --- /dev/null +++ b/src/poke_env/player/baselines.py @@ -0,0 +1,178 @@ +# -*- coding: utf-8 -*- +from poke_env.environment.move_category import MoveCategory +from poke_env.environment.side_condition import SideCondition +from poke_env.player.player import Player + + +class MaxBasePowerPlayer(Player): + def choose_move(self, battle): + if battle.available_moves: + best_move = max(battle.available_moves, key=lambda move: move.base_power) + return self.create_order(best_move) + return self.choose_random_move(battle) + + +class SimpleHeuristicsPlayer(Player): + ENTRY_HAZARDS = { + "spikes": SideCondition.SPIKES, + "stealhrock": SideCondition.STEALTH_ROCK, + "stickyweb": SideCondition.STICKY_WEB, + "toxicspikes": SideCondition.TOXIC_SPIKES, + } + + ANTI_HAZARDS_MOVES = {"rapidspin", "defog"} + + SPEED_TIER_COEFICIENT = 0.1 + HP_FRACTION_COEFICIENT = 0.4 + SWITCH_OUT_MATCHUP_THRESHOLD = -2 + + def _estimate_matchup(self, mon, opponent): + score = max([opponent.damage_multiplier(t) for t in mon.types if t is not None]) + score -= max( + [mon.damage_multiplier(t) for t in opponent.types if t is not None] + ) + if mon.base_stats["spe"] > opponent.base_stats["spe"]: + score += self.SPEED_TIER_COEFICIENT + elif opponent.base_stats["spe"] > mon.base_stats["spe"]: + score -= self.SPEED_TIER_COEFICIENT + + score += mon.current_hp_fraction * self.HP_FRACTION_COEFICIENT + score -= opponent.current_hp_fraction * self.HP_FRACTION_COEFICIENT + + return score + + def _should_dynamax(self, battle, n_remaining_mons): + if battle.can_dynamax: + # Last full HP mon + if ( + len([m for m in battle.team.values() if m.current_hp_fraction == 1]) + == 1 + and battle.active_pokemon.current_hp_fraction == 1 + ): + return True + # Matchup advantage and full hp on full hp + if ( + self._estimate_matchup( + battle.active_pokemon, battle.opponent_active_pokemon + ) + > 0 + and battle.active_pokemon.current_hp_fraction == 1 + and battle.opponent_active_pokemon.current_hp_fraction == 1 + ): + return True + if n_remaining_mons == 1: + return True + return False + + def _should_switch_out(self, battle): + active = battle.active_pokemon + opponent = battle.opponent_active_pokemon + # If there is a decent switch in... + if [ + m + for m in battle.available_switches + if self._estimate_matchup(m, opponent) > 0 + ]: + # ...and a 'good' reason to switch out + if active.boosts["def"] <= -3 or active.boosts["spd"] <= -3: + return True + if ( + active.boosts["atk"] <= -3 + and active.stats["atk"] >= active.stats["spa"] + ): + return True + if ( + active.boosts["spa"] <= -3 + and active.stats["atk"] <= active.stats["spa"] + ): + return True + if ( + self._estimate_matchup(active, opponent) + < self.SWITCH_OUT_MATCHUP_THRESHOLD + ): + return True + return False + + def _stat_estimation(self, mon, stat): + # Stats boosts value + if mon.boosts[stat] > 1: + boost = (2 + mon.boosts[stat]) / 2 + else: + boost = 2 / (2 - mon.boosts[stat]) + return ((2 * mon.base_stats[stat] + 31) + 5) * boost + + def choose_move(self, battle): + # Main mons shortcuts + active = battle.active_pokemon + opponent = battle.opponent_active_pokemon + + # Rough estimation of damage ratio + physical_ratio = self._stat_estimation(active, "atk") / self._stat_estimation( + opponent, "def" + ) + special_ratio = self._stat_estimation(active, "spa") / self._stat_estimation( + opponent, "spd" + ) + + if battle.available_moves and (not self._should_switch_out(battle)): + n_remaining_mons = len( + [m for m in battle.team.values() if m.fainted is False] + ) + n_opp_remaining_mons = 6 - len( + [m for m in battle.team.values() if m.fainted is True] + ) + + # Entry hazard... + for move in battle.available_moves: + # ...setup + if ( + n_opp_remaining_mons >= 3 + and move.id in self.ENTRY_HAZARDS + and self.ENTRY_HAZARDS[move.id] + not in battle.opponent_side_conditions + ): + return self.create_order(move) + + # ...removal + elif ( + battle.side_conditions + and move.id in self.ANTI_HAZARDS_MOVES + and n_remaining_mons >= 2 + ): + return self.create_order(move) + + # Setup moves + if ( + active.current_hp_fraction == 1 + and self._estimate_matchup(active, opponent) > 0 + ): + for move in battle.available_moves: + if ( + move.boosts + and sum(move.boosts.values()) >= 2 + and move.target == "self" + ): + return self.create_order(move) + + move = max( + battle.available_moves, + key=lambda m: m.base_power + * (1.5 if m.type in active.types else 1) + * ( + physical_ratio + if m.category == MoveCategory.PHYSICAL + else special_ratio + ) + * m.accuracy + * opponent.damage_multiplier(m), + ) + return self.create_order( + move, dynamax=self._should_dynamax(battle, n_remaining_mons) + ) + + return self.create_order( + max( + battle.available_switches, + key=lambda s: self._estimate_matchup(s, opponent), + ) + ) diff --git a/src/poke_env/player/player.py b/src/poke_env/player/player.py index 656fdf978..fac816f92 100644 --- a/src/poke_env/player/player.py +++ b/src/poke_env/player/player.py @@ -2,6 +2,7 @@ """This module defines a base class for players. """ +import asyncio import random from abc import ABC @@ -29,6 +30,7 @@ from poke_env.server_configuration import ServerConfiguration from poke_env.teambuilder.teambuilder import Teambuilder from poke_env.teambuilder.constant_teambuilder import ConstantTeambuilder +from poke_env.utils import to_id_str class Player(PlayerNetwork, ABC): @@ -389,6 +391,23 @@ async def ladder(self, n_games): perf_counter() - start_time, ) + async def battle_against(self, opponent: "Player", n_battles: int) -> None: + """Make the player play n_battles against opponent. + + This function is a wrapper around send_challenges and accept challenges. + + :param opponent: The opponent to play against. + :type opponent: Player + :param n_battles: The number of games to play. + :type n_battles: int + """ + await asyncio.gather( + self.send_challenges( + to_id_str(opponent.username), n_battles, to_wait=opponent.logged_in + ), + opponent.accept_challenges(to_id_str(self.username), n_battles), + ) + async def send_challenges( self, opponent: str, n_challenges: int, to_wait: Optional[Event] = None ) -> None: @@ -409,7 +428,7 @@ async def send_challenges( :type to_wait: Event, optional. """ await self._logged_in.wait() - self.logger.info("Event logged in received in challenge") + self.logger.info("Event logged in received in send challenge") if to_wait is not None: await to_wait.wait() @@ -500,6 +519,10 @@ def create_order( def battles(self) -> Dict[str, Battle]: return self._battles + @property + def format(self) -> str: + return self._format + @property def n_finished_battles(self) -> int: return len([None for b in self._battles.values() if b.finished]) diff --git a/src/poke_env/player/player_network_interface.py b/src/poke_env/player/player_network_interface.py index 7d1421e5f..7c6a29fe0 100644 --- a/src/poke_env/player/player_network_interface.py +++ b/src/poke_env/player/player_network_interface.py @@ -257,6 +257,8 @@ async def listen(self) -> None: ) except (CancelledError, RuntimeError) as e: self.logger.critical("Listen interrupted by %s", e) + except Exception as e: + self.logger.exception(e) finally: for coroutine in coroutines: coroutine.cancel() diff --git a/src/poke_env/player/utils.py b/src/poke_env/player/utils.py index 987212e05..76c8a18df 100644 --- a/src/poke_env/player/utils.py +++ b/src/poke_env/player/utils.py @@ -3,12 +3,22 @@ """ from poke_env.player.player import Player +from poke_env.player.random_player import RandomPlayer +from poke_env.player.baselines import MaxBasePowerPlayer, SimpleHeuristicsPlayer from poke_env.utils import to_id_str from typing import Dict from typing import List from typing import Optional +from typing import Tuple import asyncio +import math + +_EVALUATION_RATINGS = { + RandomPlayer: 1, + MaxBasePowerPlayer: 7.608901, + SimpleHeuristicsPlayer: 121.885905, +} async def cross_evaluate( @@ -35,3 +45,126 @@ async def cross_evaluate( p_1.reset_battles() p_2.reset_battles() return results # pyre-ignore + + +def _estimate_strength_from_results( + number_of_games: int, number_of_wins: int, opponent_rating: float +) -> Tuple[float, Tuple[float, float]]: + """Estimate player strength based on game results and opponent rating. + + :param number_of_games: Number of performance games for evaluation. + :type number_of_games: int + :param number_of_win: Number of won evaluation games. + :type number_of_win: int + :param opponent_rating: The opponent's rating. + :type opponent_rating: float + :raises: ValueError if the results are too extreme to be interpreted. + :return: A tuple containing the estimated player strength and a 95% confidence + interval + :rtype: tuple of float and tuple of floats + """ + n, p = number_of_games, number_of_wins / number_of_games + q = 1 - p + + if n * p * q < 9: # Cannot apply normal approximation of binomial distribution + raise ValueError( + "The results obtained in evaluate_player are too extreme to obtain an " + "accuracte player evaluation. You can try to solve this issue by increasing" + " the total number of battles. Obtained results: %d victories out of %d" + " games." % (p * n, n) + ) + + estimate = opponent_rating * p / q + error = ( + math.sqrt(n * p * q) / n * 1.96 + ) # 95% confidence interval for normal distribution + + lower_bound = max(0, p - error) + lower_bound = opponent_rating * lower_bound / (1 - lower_bound) + + higher_bound = min(1, p + error) + + if higher_bound == 1: + higher_bound = math.inf + else: + higher_bound = opponent_rating * higher_bound / (1 - higher_bound) + + return estimate, (lower_bound, higher_bound) + + +async def evaluate_player( + player, n_battles: int = 1000, n_placement_battles: int = 30 +) -> Tuple[float, Tuple[float, float]]: + """Estimate player strength. + + This functions calculates an estimate of a player's strength, measured as its + expected performance against a random opponent in a gen 8 random battle. The + returned number can be interpreted as follows: a strength of k means that the + probability of the player winning a gen 8 random battl against a random player is k + times higher than the probability of the random player winning. + + The function returns a tuple containing a best guess based on the the played games " + as well as a tuple describing a 95% confidence interval for that estimated strength. + + The actual evaluation can be performed against any baseline player for which an + accurate strength estimate is available. This baseline is determined at the start of + the process, by playing a limited number of placement battles and choosing the + opponent closest to the player in terms of performance. + + :param player: The player to evaluate. + :type player: Player + :param n_battles: The total number of battle to perform, including placement + battles. + :type n_battles: int + :param n_placement_battles: Number of placement battles to perform per baseline + player. + :type n_placement_battles: int + :raises: ValueError if the results are too extreme to be interpreted. + :raises: AssertionError if the player is not configured to play gen8battles or the + selected number of games to play it too small. + :return: A tuple containing the estimated player strength and a 95% confidence + interval + :rtype: tuple of float and tuple of floats + """ + # Input checks + assert player.format == "gen8randombattle", ( + "Player %s can not be evaluated as its current format (%s) is not " + "gen8randombattle." % (player, player.format) + ) + + if n_placement_battles * len(_EVALUATION_RATINGS) > n_battles // 2: + player.logger.warning( + "Number of placement battles reduced from %d to %d due to limited number of" + " battles (%d). A more accuracte evaluation can be performed by increasing " + "the total number of players.", + n_placement_battles, + n_battles // len(_EVALUATION_RATINGS) // 2, + n_battles, + ) + n_placement_battles = n_battles // len(_EVALUATION_RATINGS) // 2 + + assert ( + n_placement_battles > 0 + ), "Not enough battles to perform placement battles. Please increase the number of " + "battles to perform to evaluate the player." + + # Initial placement battles + baselines = [p(max_concurrent_battles=n_battles) for p in _EVALUATION_RATINGS] + + for p in baselines: + await p.battle_against(player, n_placement_battles) + + # Select the best opponent for evaluation + best_opp = min( + baselines, key=lambda p: (abs(p.win_rate - 0.5), -_EVALUATION_RATINGS[type(p)]) + ) + + # Performing the main evaluation + remaining_battles = n_battles - len(_EVALUATION_RATINGS) * n_placement_battles + await best_opp.battle_against(player, remaining_battles) + + return _estimate_strength_from_results( + best_opp.n_finished_battles, + best_opp.n_lost_battles, + _EVALUATION_RATINGS[type(best_opp)], + ) diff --git a/unit_tests/environment/test_pokemon.py b/unit_tests/environment/test_pokemon.py index 85bab2d41..9baaa7fc6 100644 --- a/unit_tests/environment/test_pokemon.py +++ b/unit_tests/environment/test_pokemon.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +from poke_env.environment.move import Move from poke_env.environment.pokemon import Pokemon from poke_env.environment.pokemon_type import PokemonType @@ -49,3 +50,17 @@ def test_pokemon_types(): mon._primal() assert mon.type_1 == PokemonType.GROUND assert mon.type_2 == PokemonType.FIRE + + +def test_pokemon_damage_multiplier(): + mon = Pokemon(species="pikachu") + assert mon.damage_multiplier(PokemonType.GROUND) == 2 + assert mon.damage_multiplier(PokemonType.ELECTRIC) == 0.5 + + mon = Pokemon(species="garchomp") + assert mon.damage_multiplier(Move("icebeam")) == 4 + assert mon.damage_multiplier(Move("dracometeor")) == 2 + + mon = Pokemon(species="linoone") + assert mon.damage_multiplier(Move("closecombat")) == 2 + assert mon.damage_multiplier(PokemonType.GHOST) == 0 diff --git a/unit_tests/player/test_baselines.py b/unit_tests/player/test_baselines.py new file mode 100644 index 000000000..6050b310e --- /dev/null +++ b/unit_tests/player/test_baselines.py @@ -0,0 +1,197 @@ +# -*- coding: utf-8 -*- +from poke_env.environment.move import Move +from poke_env.environment.pokemon import Pokemon +from poke_env.player.baselines import MaxBasePowerPlayer +from poke_env.player.baselines import SimpleHeuristicsPlayer +from collections import namedtuple + + +def test_max_base_power_player(): + player = MaxBasePowerPlayer(start_listening=False) + + PseudoBattle = namedtuple( + "PseudoBattle", + ( + "available_moves", + "available_switches", + "can_z_move", + "can_dynamax", + "can_mega_evolve", + ), + ) + battle = PseudoBattle([], [], False, False, False) + + assert player.choose_move(battle) == "/choose default" + + battle.available_switches.append(Pokemon(species="ponyta")) + assert player.choose_move(battle) == "/choose switch ponyta" + + battle.available_moves.append(Move("protect")) + assert player.choose_move(battle) == "/choose move protect" + + battle.available_moves.append(Move("quickattack")) + assert player.choose_move(battle) == "/choose move quickattack" + + battle.available_moves.append(Move("flamethrower")) + assert player.choose_move(battle) == "/choose move flamethrower" + + +def test_simple_heuristics_player_estimate_matchup(): + player = SimpleHeuristicsPlayer(start_listening=False) + + dragapult = Pokemon(species="dragapult") + assert player._estimate_matchup(dragapult, dragapult) == 0 + + gengar = Pokemon(species="gengar") + assert player._estimate_matchup(dragapult, gengar) == -player._estimate_matchup( + gengar, dragapult + ) + assert player._estimate_matchup(dragapult, gengar) == player.SPEED_TIER_COEFICIENT + + mamoswine = Pokemon(species="mamoswine") + assert ( + player._estimate_matchup(dragapult, mamoswine) + == -1 + player.SPEED_TIER_COEFICIENT + ) + + dragapult._set_hp("100/100") + mamoswine._set_hp("50/100") + assert ( + player._estimate_matchup(dragapult, mamoswine) + == -1 + player.SPEED_TIER_COEFICIENT + player.HP_FRACTION_COEFICIENT / 2 + ) + + +def test_simple_heuristics_player_should_dynamax(): + PseudoBattle = namedtuple( + "PseudoBattle", + ["active_pokemon", "opponent_active_pokemon", "team", "can_dynamax"], + ) + player = SimpleHeuristicsPlayer(start_listening=False) + + battle = PseudoBattle( + Pokemon(species="charmander"), Pokemon(species="charmander"), {}, False + ) + assert player._should_dynamax(battle, 4) is False + + battle = PseudoBattle( + Pokemon(species="charmander"), Pokemon(species="charmander"), {}, True + ) + assert player._should_dynamax(battle, 1) is True + + battle.active_pokemon._set_hp("100/100") + battle.team["charmander"] = battle.active_pokemon + assert player._should_dynamax(battle, 4) is True + + battle = PseudoBattle( + Pokemon(species="squirtle"), + Pokemon(species="charmander"), + { + "kakuna": Pokemon(species="kakuna"), + "venusaur": Pokemon(species="venusaur"), + "charmander": Pokemon(species="charmander"), + }, + True, + ) + for mon in battle.team.values(): + mon._set_hp("100/100") + battle.active_pokemon._set_hp("100/100") + battle.opponent_active_pokemon._set_hp("100/100") + + assert player._should_dynamax(battle, 4) is True + + +def test_simple_heuristics_player_should_switch_out(): + PseudoBattle = namedtuple( + "PseudoBattle", + ["active_pokemon", "opponent_active_pokemon", "available_switches"], + ) + player = SimpleHeuristicsPlayer(start_listening=False) + + battle = PseudoBattle( + Pokemon(species="charmander"), Pokemon(species="charmander"), [] + ) + assert player._should_switch_out(battle) is False + + battle.available_switches.append(Pokemon(species="venusaur")) + assert player._should_switch_out(battle) is False + + battle.available_switches.append(Pokemon(species="gyarados")) + assert player._should_switch_out(battle) is False + + battle.active_pokemon._boost("spa", -3) + battle.active_pokemon.stats.update({"atk": 10, "spa": 20}) + assert player._should_switch_out(battle) is True + + battle.active_pokemon.stats.update({"atk": 30, "spa": 20}) + assert player._should_switch_out(battle) is False + + battle.active_pokemon._boost("atk", -3) + assert player._should_switch_out(battle) is True + + battle = PseudoBattle( + Pokemon(species="gible"), + Pokemon(species="mamoswine"), + [Pokemon(species="charizard")], + ) + assert player._should_switch_out(battle) is True + + +def test_simple_heuristics_player_stat_estimation(): + player = SimpleHeuristicsPlayer(start_listening=False) + mon = Pokemon(species="charizard") + + assert player._stat_estimation(mon, "spe") == 236 + + mon._boost("spe", 2) + assert player._stat_estimation(mon, "spe") == 472 + + mon._boost("atk", -1) + assert player._stat_estimation(mon, "atk") == 136 + + +def test_simple_heuristics_player(): + player = SimpleHeuristicsPlayer(start_listening=False) + + PseudoBattle = namedtuple( + "PseudoBattle", + ( + "active_pokemon", + "opponent_active_pokemon", + "available_moves", + "available_switches", + "team", + "can_dynamax", + "side_conditions", + "opponent_side_conditions", + ), + ) + battle = PseudoBattle( + Pokemon(species="dragapult"), + Pokemon(species="gengar"), + [], + [Pokemon(species="togekiss")], + {}, + True, + set(), + set(), + ) + battle.active_pokemon._stats = {stat: 100 for stat in battle.active_pokemon._stats} + + battle.available_switches[0]._set_hp("100/100") + assert player.choose_move(battle) == "/choose switch togekiss" + + battle.available_moves.append(Move("quickattack")) + assert player.choose_move(battle) == "/choose move quickattack" + + battle.available_moves.append(Move("flamethrower")) + assert player.choose_move(battle) == "/choose move flamethrower" + + battle.available_moves.append(Move("dracometeor")) + assert player.choose_move(battle) == "/choose move dracometeor" + + battle.active_pokemon._boost("atk", -3) + battle.active_pokemon._boost("spa", -3) + battle.available_switches.append(Pokemon(species="sneasel")) + battle.available_switches[1]._set_hp("100/100") + assert player.choose_move(battle) == "/choose switch sneasel" diff --git a/unit_tests/player/test_player_evaluation.py b/unit_tests/player/test_player_evaluation.py new file mode 100644 index 000000000..512d6231d --- /dev/null +++ b/unit_tests/player/test_player_evaluation.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- +from poke_env.player.random_player import RandomPlayer +from poke_env.player.utils import evaluate_player, _estimate_strength_from_results + +import math +import pytest + + +def test_estimate_strength_from_results(): + # Test extreme values with decent number of games + with pytest.raises(ValueError): + _estimate_strength_from_results(1000, 1, 1) + + with pytest.raises(ValueError): + _estimate_strength_from_results(1000, 999, 1) + + # Test too few games + with pytest.raises(ValueError): + _estimate_strength_from_results(5, 2, 1) + # Test too few games + with pytest.raises(ValueError): + _estimate_strength_from_results(10, 3, 1) + + # Test estimate values + for n_games, n_victories, expected_value in ( + (100, 50, 1), + (999, 333, 0.5), + (999, 666, 2), + (250, 150, 1.5), + ): + assert ( + abs( + _estimate_strength_from_results(n_games, n_victories, 1)[0] + - expected_value + ) + < 10e-10 + ) + + eis = [] + cis = [] + for n_games in [10 ** i for i in range(2, 10)]: + ei, ci = _estimate_strength_from_results(n_games, n_games // 2 + 2, 1) + eis.append(ei) + cis.append(ci) + + assert ( + abs( + ei * 2 + - _estimate_strength_from_results(n_games, n_games // 2 + 2, 2)[0] + ) + < 10e-10 + ) + assert ( + abs( + ei * 5 + - _estimate_strength_from_results(n_games, n_games // 2 + 2, 5)[0] + ) + < 10e-10 + ) + + for i, ci in enumerate(cis): + assert ci[0] < eis[i] < ci[1] + if i: + assert ci[0] > cis[i - 1][0] + assert ci[1] < cis[i - 1][1] + + assert _estimate_strength_from_results(95, 43, 1) == ( + 0.8269230769230771, + (0.5444919968130197, 1.2357621837082962), + ) + + assert _estimate_strength_from_results(10 ** 17, 10 ** 17 - 10, 1)[1][1] == math.inf + + +@pytest.mark.asyncio +async def test_player_evaluation_assertions(): + p = RandomPlayer(battle_format="gen8ou", start_listening=False) + with pytest.raises(AssertionError): + await evaluate_player(p) + + p = RandomPlayer(start_listening=False) + with pytest.raises(AssertionError): + await evaluate_player(p, n_battles=0) + + with pytest.raises(AssertionError): + await evaluate_player(p, n_battles=-10) + + with pytest.raises(AssertionError): + await evaluate_player(p, n_placement_battles=0) + + with pytest.raises(AssertionError): + await evaluate_player(p, n_placement_battles=-10)