diff --git a/pyproject.toml b/pyproject.toml index 620a7e8..fedaaf4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pysaql" -version = "0.8.0" +version = "0.9.0" description = "Python SAQL query builder" authors = ["Jonathan Drake "] license = "BSD-3-Clause" diff --git a/pysaql/__init__.py b/pysaql/__init__.py index 8f16e39..4f4eccc 100644 --- a/pysaql/__init__.py +++ b/pysaql/__init__.py @@ -1,3 +1,3 @@ """Python SAQL query builder""" -__version__ = "0.8.0" +__version__ = "0.9.0" diff --git a/pysaql/expression.py b/pysaql/expression.py index 8c612f9..07d80ec 100644 --- a/pysaql/expression.py +++ b/pysaql/expression.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod +from copy import deepcopy from typing import Optional from .util import escape_identifier @@ -19,15 +20,19 @@ class Expression(ABC): def alias(self, name: str) -> "Expression": """Set the alias name + This creates and returns a new expression object so a single field can be + aliased multiple times. + Args: name: Alias name Returns: - self + new expression object with alias """ - self._alias = name - return self + new_expr = deepcopy(self) + new_expr._alias = name + return new_expr @abstractmethod def to_string(self) -> str: diff --git a/pysaql/function.py b/pysaql/function.py index 378688e..8322ac9 100644 --- a/pysaql/function.py +++ b/pysaql/function.py @@ -147,6 +147,22 @@ def __init__(self, n: Scalar, m: int) -> None: super().__init__(n, m) +class sign(Function): + """Returns sign of the number (-1, 0, 1) + + See: https://developer.salesforce.com/docs/atlas.en-us.bi_dev_guide_saql.meta/bi_dev_guide_saql/bi_saql_functions_math_sign.htm + """ + + def __init__(self, value: Scalar) -> None: + """Returns sign of the number (-1, 0, 1) + + Args: + value: Number to determine the sign + + """ + super().__init__(value) + + class trunc(Function): """Returns the value of the numeric expression n truncated to m decimal places diff --git a/pysaql/scalar.py b/pysaql/scalar.py index a6e0486..80176f0 100644 --- a/pysaql/scalar.py +++ b/pysaql/scalar.py @@ -4,7 +4,7 @@ from abc import ABC import operator -from typing import Any, Callable, Sequence, Union +from typing import Any, Callable, Optional, Protocol, Sequence, Union from .expression import Expression from .util import escape_identifier, stringify @@ -271,22 +271,39 @@ def to_string(self) -> str: return f"{OPERATOR_STRINGS[self.op]} {stringify(self.value)}" +class StreamProtocol(Protocol): + """Protocol definition for a stream interface + + This is defined to prevent recursive dependencies + """ + + @property + def ref(self) -> str: + """Stream reference in the SAQL query""" + pass + + class field(Scalar): """Represents a field (column) in the data stream""" - def __init__(self, name: str) -> None: + def __init__(self, name: str, stream: Optional[StreamProtocol] = None) -> None: """Represents a field (column) in the data stream Args: name: Name of the field + stream: Optional stream. Providing a stream indicates the field + reference string should include a stream prefix to distinguish them from + fields in other streams. """ super().__init__() self.name = name + self.stream = stream def to_string(self) -> str: """Cast the field to a string""" - return escape_identifier(self.name) + prefix = f"{self.stream.ref}." if self.stream else "" + return prefix + escape_identifier(self.name) class literal(Scalar): diff --git a/pysaql/stream.py b/pysaql/stream.py index 68c98d2..acbfbb2 100644 --- a/pysaql/stream.py +++ b/pysaql/stream.py @@ -82,6 +82,18 @@ def add_statement(self, statement: StreamStatement) -> None: """ self._statements.append(statement) + def field(self, name: str) -> field: + """Create a new field object scoped to this stream + + Args: + name: Name of the field + + Returns: + field object + + """ + return field(name, stream=self) + def foreach(self, *fields: Scalar) -> Stream: """Applies a set of expressions to every row in a dataset. @@ -156,9 +168,9 @@ def limit(self, limit: int) -> Stream: def fill( self, - date_cols: Sequence[field], + date_cols: Sequence[Scalar], date_type_string: FillDateTypeString, - partition: Optional[field] = None, + partition: Optional[Scalar] = None, ) -> Stream: """Fills missing date values by adding rows in data stream @@ -392,9 +404,9 @@ class FillStatement(StreamStatement): def __init__( self, stream: Stream, - date_cols: Sequence[field], + date_cols: Sequence[Scalar], date_type_string: FillDateTypeString, - partition: Optional[field] = None, + partition: Optional[Scalar] = None, ) -> None: """Initializer @@ -440,7 +452,8 @@ def load(name: str) -> Stream: def cogroup( - *streams: Tuple[Stream, Scalar], join_type: JoinType = JoinType.inner + *streams: Tuple[Stream, Union[Scalar, Sequence[Scalar], str]], + join_type: JoinType = JoinType.inner, ) -> Stream: """Combine data from two or more data streams into a single data stream diff --git a/tests/unit/test_stream.py b/tests/unit/test_stream.py index 675a1fb..4b82220 100644 --- a/tests/unit/test_stream.py +++ b/tests/unit/test_stream.py @@ -90,6 +90,20 @@ def test_foreach(): assert str(stream) == "q0 = foreach q0 generate 'name', 'number' as 'n';" +def test_foreach__cogroup(): + """Should generate field projections on top of a cogroup""" + q0 = load("q0_dataset") + q1 = load("q1_dataset") + c0 = cogroup((q0, [field("a"), field("b")]), (q1, [field("a"), field("b")])) + c0.foreach(q0.field("a"), q1.field("b")) + assert str(c0).split("\n") == [ + """q0 = load "q0_dataset";""", + """q1 = load "q1_dataset";""", + """q2 = cogroup q0 by ('a', 'b'), q1 by ('a', 'b');""", + """q2 = foreach q2 generate q0.'a', q1.'b';""", + ] + + def test_group__all(): """Should group by all when no fields are provided""" stream = Stream()