Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring #180

Merged
merged 12 commits into from
Jan 19, 2025
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- `OVERVIEW` in `pymilo_param.py`
- `get_sklearn_class` in `utils.util.py`
### Changed
- `ML Streaming` testcases modified to use PyMilo CLI
- `to_pymilo_issue` function in `PymiloException`
- `valid_url_valid_file` testcase added in `test_exceptions.py`
- `valid_url_valid_file` function in `import_exceptions.py`
Expand Down
21 changes: 12 additions & 9 deletions pymilo/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
import re
import argparse
from art import tprint
from .pymilo_param import PYMILO_VERSION, URL_REGEX
from .pymilo_param import (
PYMILO_VERSION,
URL_REGEX,
CLI_MORE_INFO,
CLI_UNKNOWN_MODEL,
CLI_ML_STREAMING_NOT_INSTALLED,
)
from .pymilo_func import print_supported_ml_models, pymilo_help
from .pymilo_obj import Import
from .utils.util import get_sklearn_class
Expand Down Expand Up @@ -70,10 +76,8 @@ def main():
print(PYMILO_VERSION)
return
if not ml_streaming_support:
print("Error: ML Streaming is not installed.")
print("To install ML Streaming, run the following command:")
print("pip install pymilo[streaming]")
print("For more information, visit the PyMilo README at https://github.com/openscilab/pymilo")
print(CLI_ML_STREAMING_NOT_INSTALLED)
print(CLI_MORE_INFO)
tprint("PyMilo")
tprint("V:" + PYMILO_VERSION)
pymilo_help()
Expand All @@ -88,19 +92,17 @@ def main():
path = args.load
run_ps = True
_model = Import(url=path) if re.match(URL_REGEX, path) else Import(file_adr=path)
_model = _model.to_model()
elif args.init:
model_name = args.init
model_class = get_sklearn_class(model_name)
if model_class is None:
print(
"The given ML model name is neither valid nor supported, use the list below: \n{print_supported_ml_models}")
print_supported_ml_models()
print(f"{CLI_UNKNOWN_MODEL}\n{print_supported_ml_models()}")
return
run_ps = True
_model = model_class()
elif args.bare:
run_ps = True
_model = model_class()
if not run_ps:
tprint("PyMilo")
tprint("V:" + PYMILO_VERSION)
Expand All @@ -114,5 +116,6 @@ def main():
communication_protocol=_communication_protocol,
).communicator.run()


if __name__ == '__main__':
main()
6 changes: 6 additions & 0 deletions pymilo/chains/ensemble_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,10 @@ def deserialize(self, ensemble, is_inner_model=False):
setattr(raw_model, item, data[item])
return raw_model


ensemble_chain = EnsembleModelChain(ENSEMBLE_CHAIN, SKLEARN_ENSEMBLE_TABLE)


def get_transporter(model):
"""
Get associated transporter for the given ML model.
Expand All @@ -188,6 +190,7 @@ def get_transporter(model):
else:
return get_concrete_transporter(model)


def serialize_possible_ml_model(possible_ml_model):
"""
Check whether the given object is a ML model and if it is, serialize it.
Expand All @@ -209,6 +212,7 @@ def serialize_possible_ml_model(possible_ml_model):
else:
return False, possible_ml_model


def deserialize_possible_ml_model(possible_serialized_ml_model):
"""
Check whether the given object is previously serialized ML model and if it is, deserialize it back to the associated ML model.
Expand All @@ -226,6 +230,7 @@ def deserialize_possible_ml_model(possible_serialized_ml_model):
else:
return False, possible_serialized_ml_model


def serialize_models_in_ndarray(ndarray_instance):
"""
Serialize the ml models inside the given ndarray.
Expand Down Expand Up @@ -268,6 +273,7 @@ def serialize_models_in_ndarray(ndarray_instance):
'pymiloed-data-structure': 'numpy.ndarray'
}


def deserialize_models_in_ndarray(serialized_ndarray):
"""
Deserializes possible ML models within the given ndarray instance.
Expand Down
2 changes: 2 additions & 0 deletions pymilo/chains/linear_model_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,10 @@ def deserialize(self, linear_model, is_inner_model=False):
setattr(raw_model, item, data[item])
return raw_model


linear_chain = LinearModelChain(LINEAR_MODEL_CHAIN, SKLEARN_LINEAR_MODEL_TABLE)


def is_deserialized_linear_model(content):
"""
Check if the given content is a previously serialized model by Pymilo's Export or not.
Expand Down
6 changes: 6 additions & 0 deletions pymilo/pymilo_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@
INVALID_DOWNLOADED_MODEL = "The downloaded content is not a valid JSON file."
BATCH_IMPORT_INVALID_DIRECTORY = "The given directory does not exist."

CLI_ML_STREAMING_NOT_INSTALLED = """ML Streaming is not installed.
To install ML Streaming, run the following command:
pip install pymilo[streaming]"""
CLI_MORE_INFO = "For more information, visit the PyMilo README at https://github.com/openscilab/pymilo"
CLI_UNKNOWN_MODEL = "The provided ML model name is either invalid or unsupported."

SKLEARN_LINEAR_MODEL_TABLE = {
"DummyRegressor": dummy.DummyRegressor,
"DummyClassifier": dummy.DummyClassifier,
Expand Down
1 change: 0 additions & 1 deletion pymilo/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,5 +179,4 @@ def get_sklearn_class(model_name):
for _, category_models in SKLEARN_SUPPORTED_CATEGORIES.items():
if model_name in category_models:
return category_models[model_name]
# todo raise exception
return None
2 changes: 1 addition & 1 deletion tests/test_ml_streaming/run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ def main():
communicator.run()

if __name__ == '__main__':
main()
main()
36 changes: 29 additions & 7 deletions tests/test_ml_streaming/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@
params=["NULL", "GZIP", "ZLIB", "LZMA", "BZ2"])
def prepare_bare_server(request):
compression_method = request.param
# Using PyMilo direct CLI
# server_proc = subprocess.Popen(
# [
# executable,
# "-m", "pymilo",
# "--compression", compression_method,
# "--protocol", "REST",
# "--port", "8000",
# "--bare",
# ],
# )
path = os.path.join(
os.getcwd(),
"tests",
Expand All @@ -28,7 +39,7 @@ def prepare_bare_server(request):
],
)
time.sleep(10)
yield (server_proc, compression_method, "REST")
yield (compression_method, "REST")
server_proc.terminate()


Expand All @@ -38,7 +49,18 @@ def prepare_bare_server(request):
def prepare_ml_server(request):
communication_protocol = request.param
compression_method = "ZLIB"
print(communication_protocol)
# Using PyMilo direct CLI
# server_proc = subprocess.Popen(
# [
# executable,
# "-m", "pymilo",
# "--compression", compression_method,
# "--protocol", communication_protocol,
# "--port", "9000",
# "--load", os.path.join(os.getcwd(), "tests", "test_exceptions", "valid_jsons", "linear_regression.json")
# # "--load", "https://raw.githubusercontent.com/openscilab/pymilo/main/tests/test_exceptions/valid_jsons/linear_regression.json",
# ],
# )
path = os.path.join(
os.getcwd(),
"tests",
Expand All @@ -54,21 +76,21 @@ def prepare_ml_server(request):
"--init",
],
)
time.sleep(5)
yield (server_proc, compression_method, communication_protocol)
time.sleep(10)
yield (compression_method, communication_protocol)
server_proc.terminate()


def test1(prepare_bare_server):
_, compression_method, communication_protocol = prepare_bare_server
compression_method, communication_protocol = prepare_bare_server
assert scenario1(compression_method, communication_protocol) == 0


def test2(prepare_bare_server):
_, compression_method, communication_protocol = prepare_bare_server
compression_method, communication_protocol = prepare_bare_server
assert scenario2(compression_method, communication_protocol) == 0


def test3(prepare_ml_server):
_, compression_method, communication_protocol = prepare_ml_server
compression_method, communication_protocol = prepare_ml_server
assert scenario3(compression_method, communication_protocol) == 0