-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* define `Chain` interface * implement `AbstractChain` * refactor `svm` chain developement * refactor `neural network` chain developement * refactor `neighbors` chain developement * refactor `naive bayes` chain developement * refactor `decision tree` chain developement * refactor `cross decomposition` chain developement * refactor `clustering` chain developement * refactor and update `util.py` * update * update * update docstring * update `test_pymilo.py` to make compatible with new changes * refactor `ensemble` chain developement * refactor `linear` chain developement * reorder imports * refactor spacings * `autopep8.sh` applied * refactor concrete chain implementations and remove non-necessary object level functions to be out of class functions * refactorings * apply docstring feedbacks
- Loading branch information
Showing
13 changed files
with
552 additions
and
1,452 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
# -*- coding: utf-8 -*- | ||
"""PyMilo Chain Module.""" | ||
|
||
from traceback import format_exc | ||
from abc import ABC, abstractmethod | ||
|
||
from ..utils.util import get_sklearn_type | ||
from ..transporters.transporter import Command | ||
from ..exceptions.serialize_exception import PymiloSerializationException, SerializationErrorTypes | ||
from ..exceptions.deserialize_exception import PymiloDeserializationException, DeserializationErrorTypes | ||
|
||
|
||
class Chain(ABC): | ||
""" | ||
Chain Interface. | ||
Each Chain serializes/deserializes the given model. | ||
""" | ||
|
||
@abstractmethod | ||
def is_supported(self, model): | ||
""" | ||
Check if the given model is a sklearn's ML model supported by this chain. | ||
:param model: a string name of an ML model or a sklearn object of it | ||
:type model: any object | ||
:return: check result as bool | ||
""" | ||
|
||
@abstractmethod | ||
def transport(self, request, command, is_inner_model=False): | ||
""" | ||
Return the transported (serialized or deserialized) model. | ||
:param request: given ML model to be transported | ||
:type request: any object | ||
:param command: command to specify whether the request should be serialized or deserialized | ||
:type command: transporter.Command | ||
:param is_inner_model: determines whether it is an inner model of a super ML model | ||
:type is_inner_model: boolean | ||
:return: the transported request as a json string or sklearn ML model | ||
""" | ||
|
||
@abstractmethod | ||
def serialize(self, model): | ||
""" | ||
Return the serialized json string of the given model. | ||
:param model: given ML model to be get serialized | ||
:type model: sklearn ML model | ||
:return: the serialized json string of the given ML model | ||
""" | ||
|
||
@abstractmethod | ||
def deserialize(self, serialized_model, is_inner_model=False): | ||
""" | ||
Return the associated sklearn ML model of the given previously serialized ML model. | ||
:param serialized_model: given json string of a ML model to get deserialized to associated sklearn ML model | ||
:type serialized_model: obj | ||
:param is_inner_model: determines whether it is an inner ML model of a super ML model | ||
:type is_inner_model: boolean | ||
:return: associated sklearn ML model | ||
""" | ||
|
||
@abstractmethod | ||
def validate(self, model, command): | ||
""" | ||
Check if the provided inputs are valid in relation to each other. | ||
:param model: a sklearn ML model or a json string of it, serialized through the pymilo export | ||
:type model: obj | ||
:param command: command to specify whether the request should be serialized or deserialized | ||
:type command: transporter.Command | ||
:return: None | ||
""" | ||
|
||
|
||
class AbstractChain(Chain): | ||
"""Abstract Chain with the general implementation of the Chain interface.""" | ||
|
||
def __init__(self, transporters, supported_models): | ||
""" | ||
Initialize the AbstractChain instance. | ||
:param transporters: worker transporters dedicated to this chain | ||
:type transporters: transporter.AbstractTransporter[] | ||
:param supported_models: supported sklearn ML models belong to this chain | ||
:type supported_models: dict | ||
:return: an instance of the AbstractChain class | ||
""" | ||
self._transporters = transporters | ||
self._supported_models = supported_models | ||
|
||
def is_supported(self, model): | ||
""" | ||
Check if the given model is a sklearn's ML model supported by this chain. | ||
:param model: a string name of an ML model or a sklearn object of it | ||
:type model: any object | ||
:return: check result as bool | ||
""" | ||
model_name = model if isinstance(model, str) else get_sklearn_type(model) | ||
return model_name in self._supported_models | ||
|
||
def transport(self, request, command, is_inner_model=False): | ||
""" | ||
Return the transported (serialized or deserialized) model. | ||
:param request: given ML model to be transported | ||
:type request: any object | ||
:param command: command to specify whether the request should be serialized or deserialized | ||
:type command: transporter.Command | ||
:param is_inner_model: determines whether it is an inner model of a super ML model | ||
:type is_inner_model: boolean | ||
:return: the transported request as a json string or sklearn ML model | ||
""" | ||
if not is_inner_model: | ||
self.validate(request, command) | ||
|
||
if command == Command.SERIALIZE: | ||
try: | ||
return self.serialize(request) | ||
except Exception as e: | ||
raise PymiloSerializationException( | ||
{ | ||
'error_type': SerializationErrorTypes.VALID_MODEL_INVALID_INTERNAL_STRUCTURE, | ||
'error': { | ||
'Exception': repr(e), | ||
'Traceback': format_exc(), | ||
}, | ||
'object': request, | ||
}) | ||
|
||
elif command == Command.DESERIALIZE: | ||
try: | ||
return self.deserialize(request, is_inner_model) | ||
except Exception as e: | ||
raise PymiloDeserializationException( | ||
{ | ||
'error_type': DeserializationErrorTypes.VALID_MODEL_INVALID_INTERNAL_STRUCTURE, | ||
'error': { | ||
'Exception': repr(e), | ||
'Traceback': format_exc()}, | ||
'object': request | ||
}) | ||
|
||
def serialize(self, model): | ||
""" | ||
Return the serialized json string of the given model. | ||
:param model: given ML model to be get serialized | ||
:type model: sklearn ML model | ||
:return: the serialized json string of the given ML model | ||
""" | ||
for transporter in self._transporters: | ||
self._transporters[transporter].transport(model, Command.SERIALIZE) | ||
return model.__dict__ | ||
|
||
def deserialize(self, serialized_model, is_inner_model=False): | ||
""" | ||
Return the associated sklearn ML model of the given previously serialized ML model. | ||
:param serialized_model: given json string of a ML model to get deserialized to associated sklearn ML model | ||
:type serialized_model: obj | ||
:param is_inner_model: determines whether it is an inner ML model of a super ML model | ||
:type is_inner_model: boolean | ||
:return: associated sklearn ML model | ||
""" | ||
raw_model = None | ||
data = None | ||
if is_inner_model: | ||
raw_model = self._supported_models[serialized_model["type"]]() | ||
data = serialized_model["data"] | ||
else: | ||
raw_model = self._supported_models[serialized_model.type]() | ||
data = serialized_model.data | ||
for transporter in self._transporters: | ||
self._transporters[transporter].transport( | ||
serialized_model, Command.DESERIALIZE, is_inner_model) | ||
for item in data: | ||
setattr(raw_model, item, data[item]) | ||
return raw_model | ||
|
||
def validate(self, model, command): | ||
""" | ||
Check if the provided inputs are valid in relation to each other. | ||
:param model: a sklearn ML model or a json string of it, serialized through the pymilo export | ||
:type model: obj | ||
:param command: command to specify whether the request should be serialized or deserialized | ||
:type command: transporter.Command | ||
:return: None | ||
""" | ||
if command == Command.SERIALIZE: | ||
if self.is_supported(model): | ||
return | ||
else: | ||
raise PymiloSerializationException( | ||
{ | ||
'error_type': SerializationErrorTypes.INVALID_MODEL, | ||
'object': model | ||
} | ||
) | ||
elif command == Command.DESERIALIZE: | ||
if self.is_supported(model.type): | ||
return | ||
else: | ||
raise PymiloDeserializationException( | ||
{ | ||
'error_type': DeserializationErrorTypes.INVALID_MODEL, | ||
'object': model | ||
} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,158 +1,24 @@ | ||
# -*- coding: utf-8 -*- | ||
"""PyMilo chain for clustering models.""" | ||
from ..transporters.transporter import Command | ||
"""PyMilo chain for Clustering models.""" | ||
|
||
from ..transporters.general_data_structure_transporter import GeneralDataStructureTransporter | ||
from ..transporters.function_transporter import FunctionTransporter | ||
from ..chains.chain import AbstractChain | ||
from ..pymilo_param import SKLEARN_CLUSTERING_TABLE, NOT_SUPPORTED | ||
from ..transporters.cfnode_transporter import CFNodeTransporter | ||
from ..transporters.function_transporter import FunctionTransporter | ||
from ..transporters.general_data_structure_transporter import GeneralDataStructureTransporter | ||
from ..transporters.preprocessing_transporter import PreprocessingTransporter | ||
|
||
from ..utils.util import get_sklearn_type | ||
|
||
from ..pymilo_param import SKLEARN_CLUSTERING_TABLE, NOT_SUPPORTED | ||
from ..exceptions.serialize_exception import PymiloSerializationException, SerializationErrorTypes | ||
from ..exceptions.deserialize_exception import PymiloDeserializationException, DeserializationErrorTypes | ||
from traceback import format_exc | ||
|
||
bisecting_kmeans_support = SKLEARN_CLUSTERING_TABLE["BisectingKMeans"] != NOT_SUPPORTED | ||
CLUSTERING_CHAIN = { | ||
"PreprocessingTransporter": PreprocessingTransporter(), | ||
"GeneralDataStructureTransporter": GeneralDataStructureTransporter(), | ||
"FunctionTransporter": FunctionTransporter(), | ||
"CFNodeTransporter": CFNodeTransporter(), | ||
} | ||
|
||
if bisecting_kmeans_support: | ||
from ..transporters.randomstate_transporter import RandomStateTransporter | ||
if SKLEARN_CLUSTERING_TABLE["BisectingKMeans"] != NOT_SUPPORTED: | ||
from ..transporters.bisecting_tree_transporter import BisectingTreeTransporter | ||
from ..transporters.randomstate_transporter import RandomStateTransporter | ||
CLUSTERING_CHAIN["RandomStateTransporter"] = RandomStateTransporter() | ||
CLUSTERING_CHAIN["BisectingTreeTransporter"] = BisectingTreeTransporter() | ||
|
||
|
||
def is_clusterer(model): | ||
""" | ||
Check if the input model is a sklearn's clustering model. | ||
:param model: is a string name of a clusterer or a sklearn object of it | ||
:type model: any object | ||
:return: check result as bool | ||
""" | ||
if isinstance(model, str): | ||
return model in SKLEARN_CLUSTERING_TABLE | ||
else: | ||
return get_sklearn_type(model) in SKLEARN_CLUSTERING_TABLE.keys() | ||
|
||
|
||
def transport_clusterer(request, command, is_inner_model=False): | ||
""" | ||
Return the transported (Serialized or Deserialized) model. | ||
:param request: given clusterer to be transported | ||
:type request: any object | ||
:param command: command to specify whether the request should be serialized or deserialized | ||
:type command: transporter.Command | ||
:param is_inner_model: determines whether it is an inner model of a super ml model | ||
:type is_inner_model: boolean | ||
:return: the transported request as a json string or sklearn clustering model | ||
""" | ||
if not is_inner_model: | ||
_validate_input(request, command) | ||
|
||
if command == Command.SERIALIZE: | ||
try: | ||
return serialize_clusterer(request) | ||
except Exception as e: | ||
raise PymiloSerializationException( | ||
{ | ||
'error_type': SerializationErrorTypes.VALID_MODEL_INVALID_INTERNAL_STRUCTURE, | ||
'error': { | ||
'Exception': repr(e), | ||
'Traceback': format_exc(), | ||
}, | ||
'object': request, | ||
}) | ||
|
||
elif command == Command.DESERIALIZE: | ||
try: | ||
return deserialize_clusterer(request, is_inner_model) | ||
except Exception as e: | ||
raise PymiloDeserializationException( | ||
{ | ||
'error_type': SerializationErrorTypes.VALID_MODEL_INVALID_INTERNAL_STRUCTURE, | ||
'error': { | ||
'Exception': repr(e), | ||
'Traceback': format_exc()}, | ||
'object': request}) | ||
|
||
|
||
def serialize_clusterer(clusterer_object): | ||
""" | ||
Return the serialized json string of the given clustering model. | ||
:param clusterer_object: given model to be get serialized | ||
:type clusterer_object: any sklearn clustering model | ||
:return: the serialized json string of the given clusterer | ||
""" | ||
for transporter in CLUSTERING_CHAIN: | ||
CLUSTERING_CHAIN[transporter].transport( | ||
clusterer_object, Command.SERIALIZE) | ||
return clusterer_object.__dict__ | ||
|
||
|
||
def deserialize_clusterer(clusterer, is_inner_model=False): | ||
""" | ||
Return the associated sklearn clustering model of the given clusterer. | ||
:param clusterer: given json string of a clustering model to get deserialized to associated sklearn clustering model | ||
:type clusterer: obj | ||
:param is_inner_model: determines whether it is an inner model of a super ml model | ||
:type is_inner_model: boolean | ||
:return: associated sklearn clustering model | ||
""" | ||
raw_model = None | ||
data = None | ||
if is_inner_model: | ||
raw_model = SKLEARN_CLUSTERING_TABLE[clusterer["type"]]() | ||
data = clusterer["data"] | ||
else: | ||
raw_model = SKLEARN_CLUSTERING_TABLE[clusterer.type]() | ||
data = clusterer.data | ||
|
||
for transporter in CLUSTERING_CHAIN: | ||
CLUSTERING_CHAIN[transporter].transport( | ||
clusterer, Command.DESERIALIZE, is_inner_model) | ||
for item in data: | ||
setattr(raw_model, item, data[item]) | ||
return raw_model | ||
|
||
|
||
def _validate_input(model, command): | ||
""" | ||
Check if the provided inputs are valid in relation to each other. | ||
:param model: a sklearn clusterer model or a json string of it, serialized through the pymilo export. | ||
:type model: obj | ||
:param command: command to specify whether the request should be serialized or deserialized | ||
:type command: transporter.Command | ||
:return: None | ||
""" | ||
if command == Command.SERIALIZE: | ||
if is_clusterer(model): | ||
return | ||
else: | ||
raise PymiloSerializationException( | ||
{ | ||
'error_type': SerializationErrorTypes.INVALID_MODEL, | ||
'object': model | ||
} | ||
) | ||
elif command == Command.DESERIALIZE: | ||
if is_clusterer(model.type): | ||
return | ||
else: | ||
raise PymiloDeserializationException( | ||
{ | ||
'error_type': DeserializationErrorTypes.INVALID_MODEL, | ||
'object': model | ||
} | ||
) | ||
clustering_chain = AbstractChain(CLUSTERING_CHAIN, SKLEARN_CLUSTERING_TABLE) |
Oops, something went wrong.