Skip to content

Commit

Permalink
Added 'table' alias in JdbcReader for 'dbtable' param. Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
riccamini committed Sep 27, 2024
1 parent 45b6fe6 commit 253f7f7
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 8 deletions.
1 change: 1 addition & 0 deletions src/koheesio/integrations/spark/tableau/hyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ class HyperFileDataFrameWriter(HyperFileWriter):
hw.hyper_path
```
"""

df: DataFrame = Field(default=..., description="Spark DataFrame to write to the Hyper file")
table_definition: Optional[TableDefinition] = None # table_definition is not required for this class

Expand Down
2 changes: 2 additions & 0 deletions src/koheesio/integrations/spark/tableau/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class TableauServer(Step):
"""
Base class for Tableau server interactions. Class provides authentication and project identification functionality.
"""

url: str = Field(
default=...,
alias="url",
Expand Down Expand Up @@ -190,6 +191,7 @@ class TableauHyperPublisher(TableauServer):
"""
Publish the given Hyper file to the Tableau server. Hyper file will be treated by Tableau server as a datasource.
"""

datasource_name: str = Field(default=..., description="Name of the datasource to publish")
hyper_path: PurePath = Field(default=..., description="Path to Hyper file")
publish_mode: TableauHyperPublishMode = Field(
Expand Down
4 changes: 3 additions & 1 deletion src/koheesio/spark/readers/jdbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ class JdbcReader(Reader):
)
user: str = Field(default=..., description="User to authenticate to the server")
password: SecretStr = Field(default=..., description="Password belonging to the username")
dbtable: Optional[str] = Field(default=None, description="Database table name, also include schema name")
dbtable: Optional[str] = Field(
default=None, description="Database table name, also include schema name", alias="table"
)
query: Optional[str] = Field(default=None, description="Query")
options: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Extra options to pass to spark reader")

Expand Down
19 changes: 12 additions & 7 deletions tests/spark/integrations/snowflake/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,29 +39,34 @@
"warehouse": "warehouse",
}


def test_snowflake_module_import():
# test that the pass-through imports in the koheesio.spark snowflake modules are working
from koheesio.spark.writers import snowflake as snowflake_readers
from koheesio.spark.readers import snowflake as snowflake_writers
from koheesio.spark.writers import snowflake as snowflake_readers


class TestSnowflakeReader:
reader_options = {"dbtable": "table", **COMMON_OPTIONS}

def test_get_options(self):
sf = SnowflakeReader(**(self.reader_options | {"authenticator": None}))
@pytest.mark.parametrize(
"reader_options", [{"dbtable": "table", **COMMON_OPTIONS}, {"table": "table", **COMMON_OPTIONS}]
)
def test_get_options(self, reader_options):
sf = SnowflakeReader(**(reader_options | {"authenticator": None}))
o = sf.get_options()
assert sf.format == "snowflake"
assert o["sfUser"] == "user"
assert o["sfCompress"] == "on"
assert "authenticator" not in o

def test_execute(self, dummy_spark):
@pytest.mark.parametrize(
"reader_options", [{"dbtable": "table", **COMMON_OPTIONS}, {"table": "table", **COMMON_OPTIONS}]
)
def test_execute(self, dummy_spark, reader_options):
"""Method should be callable from parent class"""
with mock.patch.object(SparkSession, "getActiveSession") as mock_spark:
mock_spark.return_value = dummy_spark

k = SnowflakeReader(**self.reader_options).execute()
k = SnowflakeReader(**reader_options).execute()
assert k.df.count() == 1


Expand Down

0 comments on commit 253f7f7

Please sign in to comment.