Skip to content

Commit

Permalink
[feat] Make sanitize repo and file respect encodings
Browse files Browse the repository at this point in the history
  • Loading branch information
tohanss authored Aug 7, 2020
1 parent 7fe6181 commit 3f2849d
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 39 deletions.
32 changes: 4 additions & 28 deletions repobee_sanitizer/_sanitize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
.. module:: _sanitize_file
:synopsis: Module for file sanitization functionality.
"""
import pathlib

from repobee_sanitizer import _syntax
from repobee_sanitizer._syntax import Markers
Expand All @@ -12,38 +11,15 @@
from typing import List, Iterable, Optional


def sanitize_file(
file_abs_path: pathlib.Path, strip: bool = False
) -> Optional[str]:
"""Runs the sanitization protocol on a given file. This can either remove
the file or give back a sanitized version. File must be syntax checked
before running this.
def sanitize_text(lines: List[str], strip: bool = False) -> Optional[str]:
"""A function to directly sanitize given content. Text must be syntax
checked first.
Args:
file_abs_path: The absolute file path to the file you wish to
sanitize.
Returns:
The sanitized output text, but only if the file was not
removed.
Content to be sanitized
"""
text = file_abs_path.read_text()
lines = text.split("\n")
if _syntax.contained_marker(lines[0]) == Markers.SHRED:
return None
else:
sanitized_string = _sanitize(lines, strip=False)
return "\n".join(sanitized_string)


def sanitize_text(content: str, strip: bool = False) -> str:
"""A function to directly sanitize given content.
Args:
Content to be sanitized.
"""
lines = content.split("\n")
_syntax.check_syntax(lines)
sanitized_string = _sanitize(lines, strip)
return "\n".join(sanitized_string)

Expand Down
6 changes: 3 additions & 3 deletions repobee_sanitizer/_sanitize_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ def sanitize_files(
return files_with_errors

for relpath in file_relpaths:
file_abspath = basedir / str(relpath)
sanitized_text = _sanitize.sanitize_file(file_abspath)
content = relpath.read_text_relative_to(basedir).split("\n")
sanitized_text = _sanitize.sanitize_text(content)
if sanitized_text is None:
file_abspath.unlink()
(basedir / str(relpath)).unlink()
LOGGER.info(f"Shredded file {relpath}")
else:
relpath.write_text_relative_to(
Expand Down
13 changes: 10 additions & 3 deletions repobee_sanitizer/sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
_syntax,
_format,
_sanitize_repo,
_fileutils,
)

PLUGIN_NAME = "sanitizer"
Expand Down Expand Up @@ -105,7 +106,13 @@ def command(self, api) -> Optional[plug.Result]:
Returns:
Result if the syntax is invalid, otherwise nothing.
"""
errors = _syntax.check_syntax(self.infile.read_text().split("\n"))

infile_encoding = _fileutils.guess_encoding(self.infile)
infile_content = self.infile.read_text(encoding=infile_encoding).split(
"\n"
)

errors = _syntax.check_syntax(infile_content)
if errors:
file_errors = [_format.FileWithErrors(self.infile.name, errors)]
msg = _format.format_error_string(file_errors)
Expand All @@ -114,9 +121,9 @@ def command(self, api) -> Optional[plug.Result]:
name="sanitize-file", msg=msg, status=plug.Status.ERROR,
)

result = _sanitize.sanitize_file(self.infile, strip=self.strip)
result = _sanitize.sanitize_text(infile_content, strip=self.strip)
if result:
self.outfile.write_text(result)
self.outfile.write_text(result, encoding=infile_encoding)

return plug.Result(
name="sanitize-file",
Expand Down
6 changes: 3 additions & 3 deletions tests/helpers/testhelpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pathlib
import collections

from typing import Iterable, Tuple
from typing import Iterable

INPUT_FILENAME = "input.in"
OUTPUT_FILENAME = "output.out"
Expand All @@ -17,7 +17,7 @@
RESOURCES_BASEDIR = pathlib.Path(__file__).parent.parent / "resources"


TestData = collections.namedtuple("TestData", "inp out inverse".split())
TestData = collections.namedtuple("TestData", "inp out inverse")


def discover_test_cases(
Expand Down Expand Up @@ -61,7 +61,7 @@ def generate_invalid_test_cases():
)


def read_valid_test_case_files(test_case_dir: pathlib.Path) -> Tuple[str, str]:
def read_valid_test_case_files(test_case_dir: pathlib.Path) -> TestData:
inp = (test_case_dir / INPUT_FILENAME).read_text(encoding="utf8")
out = (test_case_dir / OUTPUT_FILENAME).read_text(encoding="utf8")
inverse = (test_case_dir / INVERSE_FILENAME).read_text(encoding="utf8")
Expand Down
5 changes: 3 additions & 2 deletions tests/test_sanitze.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@ def test_sanitize_valid(data: testhelpers.TestData):
for this test function is generated by the pytest_generate_tests
hook.
"""
assert _sanitize.sanitize_text(data.inp) == data.out
assert _sanitize.sanitize_text(data.inp, strip=True) == data.inverse
formated_inp = data.inp.split("\n")
assert _sanitize.sanitize_text(formated_inp) == data.out
assert _sanitize.sanitize_text(formated_inp, strip=True) == data.inverse

0 comments on commit 3f2849d

Please sign in to comment.