diff --git a/safer/__init__.py b/safer/__init__.py index f32ed35..846f91c 100644 --- a/safer/__init__.py +++ b/safer/__init__.py @@ -1,5 +1,4 @@ -""" -# 🧿 `safer`: A safer writer 🧿 +"""# 🧿 `safer`: A safer writer 🧿 Avoid partial writes or corruption! @@ -42,6 +41,10 @@ It also has a useful `dry_run` setting to let you test your code without actually overwriting the target file. +NOTE: Just like plain old `open`, if a file that is already opened for writing +is opened again before the first write has completed, the results are +unpredictable: so don't do it! + * `safer.writer()` wraps an existing writer, socket or stream and writes a whole response or nothing @@ -67,7 +70,6 @@ does not work on Windows. (In fact, it's unclear if any of this works on Windows, but that certainly won't. Windows developer solicted!) - ### Example: `safer.writer()` `safer.writer()` wraps an existing stream - a writer, socket, or callback - @@ -143,6 +145,7 @@ for item in items: print(item) # Either the whole file is written, or nothing + """ import contextlib import functools @@ -657,14 +660,16 @@ def __init__( self.target_file = target_file self.dry_run = dry_run self.is_binary = is_binary + if temp_file is True: + parent, file = os.path.split(target_file) + temp_file = os.path.join(parent, f'.{file}.tmp-safer') + super().__init__(temp_file, delete_failures, parent) def _success(self): if not self.dry_run: if os.path.exists(self.target_file): shutil.copymode(self.target_file, self.temp_file) - else: - os.chmod(self.temp_file, 0o100644) os.replace(self.temp_file, self.target_file) elif callable(self.dry_run): diff --git a/test/test_open.py b/test/test_open.py index c6e191c..24291a2 100644 --- a/test/test_open.py +++ b/test/test_open.py @@ -102,15 +102,15 @@ def test_two_errors(self, safer_open): fp.write('OK!') if uses_files: after = set(os.listdir('.')) - assert len(before) + 2 == len(after) - assert len(after.difference(before)) == 2 + assert len(before) + 1 == len(after) + assert len(after.difference(before)) == 1 assert FILENAME.read_text() == 'OK!' if uses_files: after = set(os.listdir('.')) - assert len(before) + 1 == len(after) - assert len(after.difference(before)) == 1 + assert len(before) == len(after) + assert len(after.difference(before)) == 0 def test_error_with_copy(self, safer_open): FILENAME.write_text('hello') @@ -214,3 +214,17 @@ def test_file_exists_error(self, safer_open): with safer_open(FILENAME, 'wt') as fp: fp.write('goodbye') assert FILENAME.read_text() == 'goodbye' + + def test_tempfile_perms(self, safer_open): + temp_files = False, True, 'three' + perms = [] + for temp_file in temp_files: + filename = str(temp_file) + if isinstance(temp_file, str): + temp_file += '.tmpfile' + with safer_open(filename, 'w', temp_file=temp_file): + pass + perms.append(os.stat(filename).st_mode) + + assert perms == [perms[0]] * len(perms) + assert perms[0] in (0o100644, 0o100664)