Skip to content
This repository has been archived by the owner on Aug 26, 2020. It is now read-only.

Commit

Permalink
change: add test for network isolation mode training (#192)
Browse files Browse the repository at this point in the history
  • Loading branch information
wiltonwu authored and mvsusp committed May 7, 2019
1 parent 52cb4b6 commit d782e86
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions test/functional/test_training_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,42 @@ def test_training_framework(user_script, capture_error):
assert model.optimizer == 'Adam'


@pytest.mark.parametrize('user_script, capture_error',
[[USER_SCRIPT_WITH_SAVE, False], [USER_SCRIPT_WITH_SAVE, True]])
def test_training_framework_network_isolation(user_script, capture_error):
with pytest.raises(ImportError):
importlib.import_module(modules.DEFAULT_MODULE_NAME)

channel_train = test.Channel.create(name='training')
channel_code = test.Channel.create(name='code')

features = [1, 2, 3, 4]
labels = [0, 1, 0, 1]
np.savez(os.path.join(channel_train.path, 'training_data'), features=features, labels=labels)

with open(os.path.join(channel_code.path, 'user_script.py'), 'w') as f:
f.write(user_script)
f.close()

module = test.UserModule(test.File(name='user_script.py', data="")) # dummy module for hyperparameters

submit_dir = env.input_dir + '/data/code'
hyperparameters = dict(training_data_file='training_data.npz', sagemaker_program='user_script.py',
sagemaker_submit_directory=submit_dir, epochs=10, batch_size=64, optimizer='Adam')

test.prepare(user_module=module, hyperparameters=hyperparameters, channels=[channel_train, channel_code],
local=True)

assert execute_an_wrap_exit(framework_training_fn) == trainer.SUCCESS_CODE

model_path = os.path.join(env.model_dir, 'saved_model')
model = fake_ml_framework.Model.load(model_path)

assert model.epochs == 10
assert model.batch_size == 64
assert model.optimizer == 'Adam'


@pytest.mark.parametrize('user_script, sagemaker_program', [
[USER_MODE_SCRIPT, 'user_script.py'],
[BASH_SCRIPT, 'bash_script']
Expand Down

0 comments on commit d782e86

Please sign in to comment.