Skip to content

Commit

Permalink
pathlib for script/automate_training.py (ivadomed#881)
Browse files Browse the repository at this point in the history
* pathlib for script/automate_training.py

* fix pathlib

Co-authored-by: Yang Ding <[email protected]>
  • Loading branch information
cakester and dyt811 authored Aug 14, 2021
1 parent 5daf38a commit 93b9e73
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions ivadomed/scripts/automate_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from functools import partial
import json
import logging
import os
import random
import collections.abc
import shutil
Expand All @@ -27,6 +26,7 @@
import numpy as np
import torch.multiprocessing as mp
import ivadomed.scripts.visualize_and_compare_testing_models as violin_plots
from pathlib import Path
from ivadomed import main as ivado
from ivadomed import config_manager as imed_config_manager
from ivadomed.loader import utils as imed_loader_utils
Expand Down Expand Up @@ -153,14 +153,14 @@ def split_dataset(initial_config):
}
"""
loader_parameters = initial_config["loader_parameters"]
path_output = initial_config["path_output"]
if not os.path.isdir(path_output):
path_output = Path(initial_config["path_output"])
if not path_output.is_dir():
print('Creating output path: {}'.format(path_output))
os.makedirs(path_output)
path_output.mkdir(parents=True)
else:
print('Output path already exists: {}'.format(path_output))

bids_df = imed_loader_utils.BidsDataframe(loader_parameters, path_output, derivatives=True)
bids_df = imed_loader_utils.BidsDataframe(loader_parameters, str(path_output), derivatives=True)

train_lst, valid_lst, test_lst = imed_loader_utils.get_new_subject_file_split(
df=bids_df.df,
Expand Down Expand Up @@ -654,8 +654,8 @@ def automate_training(file_config, file_config_hyper, fixed_split, all_combin, p
applied, then all the second, etc.
output_dir (str): Path to where the results will be saved.
"""
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
if output_dir and not Path(output_dir).exists():
Path(output_dir).mkdir(parents=True)
if not output_dir:
output_dir = ""

Expand All @@ -670,7 +670,7 @@ def automate_training(file_config, file_config_hyper, fixed_split, all_combin, p
initial_config = split_dataset(initial_config)

# Hyperparameters values to experiment
with open(file_config_hyper, "r") as fhandle:
with Path(file_config_hyper).open(mode="r") as fhandle:
config_hyper = json.load(fhandle)

param_list = get_param_list(config_hyper, [], [])
Expand Down Expand Up @@ -711,16 +711,16 @@ def automate_training(file_config, file_config_hyper, fixed_split, all_combin, p
new_config_list = []
for config in config_list:
# Delete path_pred
path_pred = os.path.join(config['path_output'], 'pred_masks')
if os.path.isdir(path_pred) and n_iterations > 1:
path_pred = Path(config['path_output'], 'pred_masks')
if path_pred.is_dir() and n_iterations > 1:
try:
shutil.rmtree(path_pred)
shutil.rmtree(str(path_pred))
except OSError as e:
logging.info("Error: %s - %s." % (e.filename, e.strerror))

# Take the config file within the path_output because binarize_prediction may have been updated
json_path = os.path.join(config['path_output'], 'config_file.json')
new_config = imed_config_manager.ConfigurationManager(json_path).get_config()
json_path = Path(config['path_output'], 'config_file.json')
new_config = imed_config_manager.ConfigurationManager(str(json_path)).get_config()
new_config["gpu_ids"] = config["gpu_ids"]
new_config_list.append(new_config)

Expand Down Expand Up @@ -761,11 +761,11 @@ def automate_training(file_config, file_config_hyper, fixed_split, all_combin, p
combined_df = val_df

results_df = pd.concat([results_df, combined_df])
results_df.to_csv(os.path.join(output_dir, "temporary_results.csv"))
eval_df.to_csv(os.path.join(output_dir, "average_eval.csv"))
results_df.to_csv(str(Path(output_dir, "temporary_results.csv")))
eval_df.to_csv(str(Path(output_dir, "average_eval.csv")))

results_df = format_results(results_df, config_list, param_list)
results_df.to_csv(os.path.join(output_dir, "detailed_results.csv"))
results_df.to_csv(str(Path(output_dir, "detailed_results.csv")))

logging.info("Detailed results")
logging.info(results_df)
Expand Down

0 comments on commit 93b9e73

Please sign in to comment.