Skip to content

Commit

Permalink
Add multi-thread unittest for mlflow handler (#5755)
Browse files Browse the repository at this point in the history
Signed-off-by: binliu <[email protected]>
### Description

Add a multi-thread unit test for mlflow handler.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: binliu <[email protected]>
  • Loading branch information
binliunls authored Dec 15, 2022
1 parent 35db359 commit e50fa88
Showing 1 changed file with 45 additions and 0 deletions.
45 changes: 45 additions & 0 deletions tests/test_handler_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@

import glob
import os
import shutil
import tempfile
import unittest
from concurrent.futures import ThreadPoolExecutor

import numpy as np
from ignite.engine import Engine, Events
Expand All @@ -21,7 +23,38 @@
from monai.utils import path_to_uri


def dummy_train(tracking_folder):
tempdir = tempfile.mkdtemp()

# set up engine
def _train_func(engine, batch):
return [batch + 1.0]

engine = Engine(_train_func)

# set up testing handler
test_path = os.path.join(tempdir, tracking_folder)
handler = MLFlowHandler(
iteration_log=False,
epoch_log=True,
tracking_uri=path_to_uri(test_path),
state_attributes=["test"],
close_on_complete=True,
)
handler.attach(engine)
engine.run(range(3), max_epochs=2)
return test_path


class TestHandlerMLFlow(unittest.TestCase):
def setUp(self):
self.tmpdir_list = []

def tearDown(self):
for tmpdir in self.tmpdir_list:
if tmpdir and os.path.exists(tmpdir):
shutil.rmtree(tmpdir)

def test_metrics_track(self):
experiment_param = {"backbone": "efficientnet_b0"}
with tempfile.TemporaryDirectory() as tempdir:
Expand Down Expand Up @@ -61,6 +94,18 @@ def _update_metric(engine):
# check logging output
self.assertTrue(len(glob.glob(test_path)) > 0)

def test_multi_thread(self):
test_uri_list = ["monai_mlflow_test1", "monai_mlflow_test2"]
with ThreadPoolExecutor(2, "Training") as executor:
futures = {}
for t in test_uri_list:
futures[t] = executor.submit(dummy_train, t)

for _, future in futures.items():
res = future.result()
self.tmpdir_list.append(res)
self.assertTrue(len(glob.glob(res)) > 0)


if __name__ == "__main__":
unittest.main()

0 comments on commit e50fa88

Please sign in to comment.