From d782e869e130e9067d5152047da872a50f399714 Mon Sep 17 00:00:00 2001 From: Wilton Wu Date: Tue, 7 May 2019 18:38:29 -0400 Subject: [PATCH] change: add test for network isolation mode training (#192) --- test/functional/test_training_framework.py | 36 ++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/test/functional/test_training_framework.py b/test/functional/test_training_framework.py index bb2ab17..a91aacc 100644 --- a/test/functional/test_training_framework.py +++ b/test/functional/test_training_framework.py @@ -225,6 +225,42 @@ def test_training_framework(user_script, capture_error): assert model.optimizer == 'Adam' +@pytest.mark.parametrize('user_script, capture_error', + [[USER_SCRIPT_WITH_SAVE, False], [USER_SCRIPT_WITH_SAVE, True]]) +def test_training_framework_network_isolation(user_script, capture_error): + with pytest.raises(ImportError): + importlib.import_module(modules.DEFAULT_MODULE_NAME) + + channel_train = test.Channel.create(name='training') + channel_code = test.Channel.create(name='code') + + features = [1, 2, 3, 4] + labels = [0, 1, 0, 1] + np.savez(os.path.join(channel_train.path, 'training_data'), features=features, labels=labels) + + with open(os.path.join(channel_code.path, 'user_script.py'), 'w') as f: + f.write(user_script) + f.close() + + module = test.UserModule(test.File(name='user_script.py', data="")) # dummy module for hyperparameters + + submit_dir = env.input_dir + '/data/code' + hyperparameters = dict(training_data_file='training_data.npz', sagemaker_program='user_script.py', + sagemaker_submit_directory=submit_dir, epochs=10, batch_size=64, optimizer='Adam') + + test.prepare(user_module=module, hyperparameters=hyperparameters, channels=[channel_train, channel_code], + local=True) + + assert execute_an_wrap_exit(framework_training_fn) == trainer.SUCCESS_CODE + + model_path = os.path.join(env.model_dir, 'saved_model') + model = fake_ml_framework.Model.load(model_path) + + assert model.epochs == 10 + assert model.batch_size == 64 + assert model.optimizer == 'Adam' + + @pytest.mark.parametrize('user_script, sagemaker_program', [ [USER_MODE_SCRIPT, 'user_script.py'], [BASH_SCRIPT, 'bash_script']