diff --git a/repobee_sanitizer/_sanitize.py b/repobee_sanitizer/_sanitize.py index 7e0d3ae..cee44a9 100644 --- a/repobee_sanitizer/_sanitize.py +++ b/repobee_sanitizer/_sanitize.py @@ -3,7 +3,6 @@ .. module:: _sanitize_file :synopsis: Module for file sanitization functionality. """ -import pathlib from repobee_sanitizer import _syntax from repobee_sanitizer._syntax import Markers @@ -12,38 +11,15 @@ from typing import List, Iterable, Optional -def sanitize_file( - file_abs_path: pathlib.Path, strip: bool = False -) -> Optional[str]: - """Runs the sanitization protocol on a given file. This can either remove - the file or give back a sanitized version. File must be syntax checked - before running this. +def sanitize_text(lines: List[str], strip: bool = False) -> Optional[str]: + """A function to directly sanitize given content. Text must be syntax + checked first. Args: - file_abs_path: The absolute file path to the file you wish to - sanitize. - - Returns: - The sanitized output text, but only if the file was not - removed. + Content to be sanitized """ - text = file_abs_path.read_text() - lines = text.split("\n") if _syntax.contained_marker(lines[0]) == Markers.SHRED: return None - else: - sanitized_string = _sanitize(lines, strip=False) - return "\n".join(sanitized_string) - - -def sanitize_text(content: str, strip: bool = False) -> str: - """A function to directly sanitize given content. - - Args: - Content to be sanitized. - """ - lines = content.split("\n") - _syntax.check_syntax(lines) sanitized_string = _sanitize(lines, strip) return "\n".join(sanitized_string) diff --git a/repobee_sanitizer/_sanitize_repo.py b/repobee_sanitizer/_sanitize_repo.py index f5620b7..b036b76 100644 --- a/repobee_sanitizer/_sanitize_repo.py +++ b/repobee_sanitizer/_sanitize_repo.py @@ -74,10 +74,10 @@ def sanitize_files( return files_with_errors for relpath in file_relpaths: - file_abspath = basedir / str(relpath) - sanitized_text = _sanitize.sanitize_file(file_abspath) + content = relpath.read_text_relative_to(basedir).split("\n") + sanitized_text = _sanitize.sanitize_text(content) if sanitized_text is None: - file_abspath.unlink() + (basedir / str(relpath)).unlink() LOGGER.info(f"Shredded file {relpath}") else: relpath.write_text_relative_to( diff --git a/repobee_sanitizer/sanitizer.py b/repobee_sanitizer/sanitizer.py index f722cdd..7dff38c 100644 --- a/repobee_sanitizer/sanitizer.py +++ b/repobee_sanitizer/sanitizer.py @@ -15,6 +15,7 @@ _syntax, _format, _sanitize_repo, + _fileutils, ) PLUGIN_NAME = "sanitizer" @@ -105,7 +106,13 @@ def command(self, api) -> Optional[plug.Result]: Returns: Result if the syntax is invalid, otherwise nothing. """ - errors = _syntax.check_syntax(self.infile.read_text().split("\n")) + + infile_encoding = _fileutils.guess_encoding(self.infile) + infile_content = self.infile.read_text(encoding=infile_encoding).split( + "\n" + ) + + errors = _syntax.check_syntax(infile_content) if errors: file_errors = [_format.FileWithErrors(self.infile.name, errors)] msg = _format.format_error_string(file_errors) @@ -114,9 +121,9 @@ def command(self, api) -> Optional[plug.Result]: name="sanitize-file", msg=msg, status=plug.Status.ERROR, ) - result = _sanitize.sanitize_file(self.infile, strip=self.strip) + result = _sanitize.sanitize_text(infile_content, strip=self.strip) if result: - self.outfile.write_text(result) + self.outfile.write_text(result, encoding=infile_encoding) return plug.Result( name="sanitize-file", diff --git a/tests/helpers/testhelpers.py b/tests/helpers/testhelpers.py index b507925..b436b04 100644 --- a/tests/helpers/testhelpers.py +++ b/tests/helpers/testhelpers.py @@ -3,7 +3,7 @@ import pathlib import collections -from typing import Iterable, Tuple +from typing import Iterable INPUT_FILENAME = "input.in" OUTPUT_FILENAME = "output.out" @@ -17,7 +17,7 @@ RESOURCES_BASEDIR = pathlib.Path(__file__).parent.parent / "resources" -TestData = collections.namedtuple("TestData", "inp out inverse".split()) +TestData = collections.namedtuple("TestData", "inp out inverse") def discover_test_cases( @@ -61,7 +61,7 @@ def generate_invalid_test_cases(): ) -def read_valid_test_case_files(test_case_dir: pathlib.Path) -> Tuple[str, str]: +def read_valid_test_case_files(test_case_dir: pathlib.Path) -> TestData: inp = (test_case_dir / INPUT_FILENAME).read_text(encoding="utf8") out = (test_case_dir / OUTPUT_FILENAME).read_text(encoding="utf8") inverse = (test_case_dir / INVERSE_FILENAME).read_text(encoding="utf8") diff --git a/tests/test_sanitze.py b/tests/test_sanitze.py index 22b9692..626cd04 100644 --- a/tests/test_sanitze.py +++ b/tests/test_sanitze.py @@ -28,5 +28,6 @@ def test_sanitize_valid(data: testhelpers.TestData): for this test function is generated by the pytest_generate_tests hook. """ - assert _sanitize.sanitize_text(data.inp) == data.out - assert _sanitize.sanitize_text(data.inp, strip=True) == data.inverse + formated_inp = data.inp.split("\n") + assert _sanitize.sanitize_text(formated_inp) == data.out + assert _sanitize.sanitize_text(formated_inp, strip=True) == data.inverse