From 930e1d3e3fa38852dc5ee7d46f33a11070934f10 Mon Sep 17 00:00:00 2001 From: Chirag Jain Date: Wed, 30 Oct 2024 12:55:40 +0000 Subject: [PATCH] Use ordered dict to only have one instance of plugin registered --- src/axolotl/integrations/base.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index f79a9b93c6..a33ebd17f1 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -18,9 +18,10 @@ To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods. """ +import collections import importlib import logging -from typing import List +from typing import OrderedDict class BasePlugin: @@ -233,7 +234,7 @@ class PluginManager: pre_model_load(cfg): Calls the pre_model_load method of all registered plugins. """ - plugins: List[BasePlugin] = [] + plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict() _instance = None @@ -243,7 +244,7 @@ def __new__(cls): """ if cls._instance is None: cls._instance = super(PluginManager, cls).__new__(cls) - cls._instance.plugins: List[BasePlugin] = [] + cls._instance.plugins = collections.OrderedDict() return cls._instance @staticmethod @@ -271,7 +272,7 @@ def register(self, plugin_name: str): """ try: plugin = load_plugin(plugin_name) - self.plugins.append(plugin) + self.plugins[plugin_name] = plugin except ImportError: logging.error(f"Failed to load plugin: {plugin_name}") @@ -283,7 +284,7 @@ def get_input_args(self): list[str]: A list of Pydantic classes for all registered plugins' input arguments.' """ input_args = [] - for plugin in self.plugins: + for _, plugin in self.plugins.items(): input_args_from_plugin = plugin.get_input_args() if input_args_from_plugin is not None: input_args.append(input_args_from_plugin) @@ -299,7 +300,7 @@ def pre_model_load(self, cfg): Returns: None """ - for plugin in self.plugins: + for _, plugin in self.plugins.items(): plugin.pre_model_load(cfg) def post_model_load(self, cfg, model): @@ -313,7 +314,7 @@ def post_model_load(self, cfg, model): Returns: None """ - for plugin in self.plugins: + for _, plugin in self.plugins.items(): plugin.post_model_load(cfg, model) def pre_lora_load(self, cfg, model): @@ -327,7 +328,7 @@ def pre_lora_load(self, cfg, model): Returns: None """ - for plugin in self.plugins: + for _, plugin in self.plugins.items(): plugin.pre_lora_load(cfg, model) def post_lora_load(self, cfg, model): @@ -341,7 +342,7 @@ def post_lora_load(self, cfg, model): Returns: None """ - for plugin in self.plugins: + for _, plugin in self.plugins.items(): plugin.post_lora_load(cfg, model) def create_optimizer(self, cfg, trainer): @@ -355,7 +356,7 @@ def create_optimizer(self, cfg, trainer): Returns: object: The created optimizer, or None if none was found. """ - for plugin in self.plugins: + for _, plugin in self.plugins.items(): optimizer = plugin.create_optimizer(cfg, trainer) if optimizer is not None: return optimizer @@ -373,7 +374,7 @@ def create_lr_scheduler(self, cfg, trainer, optimizer): Returns: object: The created learning rate scheduler, or None if none was found. """ - for plugin in self.plugins: + for _, plugin in self.plugins.items(): scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer) if scheduler is not None: return scheduler @@ -391,7 +392,7 @@ def add_callbacks_pre_trainer(self, cfg, model): List[callable]: A list of callback functions to be added to the TrainingArgs. """ callbacks = [] - for plugin in self.plugins: + for _, plugin in self.plugins.items(): callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model)) return callbacks @@ -407,7 +408,7 @@ def add_callbacks_post_trainer(self, cfg, trainer): List[callable]: A list of callback functions to be added to the TrainingArgs. """ callbacks = [] - for plugin in self.plugins: + for _, plugin in self.plugins.items(): callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer)) return callbacks @@ -422,5 +423,5 @@ def post_train_unload(self, cfg): Returns: None """ - for plugin in self.plugins: + for _, plugin in self.plugins.items(): plugin.post_train_unload(cfg)