Skip to content

Commit

Permalink
feat(sparkle): structurizing the inputs field in pipeline decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
kkiani committed Aug 7, 2024
1 parent 7d97a11 commit c16c97b
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 27 deletions.
10 changes: 5 additions & 5 deletions src/damavand/sparkle/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional, Callable
from pyspark.sql import SparkSession

from .models import Pipeline, Trigger
from .models import Pipeline, Trigger, InputField
from .data_reader import DataReader
from .data_writer import DataWriter

Expand All @@ -21,7 +21,7 @@ def add_pipeline_rule(
pipeline_name: str,
description: Optional[str],
method: Callable,
input_topics: dict[str, str],
inputs: list[InputField],
):
"""Add a trigger rule for the given pipeline."""

Expand All @@ -31,19 +31,19 @@ def add_pipeline_rule(
self.__pipelines[pipeline_name] = Pipeline(
name=pipeline_name,
description=description,
inputs=input_topics,
inputs=inputs,
func=method,
)

def pipeline(self, name: str, inputs: dict[str, str], **options) -> Callable:
def pipeline(self, name: str, inputs: list[InputField], **options) -> Callable:
"""A decorator to define an processing job for the given pipeline."""

def decorator(func):
self.add_pipeline_rule(
pipeline_name=name,
description=func.__doc__,
method=func,
input_topics=inputs,
inputs=inputs,
)

return func
Expand Down
44 changes: 24 additions & 20 deletions src/damavand/sparkle/data_reader.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from dataclasses import dataclass
from typing import Any, Protocol
from typing import Protocol
from pyspark.sql import DataFrame, SparkSession

from .models import InputField


class DataReader(Protocol):
def read(
self, inputs: dict[str, Any], spark_session: SparkSession
self, inputs: list[InputField], spark_session: SparkSession
) -> dict[str, DataFrame]: ...


Expand All @@ -15,7 +16,7 @@ def __init__(self, database_name: str) -> None:
self.database_name = database_name

def read(
self, inputs: dict[str, Any], spark_session: SparkSession
self, inputs: list[InputField], spark_session: SparkSession
) -> dict[str, DataFrame]:
raise NotImplementedError

Expand All @@ -33,7 +34,7 @@ def __init__(
self.server_name = server_name

def read(
self, inputs: dict[str, Any], spark_session: SparkSession
self, inputs: list[InputField], spark_session: SparkSession
) -> dict[str, DataFrame]:
raise NotImplementedError

Expand All @@ -44,31 +45,34 @@ def __init__(self, server: str) -> None:
self.server = server

def read(
self, inputs: dict[str, Any], spark_session: SparkSession
self, inputs: list[InputField], spark_session: SparkSession
) -> dict[str, DataFrame]:
raise NotImplementedError
for input in inputs:
if _ := input.options.get("topic"):
raise NotImplementedError
else:
raise ValueError(
"Option `topic` must be provided in the `InputField` with KafkaReader types."
)


@dataclass
class PrefixedDataReader:
reader: DataReader
prefix: str
raise NotImplementedError


class MultiReader(DataReader):
def __init__(self, *readers: PrefixedDataReader) -> None:
class MultiDataReader(DataReader):
def __init__(self, *readers: DataReader) -> None:
super().__init__()
self._readers = list(readers)

def read(
self, inputs: dict[str, Any], spark_session: SparkSession
self, inputs: list[InputField], spark_session: SparkSession
) -> dict[str, DataFrame]:
dataframes = {}

for key in inputs.keys():
for prefixed_reader in self._readers:
if key.startswith(prefixed_reader.prefix):
dataframes[key] = prefixed_reader.reader.read(inputs, spark_session)
break
for reader in self._readers:
reader_inputs = [
input for input in inputs if input.type == reader.__class__
]
dfs = reader.read(reader_inputs, spark_session)
dataframes.update(dfs)

return dataframes
10 changes: 10 additions & 0 deletions src/damavand/sparkle/data_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@ class DataWriter(Protocol):
def write(self, df: DataFrame, spark_session: SparkSession) -> None: ...


class KafkaWriter(DataWriter):
def write(self, df: DataFrame, spark_session: SparkSession) -> None:
df.write.format("kafka").save()


class IcebergWriter(DataWriter):
def write(self, df: DataFrame, spark_session: SparkSession) -> None:
df.write.format("iceberg").save()


class MultiDataWriter(DataWriter):
def __init__(self, *writers: DataWriter) -> None:
super().__init__()
Expand Down
13 changes: 11 additions & 2 deletions src/damavand/sparkle/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
from enum import Enum, StrEnum
from typing import Optional, Callable, Any
from typing import Optional, Callable, Any, Type
from dataclasses import dataclass

from damavand.sparkle.data_reader import DataReader


class Environment(StrEnum):
PRODUCTION = "production"
Expand Down Expand Up @@ -45,12 +47,19 @@ def all_values(cls) -> list[str]:
return [e.value for e in cls]


class InputField:
def __init__(self, name: str, type: Type[DataReader], **options: Any) -> None:
self.name = name
self.type = type
self.options = options


@dataclass
class Pipeline:
name: str
func: Callable
description: Optional[str] = None
inputs: dict[str, str] = {}
inputs: list[InputField] = []


@dataclass
Expand Down

0 comments on commit c16c97b

Please sign in to comment.