Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Implement msgspec encoding #2539

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
def mypy(session: Session) -> None:
"""Check types with mypy."""
args = session.posargs or ["singer_sdk"]
session.install(".[faker,jwt,parquet,s3,testing]")
session.install(".[faker,jwt,msgspec,parquet,s3,testing]")
session.install(*typing_dependencies)
session.run("mypy", *args)
if not session.posargs:
Expand All @@ -61,7 +61,7 @@ def mypy(session: Session) -> None:
@session(python=python_versions)
def tests(session: Session) -> None:
"""Execute pytest tests and compute coverage."""
session.install(".[faker,jwt,parquet,s3]")
session.install(".[faker,jwt,msgspec,parquet,s3]")
session.install(*test_dependencies)

sqlalchemy_version = os.environ.get("SQLALCHEMY_VERSION")
Expand Down Expand Up @@ -94,7 +94,7 @@ def tests(session: Session) -> None:
@session(python=main_python_version)
def benches(session: Session) -> None:
"""Run benchmarks."""
session.install(".[jwt,s3]")
session.install(".[jwt,msgspec,s3]")
session.install(*test_dependencies)
sqlalchemy_version = os.environ.get("SQLALCHEMY_VERSION")
if sqlalchemy_version:
Expand All @@ -114,7 +114,7 @@ def benches(session: Session) -> None:
@session(name="deps", python=python_versions)
def dependencies(session: Session) -> None:
"""Check issues with dependencies."""
session.install(".[s3,testing]")
session.install(".[msgspec,s3,testing]")
session.install("deptry")
session.run("deptry", "singer_sdk", *session.posargs)

Expand All @@ -124,7 +124,7 @@ def update_snapshots(session: Session) -> None:
"""Update pytest snapshots."""
args = session.posargs or ["-m", "snapshot"]

session.install(".[faker,jwt,parquet]")
session.install(".[faker,jwt,msgspec,parquet]")
session.install(*test_dependencies)
session.run("pytest", "--snapshot-update", *args)

Expand Down
55 changes: 54 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ inflection = ">=0.5.1"
joblib = ">=1.3.0"
jsonpath-ng = ">=1.5.3"
jsonschema = ">=4.16.0"
msgspec = { version = ">=0.18.6", optional = true }
packaging = ">=23.1"
pendulum = ">=2.1.0,<4"
python-dateutil = ">=2.8.2"
Expand Down Expand Up @@ -111,6 +112,7 @@ docs = [
"sphinx-notfound-page",
"sphinx-reredirects",
]
msgspec = ["msgspec"]
s3 = ["fs-s3fs"]
testing = [
"pytest",
Expand Down
172 changes: 172 additions & 0 deletions singer_sdk/_singerlib/encoding/_msgspec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from __future__ import annotations

import datetime
import decimal
import logging
import sys
import typing as t

import msgspec

from singer_sdk._singerlib.exceptions import InvalidInputLine

from ._base import GenericSingerReader, GenericSingerWriter

logger = logging.getLogger(__name__)


class Message(msgspec.Struct, tag_field="type", tag=str.upper):
"""Singer base message."""

def to_dict(self): # noqa: ANN202
return {f: getattr(self, f) for f in self.__struct_fields__}

Check warning on line 22 in singer_sdk/_singerlib/encoding/_msgspec.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/_msgspec.py#L22

Added line #L22 was not covered by tests


class RecordMessage(Message, tag="RECORD"):
"""Singer RECORD message."""

stream: str
record: t.Dict[str, t.Any] # noqa: UP006
version: t.Union[int, None] = None # noqa: UP007
time_extracted: t.Union[datetime.datetime, None] = None # noqa: UP007

def __post_init__(self) -> None:
"""Post-init processing.

Raises:
ValueError: If the time_extracted is not timezone-aware.
"""
if self.time_extracted and not self.time_extracted.tzinfo:
msg = (

Check warning on line 40 in singer_sdk/_singerlib/encoding/_msgspec.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/_msgspec.py#L40

Added line #L40 was not covered by tests
"'time_extracted' must be either None or an aware datetime (with a "
"time zone)"
)
raise ValueError(msg)

Check warning on line 44 in singer_sdk/_singerlib/encoding/_msgspec.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/_msgspec.py#L44

Added line #L44 was not covered by tests

if self.time_extracted:
self.time_extracted = self.time_extracted.astimezone(datetime.timezone.utc)

Check warning on line 47 in singer_sdk/_singerlib/encoding/_msgspec.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/_msgspec.py#L47

Added line #L47 was not covered by tests


class SchemaMessage(Message, tag="SCHEMA"):
"""Singer SCHEMA message."""

stream: str
schema: t.Dict[str, t.Any] # noqa: UP006
key_properties: t.List[str] # noqa: UP006
bookmark_properties: t.Union[t.List[str], None] = None # noqa: UP006, UP007

def __post_init__(self) -> None:
"""Post-init processing.

Raises:
ValueError: If bookmark_properties is not a string or list of strings.
"""
if isinstance(self.bookmark_properties, (str, bytes)):
self.bookmark_properties = [self.bookmark_properties]

Check warning on line 65 in singer_sdk/_singerlib/encoding/_msgspec.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/_msgspec.py#L65

Added line #L65 was not covered by tests
if self.bookmark_properties and not isinstance(self.bookmark_properties, list):
msg = "bookmark_properties must be a string or list of strings"
raise ValueError(msg)

Check warning on line 68 in singer_sdk/_singerlib/encoding/_msgspec.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/_msgspec.py#L67-L68

Added lines #L67 - L68 were not covered by tests


class StateMessage(Message, tag="STATE"):
"""Singer state message."""

value: t.Dict[str, t.Any] # noqa: UP006
"""The state value."""


class ActivateVersionMessage(Message, tag="ACTIVATE_VERSION"):
"""Singer activate version message."""

stream: str
"""The stream name."""

version: int
"""The version to activate."""


def enc_hook(obj: t.Any) -> t.Any: # noqa: ANN401
"""Encoding type helper for non native types.

Args:
obj: the item to be encoded

Returns:
The object converted to the appropriate type, default is str
"""
return obj.isoformat(sep="T") if isinstance(obj, datetime.datetime) else str(obj)


def dec_hook(type: type, obj: t.Any) -> t.Any: # noqa: ARG001, A002, ANN401
"""Decoding type helper for non native types.

Args:
type: the type given
obj: the item to be decoded

Returns:
The object converted to the appropriate type, default is str.
"""
return str(obj)


encoder = msgspec.json.Encoder(enc_hook=enc_hook, decimal_format="number")
decoder = msgspec.json.Decoder(
t.Union[
RecordMessage,
SchemaMessage,
StateMessage,
ActivateVersionMessage,
],
dec_hook=dec_hook,
float_hook=decimal.Decimal,
)


class MsgSpecReader(GenericSingerReader[str]):
"""Base class for all plugins reading Singer messages as strings from stdin."""

default_input = sys.stdin

def deserialize_json(self, line: str) -> dict: # noqa: PLR6301
"""Deserialize a line of json.

Args:
line: A single line of json.

Returns:
A dictionary of the deserialized json.

Raises:
InvalidInputLine: If the line cannot be parsed
"""
try:
return decoder.decode(line).to_dict() # type: ignore[no-any-return]
except msgspec.DecodeError as exc:
logger.exception("Unable to parse:\n%s", line)
msg = f"Unable to parse line as JSON: {line}"
raise InvalidInputLine(msg) from exc

Check warning on line 148 in singer_sdk/_singerlib/encoding/_msgspec.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/_msgspec.py#L143-L148

Added lines #L143 - L148 were not covered by tests


class MsgSpecWriter(GenericSingerWriter[bytes, Message]):
"""Interface for all plugins writing Singer messages to stdout."""

def serialize_message(self, message: Message) -> bytes: # noqa: PLR6301
"""Serialize a dictionary into a line of json.

Args:
message: A Singer message object.

Returns:
A string of serialized json.
"""
return encoder.encode(message)

Check warning on line 163 in singer_sdk/_singerlib/encoding/_msgspec.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/_msgspec.py#L163

Added line #L163 was not covered by tests

def write_message(self, message: Message) -> None:
"""Write a message to stdout.

Args:
message: The message to write.
"""
sys.stdout.buffer.write(self.format_message(message) + b"\n")
sys.stdout.flush()

Check warning on line 172 in singer_sdk/_singerlib/encoding/_msgspec.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/_msgspec.py#L171-L172

Added lines #L171 - L172 were not covered by tests
41 changes: 41 additions & 0 deletions tests/_singerlib/_encoding/test_msgspec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from __future__ import annotations # noqa: INP001

import pytest

from singer_sdk._singerlib.encoding._msgspec import dec_hook, enc_hook


@pytest.mark.parametrize(
"test_type,test_value,expected_value,expected_type",
[
pytest.param(
int,
1,
"1",
str,
id="int-to-str",
),
],
)
def test_dec_hook(test_type, test_value, expected_value, expected_type):
returned = dec_hook(type=test_type, obj=test_value)
returned_type = type(returned)

assert returned == expected_value
assert returned_type == expected_type


@pytest.mark.parametrize(
"test_value,expected_value",
[
pytest.param(
1,
"1",
id="int-to-str",
),
],
)
def test_enc_hook(test_value, expected_value):
returned = enc_hook(obj=test_value)

assert returned == expected_value
20 changes: 19 additions & 1 deletion tests/core/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pytest

from singer_sdk._singerlib import RecordMessage
from singer_sdk._singerlib.encoding._msgspec import MsgSpecReader, MsgSpecWriter
from singer_sdk._singerlib.exceptions import InvalidInputLine
from singer_sdk.io_base import SingerReader, SingerWriter

Expand Down Expand Up @@ -104,6 +105,7 @@ def test_write_message():
def bench_record():
return {
"stream": "users",
"type": "RECORD",
"record": {
"Id": 1,
"created_at": "2021-01-01T00:08:00-07:00",
Expand Down Expand Up @@ -131,7 +133,7 @@ def test_bench_format_message(benchmark, bench_record_message):
"""Run benchmark for Sink._validator method validate."""
number_of_runs = 1000

writer = SingerWriter()
writer = MsgSpecWriter()

def run_format_message():
for record in itertools.repeat(bench_record_message, number_of_runs):
Expand All @@ -144,6 +146,22 @@ def test_bench_deserialize_json(benchmark, bench_encoded_record):
"""Run benchmark for Sink._validator method validate."""
number_of_runs = 1000

class DummyReader(MsgSpecReader):
def _process_activate_version_message(self, message_dict: dict) -> None:
pass

def _process_batch_message(self, message_dict: dict) -> None:
pass

def _process_record_message(self, message_dict: dict) -> None:
pass

def _process_schema_message(self, message_dict: dict) -> None:
pass

def _process_state_message(self, message_dict: dict) -> None:
pass

reader = DummyReader()

def run_deserialize_json():
Expand Down
Loading