Skip to content

Commit

Permalink
Add DLSIA prefix to model names
Browse files Browse the repository at this point in the history
  • Loading branch information
Wiebke committed Jul 10, 2024
1 parent e76c22a commit 1ec91dd
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 20 deletions.
2 changes: 1 addition & 1 deletion example_yamls/example_msdnet_customdil.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ io_parameters:
models_dir: .

model_parameters:
network: "MSDNet"
network: "DLSIA MSDNet"
num_classes: 3
num_epochs: 3
optimizer: "Adam"
Expand Down
2 changes: 1 addition & 1 deletion example_yamls/example_msdnet_maxdil.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ io_parameters:
models_dir: .

model_parameters:
network: "MSDNet"
network: "DLSIA MSDNet"
num_classes: 3
num_epochs: 3
optimizer: "Adam"
Expand Down
2 changes: 1 addition & 1 deletion example_yamls/example_smsnet_ensemble.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ io_parameters:
models_dir: .

model_parameters:
network: "SMSNetEnsemble"
network: "DLSIA SMSNetEnsemble"
num_classes: 3
num_epochs: 3
optimizer: "Adam"
Expand Down
2 changes: 1 addition & 1 deletion example_yamls/example_tunet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ io_parameters:
models_dir: .

model_parameters:
network: "TUNet"
network: "DLSIA TUNet"
num_classes: 3
num_epochs: 3
optimizer: "Adam"
Expand Down
2 changes: 1 addition & 1 deletion example_yamls/example_tunet3plus.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ io_parameters:
models_dir: .

model_parameters:
network: "TUNet3+"
network: "DLSIA TUNet3+"
num_classes: 3
num_epochs: 3
optimizer: "Adam"
Expand Down
14 changes: 7 additions & 7 deletions src/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def build_network(
if parameters.convolution is not None:
convolution = getattr(nn, parameters.convolution)

if network == "MSDNet":
if network == "DLSIA MSDNet":
network = build_msdnet(
in_channels,
out_channels,
Expand All @@ -233,7 +233,7 @@ def build_network(
convolution,
)

elif network == "TUNet":
elif network == "DLSIA TUNet":
network = build_tunet(
in_channels,
out_channels,
Expand All @@ -243,7 +243,7 @@ def build_network(
normalization,
)

elif network == "TUNet3+":
elif network == "DLSIA TUNet3+":
network = build_tunet3plus(
in_channels,
out_channels,
Expand All @@ -253,7 +253,7 @@ def build_network(
normalization,
)

elif network == "SMSNetEnsemble":
elif network == "DLSIA SMSNetEnsemble":
network = build_smsnet_ensemble(
in_channels,
out_channels,
Expand All @@ -268,13 +268,13 @@ def load_network(
params_path,
):

if network == "MSDNet":
if network == "DLSIA MSDNet":
network = msdnet.MSDNetwork_from_file(params_path)

elif network == "TUNet":
elif network == "DLSIA TUNet":
network = tunet.TUNetwork_from_file(params_path)

elif network == "TUNet3+":
elif network == "DLSIA TUNet3+":
network = tunet3plus.TUNetwork3Plus_from_file(params_path)

return network
Expand Down
8 changes: 4 additions & 4 deletions src/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@
network = raw_parameters["network"]

model_parameters = None
if network == "MSDNet":
if network == "DLSIA MSDNet":
model_parameters = MSDNetParameters(**raw_parameters)
elif network == "TUNet":
elif network == "DLSIA TUNet":
model_parameters = TUNetParameters(**raw_parameters)
elif network == "TUNet3+":
elif network == "DLSIA TUNet3+":
model_parameters = TUNet3PlusParameters(**raw_parameters)
elif network == "SMSNetEnsemble":
elif network == "DLSIA SMSNetEnsemble":
model_parameters = SMSNetEnsembleParameters(**raw_parameters)

assert model_parameters, f"Received Unsupported Network: {network}"
Expand Down
8 changes: 4 additions & 4 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ def train(args):
network = raw_parameters["network"]

model_parameters = None
if network == "MSDNet":
if network == "DLSIA MSDNet":
model_parameters = MSDNetParameters(**raw_parameters)
elif network == "TUNet":
elif network == "DLSIA TUNet":
model_parameters = TUNetParameters(**raw_parameters)
elif network == "TUNet3+":
elif network == "DLSIA TUNet3+":
model_parameters = TUNet3PlusParameters(**raw_parameters)
elif network == "SMSNetEnsemble":
elif network == "DLSIA SMSNetEnsemble":
model_parameters = SMSNetEnsembleParameters(**raw_parameters)

assert model_parameters, f"Received Unsupported Network: {network}"
Expand Down

0 comments on commit 1ec91dd

Please sign in to comment.