Skip to content

Commit

Permalink
add support for new models
Browse files Browse the repository at this point in the history
  • Loading branch information
Olivier DEBAUCHE committed Jul 26, 2024
1 parent 98c0dab commit 16485b2
Show file tree
Hide file tree
Showing 10 changed files with 601 additions and 13 deletions.
3 changes: 2 additions & 1 deletion fl_common/dlcl_getfamilly.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,8 @@ def get_family_model_h(model_type, num_classes):
"hiera_base_plus_224": get_hiera_model,
"hiera_large_224": get_hiera_model,
"hiera_huge_224": get_hiera_model,
"hiera_small_abswin_256": get_hiera_model,
"hiera_base_abswin_256": get_hiera_model,
"hrnet_w18_small": get_hrnet_model,
"hrnet_w18_small_v2": get_hrnet_model,
"hrnet_w18": get_hrnet_model,
Expand Down Expand Up @@ -952,7 +954,6 @@ def get_family_model_m(model_type, num_classes):
"mobilenetv4_conv_aa_medium": get_mobilenet_model,
"mobilenetv4_conv_blur_medium": get_mobilenet_model,
"mobilenetv4_conv_aa_large": get_mobilenet_model,
"mobilenetv4_conv_blur_large": get_mobilenet_model,
"mobilenetv4_hybrid_medium_075": get_mobilenet_model,
"mobilenetv4_hybrid_large_075": get_mobilenet_model,
"repghostnetv2_conv_small": get_mobilenet_model,
Expand Down
2 changes: 2 additions & 0 deletions fl_common/models/hiera.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def get_hiera_model(hiera_type, num_classes):
"hiera_base_plus_224",
"hiera_large_224",
"hiera_huge_224",
'hiera_small_abswin_256',
'hiera_base_abswin_256',
}

if hiera_type not in valid_hiera_types:
Expand Down
Empty file.
1 change: 0 additions & 1 deletion fl_common/models/mobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def get_mobilenet_model(mobilenet_type, num_classes):
"mobilenetv4_conv_aa_medium",
"mobilenetv4_conv_blur_medium",
"mobilenetv4_conv_aa_large",
"mobilenetv4_conv_blur_large",
"mobilenetv4_hybrid_medium_075",
"mobilenetv4_hybrid_large_075",
"repghostnetv2_conv_small",
Expand Down
Empty file.
586 changes: 586 additions & 0 deletions fl_common/models/segmentation/snp.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions fl_common/tests/tests_models_hiera.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def test_known_hiera_types(self):
"hiera_base_plus_224",
"hiera_large_224",
"hiera_huge_224",
"hiera_small_abswin_256",
"hiera_base_abswin_256",
]

num_classes = 1000 # Assuming 1000 classes for the test
Expand Down
1 change: 0 additions & 1 deletion fl_common/tests/tests_models_mobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def test_get_mobilenet_model(self):
"mobilenetv4_conv_aa_medium",
"mobilenetv4_conv_blur_medium",
"mobilenetv4_conv_aa_large",
"mobilenetv4_conv_blur_large",
"mobilenetv4_hybrid_medium_075",
"mobilenetv4_hybrid_large_075",
# "repghostnetv2_conv_small",
Expand Down
8 changes: 4 additions & 4 deletions fl_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def register_user(node, email, password, name, institution, website):
Returns:
client: The registered client.
"""
client = node.login(email=email, password=password)
client = login(login_node=node, login_email=email, login_password=password)
client.register(
name=name,
email=email,
Expand Down Expand Up @@ -93,19 +93,19 @@ def launch_and_register(
return node, client


def login(node, login_email, login_password):
def login(login_node, login_email, login_password):
"""
Log in to the given node with the provided email and password.
Args:
node: The node to log in to.
login_node: The node to log in to.
login_email (str): The email to use for logging in.
login_password (str): The password to use for logging in.
Returns:
client: The client returned after successful login.
"""
return node.login(
return login_node.login(
email=login_email,
password=login_password,
)
Expand Down
11 changes: 5 additions & 6 deletions fl_server/tests/tests_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,6 @@ def test_launch_node_main(self):
"""
Test the main node launch functionality.
"""
ds_client = login(
"node_humani",
"[email protected]",
"abc123",
)

data_subjects = self.client_humani.data_subject_registry.get_all()
self.assertIsNotNone(data_subjects)

Expand Down Expand Up @@ -82,6 +76,11 @@ def test_launch_node_main(self):
)
self.client_humani.upload_dataset(dataset)

ds_client = login('node_humani',
"[email protected]",
"abc123",
)

asset = ds_client.datasets[-1].assets["ages"]
mock = asset.mock

Expand Down

0 comments on commit 16485b2

Please sign in to comment.