Skip to content

Commit

Permalink
add comments and type hinting in places where it may have been ambigu…
Browse files Browse the repository at this point in the history
…ous (#138)
  • Loading branch information
aopatric authored Nov 5, 2024
1 parent 1bfe422 commit 481c367
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 13 deletions.
8 changes: 6 additions & 2 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@
import argparse
import logging

from typing import Optional
from scheduler import Scheduler

logging.getLogger("PIL").setLevel(logging.INFO)
logging.basicConfig(level=logging.DEBUG) # Enable detailed logging

# Default config file paths
B_DEFAULT: str = "./configs/algo_config.py"
S_DEFAULT: str = "./configs/sys_config.py"

parser = argparse.ArgumentParser(description="Run collaborative learning experiments")
# Parse args
parser : argparse.ArgumentParser = argparse.ArgumentParser(description="Run collaborative learning experiments")
parser.add_argument(
"-b",
nargs="?",
Expand Down Expand Up @@ -50,7 +53,7 @@
print("Config loaded")

# Log and check key configuration values to prevent errors like division by zero
num_users = scheduler.config.get("num_users", None)
num_users : Optional[int] = scheduler.config.get("num_users", None)
if num_users is None:
logging.error(
"The number of users (num_users) is not defined in the configuration."
Expand All @@ -64,6 +67,7 @@

logging.info(f"Running experiment with {num_users} users.")

# Start the scheduler
scheduler.install_config()
scheduler.initialize()

Expand Down
8 changes: 5 additions & 3 deletions src/main_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import subprocess
from typing import List

parser = argparse.ArgumentParser(description="Number of nodes to run on this machine")
# Parse args
parser : argparse.ArgumentParser = argparse.ArgumentParser(description="Number of nodes to run on this machine")
parser.add_argument(
"-n",
nargs="?",
Expand All @@ -22,11 +23,12 @@
help=f"host address of the nodes",
)

args = parser.parse_args()
args : argparse.Namespace = parser.parse_args()

# Command for opening each process
command_list: List[str] = ["python", "main.py", "-host", args.host]
# if the super-node is to be started on this machine

# Start process for each user
for i in range(args.n):
print(f"Starting process for user {i}")
# start a Popen process
Expand Down
16 changes: 8 additions & 8 deletions src/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,28 +82,28 @@ def assign_config_by_path(
is_super_node: bool | None = None,
host: str | None = None,
) -> None:
self.sys_config = load_config(sys_config_path)
self.sys_config : Dict[str, Any]= load_config(sys_config_path)
if is_super_node:
self.sys_config["comm"]["rank"] = 0
else:
self.sys_config["comm"]["host"] = host
self.sys_config["comm"]["rank"] = None
self.config = {}
self.config : Dict[str, Any]= {}
self.config.update(self.sys_config)

def merge_configs(self) -> None:
self.config.update(self.sys_config)
node_name = "node_{}".format(self.communication.get_rank())
self.algo_config = self.sys_config["algos"][node_name]
node_name : str = "node_{}".format(self.communication.get_rank())
self.algo_config : Dict[str, Any] = self.sys_config["algos"][node_name]
self.config.update(self.algo_config)
self.config["dropout_dicts"] = self.sys_config.get("dropout_dicts", {}).get(node_name, {})

def initialize(self, copy_souce_code: bool = True) -> None:
assert self.config is not None, "Config should be set when initializing"
self.communication = CommunicationManager(self.config)
self.communication : CommunicationManager = CommunicationManager(self.config)
self.config["comm"]["rank"] = self.communication.get_rank()
# Base clients modify the seed later on
seed = self.config["seed"]
seed : int = self.config["seed"]
torch.manual_seed(seed) # type: ignore
random.seed(seed)
numpy.random.seed(seed)
Expand All @@ -113,7 +113,7 @@ def initialize(self, copy_souce_code: bool = True) -> None:
if copy_souce_code:
copy_source_code(self.config)
else:
path = self.config["results_path"]
path : str = self.config["results_path"]
check_and_create_path(path)
os.mkdir(self.config["saved_models"])
os.mkdir(self.config["log_path"])
Expand All @@ -125,7 +125,7 @@ def initialize(self, copy_souce_code: bool = True) -> None:
print("Waiting for 10 seconds for the super node to create directories")
time.sleep(10)

self.node = get_node(
self.node : BaseNode = get_node(
self.config,
rank=self.communication.get_rank(),
comm_utils=self.communication,
Expand Down

0 comments on commit 481c367

Please sign in to comment.