diff --git a/playground/offline.py b/playground/offline.py index 76193da..178398c 100644 --- a/playground/offline.py +++ b/playground/offline.py @@ -27,6 +27,7 @@ log = logging.getLogger(__name__) model_id = ModelIdentifier(language="english", model_name="wiki-50", is_stemmed=False) +# model_id = ModelIdentifier(language="english", model_name="glove-twitter-25", is_stemmed=False) # model_id = ModelIdentifier(language="english", model_name="google-300", is_stemmed=False) # model_id = ModelIdentifier(language="hebrew", model_name="twitter", is_stemmed=False) # model_id = ModelIdentifier(language="hebrew", model_name="ft-200", is_stemmed=False) @@ -46,7 +47,13 @@ def run_offline(board: Board = ENGLISH_BOARDS[2]): # noqa: F405 game_runner = None try: # blue_hinter = GPTHinter(name="Yoda", api_key=GPT_API_KEY) - blue_hinter = NaiveHinter(name="Yoda", team_color=TeamColor.BLUE, model_adapter=adapter, max_group_size=4) + blue_hinter = NaiveHinter( + name="Yoda", + team_color=TeamColor.BLUE, + model_identifier=model_id, + model_adapter=adapter, + max_group_size=4, + ) red_hinter = NaiveHinter( name="Einstein", team_color=TeamColor.RED, diff --git a/playground/printer.py b/playground/printer.py index 662b352..b0a3d3f 100644 --- a/playground/printer.py +++ b/playground/printer.py @@ -20,7 +20,7 @@ def print_results(game_runner: Optional[GameRunner]): def _print_board(state: GameState): log.info("") - log.info(f"{state.board}") + log.info(f"\n{state.board}") def _print_moves(game_runner: GameRunner): diff --git a/pyproject.toml b/pyproject.toml index 426f025..940eb44 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ [tool.poetry] name = "codenames-solvers" -version = "1.6.1" +version = "1.7.0" description = "Solvers implementation for Codenames board game in python." authors = ["Michael Kali ", "Asaf Kali "] readme = "README.md" diff --git a/solvers/models/cache.py b/solvers/models/cache.py index c0be07a..70b0c4b 100644 --- a/solvers/models/cache.py +++ b/solvers/models/cache.py @@ -4,6 +4,7 @@ from threading import Lock from typing import Dict +import gensim.downloader as gensim_api from generic_iterative_stemmer.models import StemmedKeyedVectors from gensim.models import KeyedVectors @@ -13,8 +14,8 @@ class ModelCache: - def __init__(self): - self.language_data_folder = "~/.cache/language_data" + def __init__(self, language_data_folder: str = "~/.cache/language_data"): + self.language_data_folder = language_data_folder self._cache: Dict[ModelIdentifier, KeyedVectors] = {} self._main_lock = Lock() self._model_locks: Dict[ModelIdentifier, Lock] = {} @@ -35,25 +36,33 @@ def load_model(self, model_identifier: ModelIdentifier) -> KeyedVectors: return self._cache[model_identifier] def _load_model(self, model_identifier: ModelIdentifier) -> KeyedVectors: - # TODO: in case loading fails, try gensim downloader - # import gensim.downloader as api - # model = api.load("wiki-he") log.info("Loading model...", extra={"model": model_identifier.dict()}) language_base_folder = expanduser(os.path.join(self.language_data_folder, model_identifier.language)) - model = load_kv_format( - language_base_folder=language_base_folder, - model_name=model_identifier.model_name, - is_stemmed=model_identifier.is_stemmed, - ) - log.info("Model loaded", extra={"model": model_identifier.dict()}) - return model + try: + return load_kv_format( + language_base_folder=language_base_folder, + model_name=model_identifier.model_name, + is_stemmed=model_identifier.is_stemmed, + ) + except Exception as e: + log.warning(f"Failed to load model: {e}", exc_info=True) + return load_from_gensim(model_identifier) def load_kv_format(language_base_folder: str, model_name: str, is_stemmed: bool = False) -> KeyedVectors: model_folder = os.path.join(language_base_folder, model_name) - file_path = os.path.join(model_folder, "model.kv") # TODO: This needs fixing + file_path = os.path.join(model_folder, "model.kv") + log.debug(f"Looking for [{model_name}] in {file_path}...") if is_stemmed: model = StemmedKeyedVectors.load(file_path) else: model = KeyedVectors.load(file_path) + log.debug(f"Successfully loaded [{model_name}] from {file_path}") + return model + + +def load_from_gensim(model_identifier: ModelIdentifier) -> KeyedVectors: + log.debug(f"Looking for [{model_identifier.model_name}] in gensim API...") + model = gensim_api.load(model_identifier.model_name) + log.debug(f"Successfully loaded [{model_identifier.model_name}] from gensim API") return model diff --git a/solvers/models/identifier.py b/solvers/models/identifier.py index aa1c01f..1d91d09 100644 --- a/solvers/models/identifier.py +++ b/solvers/models/identifier.py @@ -8,3 +8,6 @@ class ModelIdentifier(BaseModel): def __hash__(self): return hash(f"{self.language}-{self.model_name}-{self.is_stemmed}") + + def __str__(self) -> str: + return f"{self.language}-{self.model_name}"