diff --git a/tests/integration_test/data/jobs/hello-pt-cse/app/custom/train.py b/tests/integration_test/data/jobs/hello-pt-cse/app/custom/train.py index c16b1cccbe..028b2a4bb3 100644 --- a/tests/integration_test/data/jobs/hello-pt-cse/app/custom/train.py +++ b/tests/integration_test/data/jobs/hello-pt-cse/app/custom/train.py @@ -27,8 +27,8 @@ from nvflare.app_common.app_constant import ModelName # (optional) set a fixed location so we don't need to download everytime -CIFAR10_ROOT = "~/data" -MODEL_SAVE_PATH_ROOT = "~/data" +CIFAR10_ROOT = "/tmp/nvflare/data" +MODEL_SAVE_PATH_ROOT = "/tmp/nvflare/data" # (optional) We change to use GPU to speed things up. # if you want to use CPU, change DEVICE="cpu" @@ -41,7 +41,6 @@ def define_parser(): parser.add_argument("--batch_size", type=int, default=4, nargs="?") parser.add_argument("--num_workers", type=int, default=1, nargs="?") parser.add_argument("--local_epochs", type=int, default=2, nargs="?") - parser.add_argument("--model_path", type=str, default=f"{MODEL_SAVE_PATH_ROOT}/cifar_net.pth", nargs="?") return parser.parse_args() @@ -53,7 +52,6 @@ def main(): batch_size = args.batch_size num_workers = args.num_workers local_epochs = args.local_epochs - model_path = args.model_path transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) trainset = torchvision.datasets.CIFAR10(root=dataset_path, train=True, download=True, transform=transform) diff --git a/tests/integration_test/data/test_configs/standalone_job/pt_job.yml b/tests/integration_test/data/test_configs/standalone_job/pt_job.yml index 86964c2c1d..1147d50f2d 100644 --- a/tests/integration_test/data/test_configs/standalone_job/pt_job.yml +++ b/tests/integration_test/data/test_configs/standalone_job/pt_job.yml @@ -45,9 +45,9 @@ tests: "data": { "run_finished": True } validators: - path: tests.integration_test.src.validators.PTModelValidator - - path: tests.integration_test.src.validators.CrossValResultValidator - args: { server_model_names: [ "server" ] } setup: - - python -c "from torchvision.datasets import CIFAR10; CIFAR10(root='~/data/', download=True)" + - mkdir -p /tmp/nvflare/data/site-1 + - mkdir -p /tmp/nvflare/data/site-2 + - python -c "from torchvision.datasets import CIFAR10; CIFAR10(root='/tmp/nvflare/data', download=True)" teardown: - - rm -rf ~/data + - rm -rf /tmp/nvflare/data