Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Use csv.DictReader to parse header fields (msto#19) #20

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions dataclass_io/_lib/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,18 +100,19 @@ def assert_file_header_matches_dataclass(
dataclass_type: type[DataclassInstance],
delimiter: str,
comment_prefix: str,
quoting: int,
) -> None:
"""
Check that the specified file has a header and its fields match those of the provided dataclass.
"""
header: FileHeader | None
if isinstance(file, Path):
with file.open("r") as fin:
header = get_header(fin, delimiter=delimiter, comment_prefix=comment_prefix)
header = get_header(fin, delimiter=delimiter, comment_prefix=comment_prefix, quoting=quoting)
else:
pos = file.tell()
try:
header = get_header(file, delimiter=delimiter, comment_prefix=comment_prefix)
header = get_header(file, delimiter=delimiter, comment_prefix=comment_prefix, quoting=quoting)
finally:
file.seek(pos)

Expand Down
11 changes: 10 additions & 1 deletion dataclass_io/_lib/file.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from csv import DictReader
from dataclasses import dataclass
from enum import Enum
from enum import unique
Expand Down Expand Up @@ -68,6 +69,7 @@ def get_header(
reader: ReadableFileHandle,
delimiter: str,
comment_prefix: str,
quoting: int,
) -> Optional[FileHeader]:
"""
Read the header from an open file.
Expand All @@ -85,6 +87,7 @@ def get_header(
Args:
reader: An open, readable file handle.
comment_char: The character which indicates the start of a comment line.
quoting: Quoting style (enum value from Python csv package).

Returns:
A `FileHeader` containing the field names and any preceding lines.
Expand All @@ -103,6 +106,12 @@ def get_header(
else:
return None

fieldnames = line.strip().split(delimiter)
'''
msto#19 Read header fields

Use csv.DictReader because RFC4180 is tricky to implement correctly
'''
header_reader = DictReader([line], delimiter=delimiter, quoting=quoting)
fieldnames = header_reader.fieldnames

return FileHeader(preface=preface, fieldnames=fieldnames)
7 changes: 6 additions & 1 deletion dataclass_io/reader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import csv
from contextlib import contextmanager
from csv import DictReader
from pathlib import Path
Expand Down Expand Up @@ -27,6 +28,7 @@ def __init__(
dataclass_type: type[DataclassInstance],
delimiter: str = "\t",
comment_prefix: str = "#",
quoting: int = csv.QUOTE_MINIMAL,
**kwds: Any,
) -> None:
"""
Expand All @@ -35,6 +37,7 @@ def __init__(
dataclass_type: Dataclass type.
delimiter: The input file delimiter.
comment_prefix: The prefix for any comment/preface rows preceding the header row.
quoting: Quoting style (enum value from Python csv package).
dataclass_type: Dataclass type.

Raises:
Expand All @@ -46,17 +49,19 @@ def __init__(
dataclass_type=dataclass_type,
delimiter=delimiter,
comment_prefix=comment_prefix,
quoting=quoting,
)

self._dataclass_type = dataclass_type
self._fin = fin
self._header = get_header(
reader=self._fin, delimiter=delimiter, comment_prefix=comment_prefix
reader=self._fin, delimiter=delimiter, comment_prefix=comment_prefix, quoting=quoting
)
self._reader = DictReader(
f=self._fin,
fieldnames=fieldnames(dataclass_type),
delimiter=delimiter,
quoting=quoting,
)

def __iter__(self) -> "DataclassReader":
Expand Down
4 changes: 4 additions & 0 deletions dataclass_io/writer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import csv
from contextlib import contextmanager
from csv import DictWriter
from dataclasses import asdict
Expand Down Expand Up @@ -126,6 +127,7 @@ def open(
overwrite: bool = True,
delimiter: str = "\t",
comment_prefix: str = "#",
quoting: int = csv.QUOTE_MINIMAL,
**kwds: Any,
) -> Iterator["DataclassWriter"]:
"""
Expand All @@ -146,6 +148,7 @@ def open(
comment_prefix: The prefix for any comment/preface rows preceding the header row.
(This argument is ignored when `mode="write"`. It is used when `mode="append"` to
validate that the existing file's header matches the specified dataclass.)
quoting: Quoting style (enum value from Python csv package).
**kwds: Additional keyword arguments to be passed to the `DataclassWriter` constructor.

Yields:
Expand Down Expand Up @@ -178,6 +181,7 @@ def open(
dataclass_type=dataclass_type,
delimiter=delimiter,
comment_prefix=comment_prefix,
quoting=quoting,
)

fout = filepath.open(write_mode.abbreviation)
Expand Down
27 changes: 27 additions & 0 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,30 @@ class FakeDataclass:
assert isinstance(rows[0], FakeDataclass)
assert rows[0].foo == "abc"
assert rows[0].bar == 1


def test_read_csv_with_header_quotes(tmp_path: Path) -> None:
"""
Test that having quotes around column names in header row doesn't break anything
https://github.com/msto/dataclass_io/issues/19
"""
fpath = tmp_path / "test.txt"

@dataclass
class FakeDataclass:
id: str
title: str

test_csv = [
'"id"\t"title"\n',
'"fake"\t"A fake object"\n',
'"also_fake"\t"Another fake object"\n',
]

with fpath.open("w") as f:
f.writelines(test_csv)

# Parse CSV using DataclassReader
with DataclassReader.open(fpath, FakeDataclass) as reader:
for fake_object in reader:
print(fake_object)