Skip to content

Commit

Permalink
Use ordered dict to only have one instance of plugin registered
Browse files Browse the repository at this point in the history
  • Loading branch information
chiragjn committed Oct 30, 2024
1 parent af48625 commit 930e1d3
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions src/axolotl/integrations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.
"""
import collections
import importlib
import logging
from typing import List
from typing import OrderedDict


class BasePlugin:
Expand Down Expand Up @@ -233,7 +234,7 @@ class PluginManager:
pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
"""

plugins: List[BasePlugin] = []
plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict()

_instance = None

Expand All @@ -243,7 +244,7 @@ def __new__(cls):
"""
if cls._instance is None:
cls._instance = super(PluginManager, cls).__new__(cls)
cls._instance.plugins: List[BasePlugin] = []
cls._instance.plugins = collections.OrderedDict()
return cls._instance

@staticmethod
Expand Down Expand Up @@ -271,7 +272,7 @@ def register(self, plugin_name: str):
"""
try:
plugin = load_plugin(plugin_name)
self.plugins.append(plugin)
self.plugins[plugin_name] = plugin
except ImportError:
logging.error(f"Failed to load plugin: {plugin_name}")

Expand All @@ -283,7 +284,7 @@ def get_input_args(self):
list[str]: A list of Pydantic classes for all registered plugins' input arguments.'
"""
input_args = []
for plugin in self.plugins:
for _, plugin in self.plugins.items():
input_args_from_plugin = plugin.get_input_args()
if input_args_from_plugin is not None:
input_args.append(input_args_from_plugin)
Expand All @@ -299,7 +300,7 @@ def pre_model_load(self, cfg):
Returns:
None
"""
for plugin in self.plugins:
for _, plugin in self.plugins.items():
plugin.pre_model_load(cfg)

def post_model_load(self, cfg, model):
Expand All @@ -313,7 +314,7 @@ def post_model_load(self, cfg, model):
Returns:
None
"""
for plugin in self.plugins:
for _, plugin in self.plugins.items():
plugin.post_model_load(cfg, model)

def pre_lora_load(self, cfg, model):
Expand All @@ -327,7 +328,7 @@ def pre_lora_load(self, cfg, model):
Returns:
None
"""
for plugin in self.plugins:
for _, plugin in self.plugins.items():
plugin.pre_lora_load(cfg, model)

def post_lora_load(self, cfg, model):
Expand All @@ -341,7 +342,7 @@ def post_lora_load(self, cfg, model):
Returns:
None
"""
for plugin in self.plugins:
for _, plugin in self.plugins.items():
plugin.post_lora_load(cfg, model)

def create_optimizer(self, cfg, trainer):
Expand All @@ -355,7 +356,7 @@ def create_optimizer(self, cfg, trainer):
Returns:
object: The created optimizer, or None if none was found.
"""
for plugin in self.plugins:
for _, plugin in self.plugins.items():
optimizer = plugin.create_optimizer(cfg, trainer)
if optimizer is not None:
return optimizer
Expand All @@ -373,7 +374,7 @@ def create_lr_scheduler(self, cfg, trainer, optimizer):
Returns:
object: The created learning rate scheduler, or None if none was found.
"""
for plugin in self.plugins:
for _, plugin in self.plugins.items():
scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer)
if scheduler is not None:
return scheduler
Expand All @@ -391,7 +392,7 @@ def add_callbacks_pre_trainer(self, cfg, model):
List[callable]: A list of callback functions to be added to the TrainingArgs.
"""
callbacks = []
for plugin in self.plugins:
for _, plugin in self.plugins.items():
callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model))
return callbacks

Expand All @@ -407,7 +408,7 @@ def add_callbacks_post_trainer(self, cfg, trainer):
List[callable]: A list of callback functions to be added to the TrainingArgs.
"""
callbacks = []
for plugin in self.plugins:
for _, plugin in self.plugins.items():
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
return callbacks

Expand All @@ -422,5 +423,5 @@ def post_train_unload(self, cfg):
Returns:
None
"""
for plugin in self.plugins:
for _, plugin in self.plugins.items():
plugin.post_train_unload(cfg)

0 comments on commit 930e1d3

Please sign in to comment.