Skip to content

Commit

Permalink
refactor: Implement msgspec encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarrmondragon committed Jul 17, 2024
1 parent 32c059b commit 750b114
Show file tree
Hide file tree
Showing 6 changed files with 293 additions and 7 deletions.
10 changes: 5 additions & 5 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
def mypy(session: Session) -> None:
"""Check types with mypy."""
args = session.posargs or ["singer_sdk"]
session.install(".[faker,jwt,parquet,s3,testing]")
session.install(".[faker,jwt,msgspec,parquet,s3,testing]")
session.install(*typing_dependencies)
session.run("mypy", *args)
if not session.posargs:
Expand All @@ -61,7 +61,7 @@ def mypy(session: Session) -> None:
@session(python=python_versions)
def tests(session: Session) -> None:
"""Execute pytest tests and compute coverage."""
session.install(".[faker,jwt,parquet,s3]")
session.install(".[faker,jwt,msgspec,parquet,s3]")
session.install(*test_dependencies)

sqlalchemy_version = os.environ.get("SQLALCHEMY_VERSION")
Expand Down Expand Up @@ -94,7 +94,7 @@ def tests(session: Session) -> None:
@session(python=main_python_version)
def benches(session: Session) -> None:
"""Run benchmarks."""
session.install(".[jwt,s3]")
session.install(".[jwt,msgspec,s3]")
session.install(*test_dependencies)
sqlalchemy_version = os.environ.get("SQLALCHEMY_VERSION")
if sqlalchemy_version:
Expand All @@ -114,7 +114,7 @@ def benches(session: Session) -> None:
@session(name="deps", python=python_versions)
def dependencies(session: Session) -> None:
"""Check issues with dependencies."""
session.install(".[s3,testing]")
session.install(".[msgspec,s3,testing]")
session.install("deptry")
session.run("deptry", "singer_sdk", *session.posargs)

Expand All @@ -124,7 +124,7 @@ def update_snapshots(session: Session) -> None:
"""Update pytest snapshots."""
args = session.posargs or ["-m", "snapshot"]

session.install(".[faker,jwt,parquet]")
session.install(".[faker,jwt,msgspec,parquet]")
session.install(*test_dependencies)
session.run("pytest", "--snapshot-update", *args)

Expand Down
55 changes: 54 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ inflection = ">=0.5.1"
joblib = ">=1.3.0"
jsonpath-ng = ">=1.5.3"
jsonschema = ">=4.16.0"
msgspec = { version = ">=0.18.6", optional = true }
packaging = ">=23.1"
pendulum = ">=2.1.0,<4"
python-dateutil = ">=2.8.2"
Expand Down Expand Up @@ -111,6 +112,7 @@ docs = [
"sphinx-notfound-page",
"sphinx-reredirects",
]
msgspec = ["msgspec"]
s3 = ["fs-s3fs"]
testing = [
"pytest",
Expand Down
172 changes: 172 additions & 0 deletions singer_sdk/_singerlib/encoding/_msgspec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from __future__ import annotations

import datetime
import decimal
import logging
import sys
import typing as t

import msgspec

from singer_sdk._singerlib.exceptions import InvalidInputLine

from ._base import GenericSingerReader, GenericSingerWriter

logger = logging.getLogger(__name__)


class Message(msgspec.Struct, tag_field="type", tag=str.upper):
"""Singer base message."""

def to_dict(self): # noqa: ANN202
return {f: getattr(self, f) for f in self.__struct_fields__}

Check warning on line 22 in singer_sdk/_singerlib/encoding/_msgspec.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/_msgspec.py#L22

Added line #L22 was not covered by tests


class RecordMessage(Message, tag="RECORD"):
"""Singer RECORD message."""

stream: str
record: t.Dict[str, t.Any] # noqa: UP006
version: t.Union[int, None] = None # noqa: UP007
time_extracted: t.Union[datetime.datetime, None] = None # noqa: UP007

def __post_init__(self) -> None:
"""Post-init processing.
Raises:
ValueError: If the time_extracted is not timezone-aware.
"""
if self.time_extracted and not self.time_extracted.tzinfo:
msg = (

Check warning on line 40 in singer_sdk/_singerlib/encoding/_msgspec.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/_msgspec.py#L40

Added line #L40 was not covered by tests
"'time_extracted' must be either None or an aware datetime (with a "
"time zone)"
)
raise ValueError(msg)

Check warning on line 44 in singer_sdk/_singerlib/encoding/_msgspec.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/_msgspec.py#L44

Added line #L44 was not covered by tests

if self.time_extracted:
self.time_extracted = self.time_extracted.astimezone(datetime.timezone.utc)

Check warning on line 47 in singer_sdk/_singerlib/encoding/_msgspec.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/_msgspec.py#L47

Added line #L47 was not covered by tests


class SchemaMessage(Message, tag="SCHEMA"):
"""Singer SCHEMA message."""

stream: str
schema: t.Dict[str, t.Any] # noqa: UP006
key_properties: t.List[str] # noqa: UP006
bookmark_properties: t.Union[t.List[str], None] = None # noqa: UP006, UP007

def __post_init__(self) -> None:
"""Post-init processing.
Raises:
ValueError: If bookmark_properties is not a string or list of strings.
"""
if isinstance(self.bookmark_properties, (str, bytes)):
self.bookmark_properties = [self.bookmark_properties]

Check warning on line 65 in singer_sdk/_singerlib/encoding/_msgspec.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/_msgspec.py#L65

Added line #L65 was not covered by tests
if self.bookmark_properties and not isinstance(self.bookmark_properties, list):
msg = "bookmark_properties must be a string or list of strings"
raise ValueError(msg)

Check warning on line 68 in singer_sdk/_singerlib/encoding/_msgspec.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/_msgspec.py#L67-L68

Added lines #L67 - L68 were not covered by tests


class StateMessage(Message, tag="STATE"):
"""Singer state message."""

value: t.Dict[str, t.Any] # noqa: UP006
"""The state value."""


class ActivateVersionMessage(Message, tag="ACTIVATE_VERSION"):
"""Singer activate version message."""

stream: str
"""The stream name."""

version: int
"""The version to activate."""


def enc_hook(obj: t.Any) -> t.Any: # noqa: ANN401
"""Encoding type helper for non native types.
Args:
obj: the item to be encoded
Returns:
The object converted to the appropriate type, default is str
"""
return obj.isoformat(sep="T") if isinstance(obj, datetime.datetime) else str(obj)


def dec_hook(type: type, obj: t.Any) -> t.Any: # noqa: ARG001, A002, ANN401
"""Decoding type helper for non native types.
Args:
type: the type given
obj: the item to be decoded
Returns:
The object converted to the appropriate type, default is str.
"""
return str(obj)


encoder = msgspec.json.Encoder(enc_hook=enc_hook, decimal_format="number")
decoder = msgspec.json.Decoder(
t.Union[
RecordMessage,
SchemaMessage,
StateMessage,
ActivateVersionMessage,
],
dec_hook=dec_hook,
float_hook=decimal.Decimal,
)


class MsgSpecReader(GenericSingerReader[str]):
"""Base class for all plugins reading Singer messages as strings from stdin."""

default_input = sys.stdin

def deserialize_json(self, line: str) -> dict: # noqa: PLR6301
"""Deserialize a line of json.
Args:
line: A single line of json.
Returns:
A dictionary of the deserialized json.
Raises:
InvalidInputLine: If the line cannot be parsed
"""
try:
return decoder.decode(line).to_dict()
except msgspec.DecodeError as exc:
logger.exception("Unable to parse:\n%s", line)
msg = f"Unable to parse line as JSON: {line}"
raise InvalidInputLine(msg) from exc

Check warning on line 148 in singer_sdk/_singerlib/encoding/_msgspec.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/_msgspec.py#L143-L148

Added lines #L143 - L148 were not covered by tests


class MsgSpecWriter(GenericSingerWriter[bytes, Message]):
"""Interface for all plugins writing Singer messages to stdout."""

def serialize_message(self, message: Message) -> bytes: # noqa: PLR6301
"""Serialize a dictionary into a line of json.
Args:
message: A Singer message object.
Returns:
A string of serialized json.
"""
return encoder.encode(message)

Check warning on line 163 in singer_sdk/_singerlib/encoding/_msgspec.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/_msgspec.py#L163

Added line #L163 was not covered by tests

def write_message(self, message: Message) -> None:
"""Write a message to stdout.
Args:
message: The message to write.
"""
sys.stdout.buffer.write(self.format_message(message) + b"\n")
sys.stdout.flush()

Check warning on line 172 in singer_sdk/_singerlib/encoding/_msgspec.py

View check run for this annotation

Codecov / codecov/patch

singer_sdk/_singerlib/encoding/_msgspec.py#L171-L172

Added lines #L171 - L172 were not covered by tests
41 changes: 41 additions & 0 deletions tests/_singerlib/_encoding/test_msgspec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from __future__ import annotations # noqa: INP001

import pytest

from singer_sdk._singerlib.encoding._msgspec import dec_hook, enc_hook


@pytest.mark.parametrize(
"test_type,test_value,expected_value,expected_type",
[
pytest.param(
int,
1,
"1",
str,
id="int-to-str",
),
],
)
def test_dec_hook(test_type, test_value, expected_value, expected_type):
returned = dec_hook(type=test_type, obj=test_value)
returned_type = type(returned)

assert returned == expected_value
assert returned_type == expected_type


@pytest.mark.parametrize(
"test_value,expected_value",
[
pytest.param(
1,
"1",
id="int-to-str",
),
],
)
def test_enc_hook(test_value, expected_value):
returned = enc_hook(obj=test_value)

assert returned == expected_value
20 changes: 19 additions & 1 deletion tests/core/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pytest

from singer_sdk._singerlib import RecordMessage
from singer_sdk._singerlib.encoding._msgspec import MsgSpecReader, MsgSpecWriter
from singer_sdk._singerlib.exceptions import InvalidInputLine
from singer_sdk.io_base import SingerReader, SingerWriter

Expand Down Expand Up @@ -104,6 +105,7 @@ def test_write_message():
def bench_record():
return {
"stream": "users",
"type": "RECORD",
"record": {
"Id": 1,
"created_at": "2021-01-01T00:08:00-07:00",
Expand Down Expand Up @@ -131,7 +133,7 @@ def test_bench_format_message(benchmark, bench_record_message):
"""Run benchmark for Sink._validator method validate."""
number_of_runs = 1000

writer = SingerWriter()
writer = MsgSpecWriter()

def run_format_message():
for record in itertools.repeat(bench_record_message, number_of_runs):
Expand All @@ -144,6 +146,22 @@ def test_bench_deserialize_json(benchmark, bench_encoded_record):
"""Run benchmark for Sink._validator method validate."""
number_of_runs = 1000

class DummyReader(MsgSpecReader):
def _process_activate_version_message(self, message_dict: dict) -> None:
pass

def _process_batch_message(self, message_dict: dict) -> None:
pass

def _process_record_message(self, message_dict: dict) -> None:
pass

def _process_schema_message(self, message_dict: dict) -> None:
pass

def _process_state_message(self, message_dict: dict) -> None:
pass

reader = DummyReader()

def run_deserialize_json():
Expand Down

0 comments on commit 750b114

Please sign in to comment.