Skip to content

Commit

Permalink
FEAT: Fix abs/rel path priority in result files. (#858)
Browse files Browse the repository at this point in the history
Co-authored-by: Alexandre <[email protected]>
  • Loading branch information
AlexandreGoettel and Alexandre authored Nov 14, 2024
1 parent 26ad7fb commit f4a813e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 9 deletions.
28 changes: 19 additions & 9 deletions bilby/core/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,7 @@ def save_to_file(self, filename=None, overwrite=False, outdir=None,
default=False
outdir: str, optional
Path to the outdir. Default is the one stored in the result object.
If given, overwrite path prefix in 'filename'.
extension: str, optional {json, hdf5, pkl, pickle, True}
Determines the method to use to store the data (if True defaults
to json)
Expand All @@ -780,11 +781,20 @@ def save_to_file(self, filename=None, overwrite=False, outdir=None,
if extension is True:
extension = "json"

_outdir = None
if filename is not None:
_outdir, filename = os.path.split(filename)
_outdir = None if _outdir == "" else _outdir
filename = f"{os.path.splitext(filename)[0]}.{extension}"

outdir = _outdir if outdir is None else outdir
outdir = self._safe_outdir_creation(outdir, self.save_to_file)
if filename is None:
filename = result_file_name(outdir, self.label, extension, gzip)
output_path = result_file_name(outdir, self.label, extension, gzip)
else:
output_path = os.path.join(outdir, filename)

move_old_file(filename, overwrite)
move_old_file(output_path, overwrite)

# Convert the prior to a string representation for saving on disk
dictionary = self._get_save_data_dictionary()
Expand All @@ -803,27 +813,27 @@ def save_to_file(self, filename=None, overwrite=False, outdir=None,
import gzip
# encode to a string
json_str = json.dumps(dictionary, cls=BilbyJsonEncoder).encode('utf-8')
with gzip.GzipFile(filename, 'w') as file:
with gzip.GzipFile(output_path, 'w') as file:
file.write(json_str)
else:
with open(filename, 'w') as file:
with open(output_path, 'w') as file:
json.dump(dictionary, file, indent=2, cls=BilbyJsonEncoder)
elif extension == 'hdf5':
import h5py
dictionary["__module__"] = self.__module__
dictionary["__name__"] = self.__class__.__name__
with h5py.File(filename, 'w') as h5file:
with h5py.File(output_path, 'w') as h5file:
recursively_save_dict_contents_to_group(h5file, '/', dictionary)
elif extension == 'pkl':
safe_file_dump(self, filename, "dill")
safe_file_dump(self, output_path, "dill")
else:
raise ValueError("Extension type {} not understood".format(extension))
except Exception as e:
filename = ".".join(filename.split(".")[:-1]) + ".pkl"
safe_file_dump(self, filename, "dill")
output_path = f"{os.path.splitext(output_path)[0]}.pkl"
safe_file_dump(self, output_path, "dill")
logger.error(
"\n\nSaving the data has failed with the following message:\n"
"{}\nData has been dumped to {}.\n\n".format(e, filename)
"{}\nData has been dumped to {}.\n\n".format(e, output_path)
)

def save_posterior_samples(self, filename=None, outdir=None, label=None):
Expand Down
13 changes: 13 additions & 0 deletions test/core/result_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,19 @@ def _save_and_dont_overwrite_test(self, extension):
self.result.save_to_file(overwrite=False, extension=extension)
self.assertTrue(os.path.isfile(f"{self.result.outdir}/{self.result.label}_result.{extension}.old"))

def _save_with_outdir_and_filename(self, filename, outdir, template):
self.result.save_to_file(filename=filename, outdir=outdir, extension="json", gzip=False)
self.assertTrue(os.path.isfile(template))

def test_save_with_outdir_and_filename(self):
self._save_with_outdir_and_filename("out/result", "out2", "out2/result.json")
self._save_with_outdir_and_filename("out/result", None, "out/result.json")
self._save_with_outdir_and_filename("result", "out", "out/result.json")
self._save_with_outdir_and_filename(
"result", None, os.path.join(self.result.outdir, "result.json"))
self._save_with_outdir_and_filename(
None, "out", os.path.join("out", f"{self.result.label}_result.json"))

def test_save_and_overwrite_json(self):
self._save_and_overwrite_test(extension='json')

Expand Down

0 comments on commit f4a813e

Please sign in to comment.