Skip to content

Commit

Permalink
Add data source pyspark
Browse files Browse the repository at this point in the history
  • Loading branch information
ichuniq committed Nov 11, 2024
1 parent ceec377 commit 8b6f230
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
17 changes: 17 additions & 0 deletions ibis-server/app/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class QueryPostgresDTO(QueryDTO):
connection_info: ConnectionUrl | PostgresConnectionInfo = connection_info_field


class QueryPySparkDTO(QueryDTO):
connection_info: ConnectionUrl | PySparkConnectionInfo = connection_info_field


class QuerySnowflakeDTO(QueryDTO):
connection_info: SnowflakeConnectionInfo = connection_info_field

Expand Down Expand Up @@ -109,6 +113,18 @@ class PostgresConnectionInfo(BaseModel):
password: SecretStr


class PySparkConnectionInfo(BaseModel):
app_name: SecretStr = Field(examples=["wrenai"])
master: SecretStr = Field(
default="local[*]",
description="Spark master URL (e.g., 'local[*]', 'spark://master:7077')"
)
configs: dict[str, str] | None = Field(
default=None,
description="Additional Spark configurations"
)


class SnowflakeConnectionInfo(BaseModel):
user: SecretStr
password: SecretStr
Expand Down Expand Up @@ -137,6 +153,7 @@ class TrinoConnectionInfo(BaseModel):
| MSSqlConnectionInfo
| MySqlConnectionInfo
| PostgresConnectionInfo
| PySparkConnectionInfo
| SnowflakeConnectionInfo
| TrinoConnectionInfo
)
Expand Down
21 changes: 20 additions & 1 deletion ibis-server/app/model/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from json import loads

import ibis
from pyspark.sql import SparkSession
from google.oauth2 import service_account
from ibis import BaseBackend

Expand All @@ -16,13 +17,15 @@
MSSqlConnectionInfo,
MySqlConnectionInfo,
PostgresConnectionInfo,
PySparkConnectionInfo,
QueryBigQueryDTO,
QueryCannerDTO,
QueryClickHouseDTO,
QueryDTO,
QueryMSSqlDTO,
QueryMySqlDTO,
QueryPostgresDTO,
QueryPySparkDTO,
QuerySnowflakeDTO,
QueryTrinoDTO,
SnowflakeConnectionInfo,
Expand All @@ -37,6 +40,7 @@ class DataSource(StrEnum):
mssql = auto()
mysql = auto()
postgres = auto()
pyspark = auto()
snowflake = auto()
trino = auto()

Expand All @@ -60,6 +64,7 @@ class DataSourceExtension(Enum):
mssql = QueryMSSqlDTO
mysql = QueryMySqlDTO
postgres = QueryPostgresDTO
pyspark = QueryPySparkDTO
snowflake = QuerySnowflakeDTO
trino = QueryTrinoDTO

Expand Down Expand Up @@ -142,7 +147,21 @@ def get_postgres_connection(info: PostgresConnectionInfo) -> BaseBackend:
user=info.user.get_secret_value(),
password=info.password.get_secret_value(),
)


@staticmethod
def get_pyspark_connection(info: PysparkConnectionInfo) -> BaseBackend:
builder = SparkSession.builder \
.appName(info.app_name.get_secret_value()) \
.master(info.master.get_secret_value())

if info.configs:
for key, value in info.configs.items():
builder = builder.config(key, value)

# Create or get existing Spark session
spark_session = builder.getOrCreate()
return ibis.pyspark.connect(session=spark_session)

@staticmethod
def get_snowflake_connection(info: SnowflakeConnectionInfo) -> BaseBackend:
return ibis.snowflake.connect(
Expand Down

0 comments on commit 8b6f230

Please sign in to comment.