Skip to content

Commit

Permalink
Refactor/chain (#172)
Browse files Browse the repository at this point in the history
* define `Chain` interface

* implement `AbstractChain`

* refactor `svm` chain developement

* refactor `neural network` chain developement

* refactor `neighbors` chain developement

* refactor `naive bayes` chain developement

* refactor `decision tree` chain developement

* refactor `cross decomposition` chain developement

* refactor `clustering` chain developement

* refactor and update `util.py`

* update

* update

* update docstring

* update `test_pymilo.py` to make compatible with new changes

* refactor `ensemble` chain developement

* refactor `linear` chain developement

* reorder imports

* refactor spacings

* `autopep8.sh` applied

* refactor concrete chain implementations and remove non-necessary object level functions to be out of class functions

* refactorings

* apply docstring feedbacks
  • Loading branch information
AHReccese authored Jan 10, 2025
1 parent 96d46ed commit d3b9857
Show file tree
Hide file tree
Showing 13 changed files with 552 additions and 1,452 deletions.
214 changes: 214 additions & 0 deletions pymilo/chains/chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
# -*- coding: utf-8 -*-
"""PyMilo Chain Module."""

from traceback import format_exc
from abc import ABC, abstractmethod

from ..utils.util import get_sklearn_type
from ..transporters.transporter import Command
from ..exceptions.serialize_exception import PymiloSerializationException, SerializationErrorTypes
from ..exceptions.deserialize_exception import PymiloDeserializationException, DeserializationErrorTypes


class Chain(ABC):
"""
Chain Interface.
Each Chain serializes/deserializes the given model.
"""

@abstractmethod
def is_supported(self, model):
"""
Check if the given model is a sklearn's ML model supported by this chain.
:param model: a string name of an ML model or a sklearn object of it
:type model: any object
:return: check result as bool
"""

@abstractmethod
def transport(self, request, command, is_inner_model=False):
"""
Return the transported (serialized or deserialized) model.
:param request: given ML model to be transported
:type request: any object
:param command: command to specify whether the request should be serialized or deserialized
:type command: transporter.Command
:param is_inner_model: determines whether it is an inner model of a super ML model
:type is_inner_model: boolean
:return: the transported request as a json string or sklearn ML model
"""

@abstractmethod
def serialize(self, model):
"""
Return the serialized json string of the given model.
:param model: given ML model to be get serialized
:type model: sklearn ML model
:return: the serialized json string of the given ML model
"""

@abstractmethod
def deserialize(self, serialized_model, is_inner_model=False):
"""
Return the associated sklearn ML model of the given previously serialized ML model.
:param serialized_model: given json string of a ML model to get deserialized to associated sklearn ML model
:type serialized_model: obj
:param is_inner_model: determines whether it is an inner ML model of a super ML model
:type is_inner_model: boolean
:return: associated sklearn ML model
"""

@abstractmethod
def validate(self, model, command):
"""
Check if the provided inputs are valid in relation to each other.
:param model: a sklearn ML model or a json string of it, serialized through the pymilo export
:type model: obj
:param command: command to specify whether the request should be serialized or deserialized
:type command: transporter.Command
:return: None
"""


class AbstractChain(Chain):
"""Abstract Chain with the general implementation of the Chain interface."""

def __init__(self, transporters, supported_models):
"""
Initialize the AbstractChain instance.
:param transporters: worker transporters dedicated to this chain
:type transporters: transporter.AbstractTransporter[]
:param supported_models: supported sklearn ML models belong to this chain
:type supported_models: dict
:return: an instance of the AbstractChain class
"""
self._transporters = transporters
self._supported_models = supported_models

def is_supported(self, model):
"""
Check if the given model is a sklearn's ML model supported by this chain.
:param model: a string name of an ML model or a sklearn object of it
:type model: any object
:return: check result as bool
"""
model_name = model if isinstance(model, str) else get_sklearn_type(model)
return model_name in self._supported_models

def transport(self, request, command, is_inner_model=False):
"""
Return the transported (serialized or deserialized) model.
:param request: given ML model to be transported
:type request: any object
:param command: command to specify whether the request should be serialized or deserialized
:type command: transporter.Command
:param is_inner_model: determines whether it is an inner model of a super ML model
:type is_inner_model: boolean
:return: the transported request as a json string or sklearn ML model
"""
if not is_inner_model:
self.validate(request, command)

if command == Command.SERIALIZE:
try:
return self.serialize(request)
except Exception as e:
raise PymiloSerializationException(
{
'error_type': SerializationErrorTypes.VALID_MODEL_INVALID_INTERNAL_STRUCTURE,
'error': {
'Exception': repr(e),
'Traceback': format_exc(),
},
'object': request,
})

elif command == Command.DESERIALIZE:
try:
return self.deserialize(request, is_inner_model)
except Exception as e:
raise PymiloDeserializationException(
{
'error_type': DeserializationErrorTypes.VALID_MODEL_INVALID_INTERNAL_STRUCTURE,
'error': {
'Exception': repr(e),
'Traceback': format_exc()},
'object': request
})

def serialize(self, model):
"""
Return the serialized json string of the given model.
:param model: given ML model to be get serialized
:type model: sklearn ML model
:return: the serialized json string of the given ML model
"""
for transporter in self._transporters:
self._transporters[transporter].transport(model, Command.SERIALIZE)
return model.__dict__

def deserialize(self, serialized_model, is_inner_model=False):
"""
Return the associated sklearn ML model of the given previously serialized ML model.
:param serialized_model: given json string of a ML model to get deserialized to associated sklearn ML model
:type serialized_model: obj
:param is_inner_model: determines whether it is an inner ML model of a super ML model
:type is_inner_model: boolean
:return: associated sklearn ML model
"""
raw_model = None
data = None
if is_inner_model:
raw_model = self._supported_models[serialized_model["type"]]()
data = serialized_model["data"]
else:
raw_model = self._supported_models[serialized_model.type]()
data = serialized_model.data
for transporter in self._transporters:
self._transporters[transporter].transport(
serialized_model, Command.DESERIALIZE, is_inner_model)
for item in data:
setattr(raw_model, item, data[item])
return raw_model

def validate(self, model, command):
"""
Check if the provided inputs are valid in relation to each other.
:param model: a sklearn ML model or a json string of it, serialized through the pymilo export
:type model: obj
:param command: command to specify whether the request should be serialized or deserialized
:type command: transporter.Command
:return: None
"""
if command == Command.SERIALIZE:
if self.is_supported(model):
return
else:
raise PymiloSerializationException(
{
'error_type': SerializationErrorTypes.INVALID_MODEL,
'object': model
}
)
elif command == Command.DESERIALIZE:
if self.is_supported(model.type):
return
else:
raise PymiloDeserializationException(
{
'error_type': DeserializationErrorTypes.INVALID_MODEL,
'object': model
}
)
150 changes: 8 additions & 142 deletions pymilo/chains/clustering_chain.py
Original file line number Diff line number Diff line change
@@ -1,158 +1,24 @@
# -*- coding: utf-8 -*-
"""PyMilo chain for clustering models."""
from ..transporters.transporter import Command
"""PyMilo chain for Clustering models."""

from ..transporters.general_data_structure_transporter import GeneralDataStructureTransporter
from ..transporters.function_transporter import FunctionTransporter
from ..chains.chain import AbstractChain
from ..pymilo_param import SKLEARN_CLUSTERING_TABLE, NOT_SUPPORTED
from ..transporters.cfnode_transporter import CFNodeTransporter
from ..transporters.function_transporter import FunctionTransporter
from ..transporters.general_data_structure_transporter import GeneralDataStructureTransporter
from ..transporters.preprocessing_transporter import PreprocessingTransporter

from ..utils.util import get_sklearn_type

from ..pymilo_param import SKLEARN_CLUSTERING_TABLE, NOT_SUPPORTED
from ..exceptions.serialize_exception import PymiloSerializationException, SerializationErrorTypes
from ..exceptions.deserialize_exception import PymiloDeserializationException, DeserializationErrorTypes
from traceback import format_exc

bisecting_kmeans_support = SKLEARN_CLUSTERING_TABLE["BisectingKMeans"] != NOT_SUPPORTED
CLUSTERING_CHAIN = {
"PreprocessingTransporter": PreprocessingTransporter(),
"GeneralDataStructureTransporter": GeneralDataStructureTransporter(),
"FunctionTransporter": FunctionTransporter(),
"CFNodeTransporter": CFNodeTransporter(),
}

if bisecting_kmeans_support:
from ..transporters.randomstate_transporter import RandomStateTransporter
if SKLEARN_CLUSTERING_TABLE["BisectingKMeans"] != NOT_SUPPORTED:
from ..transporters.bisecting_tree_transporter import BisectingTreeTransporter
from ..transporters.randomstate_transporter import RandomStateTransporter
CLUSTERING_CHAIN["RandomStateTransporter"] = RandomStateTransporter()
CLUSTERING_CHAIN["BisectingTreeTransporter"] = BisectingTreeTransporter()


def is_clusterer(model):
"""
Check if the input model is a sklearn's clustering model.
:param model: is a string name of a clusterer or a sklearn object of it
:type model: any object
:return: check result as bool
"""
if isinstance(model, str):
return model in SKLEARN_CLUSTERING_TABLE
else:
return get_sklearn_type(model) in SKLEARN_CLUSTERING_TABLE.keys()


def transport_clusterer(request, command, is_inner_model=False):
"""
Return the transported (Serialized or Deserialized) model.
:param request: given clusterer to be transported
:type request: any object
:param command: command to specify whether the request should be serialized or deserialized
:type command: transporter.Command
:param is_inner_model: determines whether it is an inner model of a super ml model
:type is_inner_model: boolean
:return: the transported request as a json string or sklearn clustering model
"""
if not is_inner_model:
_validate_input(request, command)

if command == Command.SERIALIZE:
try:
return serialize_clusterer(request)
except Exception as e:
raise PymiloSerializationException(
{
'error_type': SerializationErrorTypes.VALID_MODEL_INVALID_INTERNAL_STRUCTURE,
'error': {
'Exception': repr(e),
'Traceback': format_exc(),
},
'object': request,
})

elif command == Command.DESERIALIZE:
try:
return deserialize_clusterer(request, is_inner_model)
except Exception as e:
raise PymiloDeserializationException(
{
'error_type': SerializationErrorTypes.VALID_MODEL_INVALID_INTERNAL_STRUCTURE,
'error': {
'Exception': repr(e),
'Traceback': format_exc()},
'object': request})


def serialize_clusterer(clusterer_object):
"""
Return the serialized json string of the given clustering model.
:param clusterer_object: given model to be get serialized
:type clusterer_object: any sklearn clustering model
:return: the serialized json string of the given clusterer
"""
for transporter in CLUSTERING_CHAIN:
CLUSTERING_CHAIN[transporter].transport(
clusterer_object, Command.SERIALIZE)
return clusterer_object.__dict__


def deserialize_clusterer(clusterer, is_inner_model=False):
"""
Return the associated sklearn clustering model of the given clusterer.
:param clusterer: given json string of a clustering model to get deserialized to associated sklearn clustering model
:type clusterer: obj
:param is_inner_model: determines whether it is an inner model of a super ml model
:type is_inner_model: boolean
:return: associated sklearn clustering model
"""
raw_model = None
data = None
if is_inner_model:
raw_model = SKLEARN_CLUSTERING_TABLE[clusterer["type"]]()
data = clusterer["data"]
else:
raw_model = SKLEARN_CLUSTERING_TABLE[clusterer.type]()
data = clusterer.data

for transporter in CLUSTERING_CHAIN:
CLUSTERING_CHAIN[transporter].transport(
clusterer, Command.DESERIALIZE, is_inner_model)
for item in data:
setattr(raw_model, item, data[item])
return raw_model


def _validate_input(model, command):
"""
Check if the provided inputs are valid in relation to each other.
:param model: a sklearn clusterer model or a json string of it, serialized through the pymilo export.
:type model: obj
:param command: command to specify whether the request should be serialized or deserialized
:type command: transporter.Command
:return: None
"""
if command == Command.SERIALIZE:
if is_clusterer(model):
return
else:
raise PymiloSerializationException(
{
'error_type': SerializationErrorTypes.INVALID_MODEL,
'object': model
}
)
elif command == Command.DESERIALIZE:
if is_clusterer(model.type):
return
else:
raise PymiloDeserializationException(
{
'error_type': DeserializationErrorTypes.INVALID_MODEL,
'object': model
}
)
clustering_chain = AbstractChain(CLUSTERING_CHAIN, SKLEARN_CLUSTERING_TABLE)
Loading

0 comments on commit d3b9857

Please sign in to comment.