Skip to content

Commit

Permalink
refactor: Allow loading stream schemas from `importlib.resources.abc.…
Browse files Browse the repository at this point in the history
…Traversable` types
  • Loading branch information
edgarrmondragon committed Jan 3, 2024
1 parent b442647 commit a131a52
Show file tree
Hide file tree
Showing 11 changed files with 60 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,38 @@
{% if cookiecutter.auth_method in ("OAuth2", "JWT") -%}
import sys
{% endif -%}
from pathlib import Path
from typing import Any, Callable, Iterable

import requests
{% if cookiecutter.auth_method == "API Key" -%}
from singer_sdk.authenticators import APIKeyAuthenticator
from singer_sdk.helpers.compat import importlib_resources
from singer_sdk.helpers.jsonpath import extract_jsonpath
from singer_sdk.pagination import BaseAPIPaginator # noqa: TCH002
from singer_sdk.streams import {{ cookiecutter.stream_type }}Stream

{% elif cookiecutter.auth_method == "Bearer Token" -%}
from singer_sdk.authenticators import BearerTokenAuthenticator
from singer_sdk.helpers.compat import importlib_resources
from singer_sdk.helpers.jsonpath import extract_jsonpath
from singer_sdk.pagination import BaseAPIPaginator # noqa: TCH002
from singer_sdk.streams import {{ cookiecutter.stream_type }}Stream

{% elif cookiecutter.auth_method == "Basic Auth" -%}
from singer_sdk.authenticators import BasicAuthenticator
from singer_sdk.helpers.compat import importlib_resources
from singer_sdk.helpers.jsonpath import extract_jsonpath
from singer_sdk.pagination import BaseAPIPaginator # noqa: TCH002
from singer_sdk.streams import {{ cookiecutter.stream_type }}Stream

{% elif cookiecutter.auth_method == "Custom or N/A" -%}
from singer_sdk.helpers.compat import importlib_resources
from singer_sdk.helpers.jsonpath import extract_jsonpath
from singer_sdk.pagination import BaseAPIPaginator # noqa: TCH002
from singer_sdk.streams import {{ cookiecutter.stream_type }}Stream

{% elif cookiecutter.auth_method in ("OAuth2", "JWT") -%}
from singer_sdk.helpers.compat import importlib_resources
from singer_sdk.helpers.jsonpath import extract_jsonpath
from singer_sdk.pagination import BaseAPIPaginator # noqa: TCH002
from singer_sdk.streams import {{ cookiecutter.stream_type }}Stream
Expand All @@ -50,7 +54,9 @@
{% endif -%}

_Auth = Callable[[requests.PreparedRequest], requests.PreparedRequest]
SCHEMAS_DIR = Path(__file__).parent / Path("./schemas")

# TODO: Delete this is if not using json files for schema definition
SCHEMAS_DIR = importlib_resources.files(__package__) / "schemas"


class {{ cookiecutter.source_name }}Stream({{ cookiecutter.stream_type }}Stream):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from __future__ import annotations

import typing as t
from pathlib import Path

from singer_sdk import typing as th # JSON Schema typing helpers
from singer_sdk.helpers.compat import importlib_resources

from {{ cookiecutter.library_name }}.client import {{ cookiecutter.source_name }}Stream

# TODO: Delete this is if not using json files for schema definition
SCHEMAS_DIR = Path(__file__).parent / Path("./schemas")
SCHEMAS_DIR = importlib_resources.files(__package__) / "schemas"


{%- if cookiecutter.stream_type == "GraphQL" %}
Expand Down
4 changes: 2 additions & 2 deletions samples/sample_tap_countries/countries_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from __future__ import annotations

import abc
from pathlib import Path

from singer_sdk import typing as th
from singer_sdk.helpers.compat import importlib_resources
from singer_sdk.streams.graphql import GraphQLStream

SCHEMAS_DIR = Path(__file__).parent / Path("./schemas")
SCHEMAS_DIR = importlib_resources.files(__package__) / "schemas"


class CountriesAPIStream(GraphQLStream, metaclass=abc.ABCMeta):
Expand Down
5 changes: 2 additions & 3 deletions samples/sample_tap_gitlab/gitlab_graphql_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@

from __future__ import annotations

from pathlib import Path

from singer_sdk.helpers.compat import importlib_resources
from singer_sdk.streams import GraphQLStream

SITE_URL = "https://gitlab.com/graphql"

SCHEMAS_DIR = Path(__file__).parent / Path("./schemas")
SCHEMAS_DIR = importlib_resources.files(__package__) / "schemas"


class GitlabGraphQLStream(GraphQLStream):
Expand Down
4 changes: 2 additions & 2 deletions samples/sample_tap_gitlab/gitlab_rest_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from __future__ import annotations

import typing as t
from pathlib import Path

from singer_sdk.authenticators import SimpleAuthenticator
from singer_sdk.helpers.compat import importlib_resources
from singer_sdk.pagination import SimpleHeaderPaginator
from singer_sdk.streams.rest import RESTStream
from singer_sdk.typing import (
Expand All @@ -17,7 +17,7 @@
StringType,
)

SCHEMAS_DIR = Path(__file__).parent / Path("./schemas")
SCHEMAS_DIR = importlib_resources.files(__package__) / "schemas"

DEFAULT_URL_BASE = "https://gitlab.com/api/v4"

Expand Down
4 changes: 2 additions & 2 deletions samples/sample_tap_google_analytics/ga_tap_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@

import datetime
import typing as t
from pathlib import Path

from singer_sdk.authenticators import OAuthJWTAuthenticator
from singer_sdk.helpers.compat import importlib_resources
from singer_sdk.streams import RESTStream

GOOGLE_OAUTH_ENDPOINT = "https://oauth2.googleapis.com/token"
GA_OAUTH_SCOPES = "https://www.googleapis.com/auth/analytics.readonly"
SCHEMAS_DIR = Path(__file__).parent / Path("./schemas")
SCHEMAS_DIR = importlib_resources.files(__package__) / "schemas"


class GoogleJWTAuthenticator(OAuthJWTAuthenticator):
Expand Down
6 changes: 0 additions & 6 deletions singer_sdk/helpers/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@
else:
from importlib.metadata import entry_points

if sys.version_info < (3, 9):
import importlib_resources as resources
else:
from importlib import resources

if sys.version_info < (3, 11):
from backports.datetime_fromisoformat import MonkeyPatch

Expand All @@ -34,7 +29,6 @@
__all__ = [
"metadata",
"final",
"resources",
"entry_points",
"datetime_fromisoformat",
"date_fromisoformat",
Expand Down
24 changes: 24 additions & 0 deletions singer_sdk/helpers/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Public compatibility helpers for the SDK."""

from __future__ import annotations

import sys

if sys.version_info < (3, 9):
import importlib_resources
else:
from importlib import resources as importlib_resources


if sys.version_info < (3, 9):
from importlib_resources.abc import Traversable
elif sys.version_info < (3, 12):
from importlib.abc import Traversable
else:
from importlib.resources.abc import Traversable


__all__ = [
"importlib_resources",
"Traversable",
]
7 changes: 4 additions & 3 deletions singer_sdk/streams/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
if t.TYPE_CHECKING:
import logging

from singer_sdk.helpers.compat import Traversable
from singer_sdk.tap_base import Tap

# Replication methods
Expand Down Expand Up @@ -136,7 +137,7 @@ def __init__(
self._replication_key: str | None = None
self._primary_keys: t.Sequence[str] | None = None
self._state_partitioning_keys: list[str] | None = None
self._schema_filepath: Path | None = None
self._schema_filepath: Path | Traversable | None = None
self._metadata: singer.MetadataMapping | None = None
self._mask: singer.SelectionMask | None = None
self._schema: dict
Expand All @@ -160,7 +161,7 @@ def __init__(
raise ValueError(msg)

if self.schema_filepath:
self._schema = json.loads(Path(self.schema_filepath).read_text())
self._schema = json.loads(self.schema_filepath.read_text())

if not self.schema:
msg = (
Expand Down Expand Up @@ -421,7 +422,7 @@ def get_replication_key_signpost(
return utc_now() if self.is_timestamp_replication_key else None

@property
def schema_filepath(self) -> Path | None:
def schema_filepath(self) -> Path | Traversable | None:
"""Get path to schema file.
Returns:
Expand Down
12 changes: 7 additions & 5 deletions singer_sdk/testing/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@
import typing as t
from collections import defaultdict
from contextlib import redirect_stderr, redirect_stdout
from pathlib import Path

from singer_sdk import Tap, Target
from singer_sdk.testing.config import SuiteConfig

if t.TYPE_CHECKING:
from pathlib import Path

from singer_sdk.helpers.compat import Traversable


class SingerTestRunner(metaclass=abc.ABCMeta):
"""Base Singer Test Runner."""
Expand Down Expand Up @@ -197,7 +201,7 @@ def __init__(
target_class: type[Target],
config: dict | None = None,
suite_config: SuiteConfig | None = None,
input_filepath: Path | None = None,
input_filepath: Path | Traversable | None = None,
input_io: io.StringIO | None = None,
**kwargs: t.Any,
) -> None:
Expand Down Expand Up @@ -242,9 +246,7 @@ def target_input(self) -> t.IO[str]:
if self.input_io:
self._input = self.input_io
elif self.input_filepath:
self._input = Path(self.input_filepath).open( # noqa: SIM115
encoding="utf8",
)
self._input = self.input_filepath.open(encoding="utf8")
return t.cast(t.IO[str], self._input)

@target_input.setter
Expand Down
14 changes: 7 additions & 7 deletions singer_sdk/testing/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import contextlib
import typing as t
import warnings
from pathlib import Path

from singer_sdk.helpers._compat import resources
from singer_sdk.helpers.compat import importlib_resources
from singer_sdk.testing import target_test_streams

if t.TYPE_CHECKING:
from singer_sdk.helpers.compat import Traversable
from singer_sdk.streams import Stream

from .config import SuiteConfig
Expand Down Expand Up @@ -322,19 +322,19 @@ def run( # type: ignore[override]
"""
# get input from file
if getattr(self, "singer_filepath", None):
assert Path(
self.singer_filepath,
).exists(), f"Singer file {self.singer_filepath} does not exist."
assert (
self.singer_filepath.is_file()
), f"Singer file {self.singer_filepath} does not exist."
runner.input_filepath = self.singer_filepath
super().run(config, resource, runner)

@property
def singer_filepath(self) -> Path:
def singer_filepath(self) -> Traversable:
"""Get path to singer JSONL formatted messages file.
Files will be sourced from `./target_test_streams/<test name>.singer`.
Returns:
The expected Path to this tests singer file.
"""
return resources.files(target_test_streams).joinpath(f"{self.name}.singer") # type: ignore[no-any-return]
return importlib_resources.files(target_test_streams) / f"{self.name}.singer" # type: ignore[no-any-return]

0 comments on commit a131a52

Please sign in to comment.