Skip to content

Commit

Permalink
Fix incorrect checks for csv saving
Browse files Browse the repository at this point in the history
  • Loading branch information
C-Achard committed Mar 25, 2024
1 parent 315f73f commit 816978d
Showing 1 changed file with 26 additions and 9 deletions.
35 changes: 26 additions & 9 deletions napari_cellseg3d/code_plugins/plugin_model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,23 +1368,40 @@ def on_yield(self, report: TrainingReport):
def _make_csv(self):
size_column = range(1, self.worker_config.max_epochs + 1)
# this assumption does not hold when training is stopped
if len(size_column) != len(self.loss_1_values):
logger.info(
f"Training was stopped, setting epochs for csv to {len(self.loss_1_values)}"
)
size_column = range(1, len(self.loss_1_values) + 1)

if len(self.loss_1_values) == 0 or self.loss_1_values is None:
logger.warning("No loss values to add to csv !")
return

try:
self.loss_1_values["Loss"]
supervised = True
if len(size_column) != len(self.loss_1_values["Loss"]):
logger.info(
f"Training was stopped, setting epochs for csv to {len(self.loss_1_values['Loss'])}"
)
size_column = range(1, len(self.loss_1_values["Loss"]) + 1)

if (
len(self.loss_1_values["Loss"]) == 0
or self.loss_1_values["Loss"] is None
):
logger.warning("No loss values to add to csv !")
return
except KeyError:
try:
self.loss_1_values["SoftNCuts"]
supervised = False
if len(size_column) != len(self.loss_1_values["SoftNCuts"]):
logger.info(
f"Training was stopped, setting epochs for csv to {len(self.loss_1_values['SoftNCuts'])}"
)
size_column = range(
1, len(self.loss_1_values["SoftNCuts"]) + 1
)

if (
len(self.loss_1_values["SoftNCuts"]) == 0
or self.loss_1_values["SoftNCuts"] is None
):
logger.warning("No loss values to add to csv !")
return
except KeyError as e:
raise KeyError(
"Error when making csv. Check loss dict keys ?"
Expand Down

0 comments on commit 816978d

Please sign in to comment.