Skip to content

Commit

Permalink
Added type hinting in main.py. Couldn't run nvidia, so took out (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
zaeeshah authored Sep 6, 2024
1 parent 4ab89af commit f3dd666
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 17 deletions.
26 changes: 13 additions & 13 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,18 @@ nest-asyncio==1.6.0
networkx==3.3
notebook_shim==0.2.4
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.5.40
nvidia-nvtx-cu12==12.1.105
#nvidia-cublas-cu12==12.1.3.1
#nvidia-cuda-cupti-cu12==12.1.105
#nvidia-cuda-nvrtc-cu12==12.1.105
#nvidia-cuda-runtime-cu12==12.1.105
#nvidia-cudnn-cu12==8.9.2.26
#nvidia-cufft-cu12==11.0.2.54
#nvidia-curand-cu12==10.3.2.106
#nvidia-cusolver-cu12==11.4.5.107
#nvidia-cusparse-cu12==12.1.0.106
#nvidia-nccl-cu12==2.20.5
#nvidia-nvjitlink-cu12==12.5.40
#nvidia-nvtx-cu12==12.1.105
ogb==1.3.6
outdated==0.2.2
overrides==7.7.0
Expand Down Expand Up @@ -145,7 +145,7 @@ torchvision==0.18.0
tornado==6.4
tqdm==4.66.4
traitlets==5.14.3
triton==2.3.0
#triton==2.3.0
types-python-dateutil==2.9.0.20240316
typing_extensions==4.12.0
tzdata==2024.1
Expand Down
8 changes: 4 additions & 4 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

logging.getLogger("PIL").setLevel(logging.INFO)

B_DEFAULT = "./configs/algo_config.py"
S_DEFAULT = "./configs/sys_config.py"
B_DEFAULT: str = "./configs/algo_config.py"
S_DEFAULT: str = "./configs/sys_config.py"

parser = argparse.ArgumentParser(description="Run collaborative learning experiments")
parser.add_argument(
Expand Down Expand Up @@ -40,9 +40,9 @@
help=f"host address of the nodes",
)

args = parser.parse_args()
args: argparse.Namespace = parser.parse_args()

scheduler = Scheduler()
scheduler: Scheduler = Scheduler()
scheduler.assign_config_by_path(args.s, args.b, args.super, args.host)
print("Config loaded")

Expand Down

0 comments on commit f3dd666

Please sign in to comment.