Skip to content

Commit

Permalink
Fix ci path (NVIDIA#2927)
Browse files Browse the repository at this point in the history
* Fix path

* Fix path

* Remove invalid validator

---------

Co-authored-by: Sean Yang <[email protected]>
  • Loading branch information
YuanTingHsieh and SYangster authored Sep 6, 2024
1 parent 5d71c2c commit f5d2025
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from nvflare.app_common.app_constant import ModelName

# (optional) set a fixed location so we don't need to download everytime
CIFAR10_ROOT = "~/data"
MODEL_SAVE_PATH_ROOT = "~/data"
CIFAR10_ROOT = "/tmp/nvflare/data"
MODEL_SAVE_PATH_ROOT = "/tmp/nvflare/data"

# (optional) We change to use GPU to speed things up.
# if you want to use CPU, change DEVICE="cpu"
Expand All @@ -41,7 +41,6 @@ def define_parser():
parser.add_argument("--batch_size", type=int, default=4, nargs="?")
parser.add_argument("--num_workers", type=int, default=1, nargs="?")
parser.add_argument("--local_epochs", type=int, default=2, nargs="?")
parser.add_argument("--model_path", type=str, default=f"{MODEL_SAVE_PATH_ROOT}/cifar_net.pth", nargs="?")
return parser.parse_args()


Expand All @@ -53,7 +52,6 @@ def main():
batch_size = args.batch_size
num_workers = args.num_workers
local_epochs = args.local_epochs
model_path = args.model_path

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root=dataset_path, train=True, download=True, transform=transform)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ tests:
"data": { "run_finished": True }
validators:
- path: tests.integration_test.src.validators.PTModelValidator
- path: tests.integration_test.src.validators.CrossValResultValidator
args: { server_model_names: [ "server" ] }
setup:
- python -c "from torchvision.datasets import CIFAR10; CIFAR10(root='~/data/', download=True)"
- mkdir -p /tmp/nvflare/data/site-1
- mkdir -p /tmp/nvflare/data/site-2
- python -c "from torchvision.datasets import CIFAR10; CIFAR10(root='/tmp/nvflare/data', download=True)"
teardown:
- rm -rf ~/data
- rm -rf /tmp/nvflare/data

0 comments on commit f5d2025

Please sign in to comment.