From e3c7e41906a5cbfc7b38ee33dfbd0ae4c585ad4e Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Fri, 1 Nov 2024 11:58:47 -0700 Subject: [PATCH] /* PR_START p--cte 05 */ Add `eq=False` to node classes. This sets `eq=False` for node dataclasses as nodes should be unique, and equivalence methods should be used for checking as the nodes represent recursive data structures. Without this, comparisons or using nodes as keys can be slow as the generated equals function will traverse the recursive structure. --- .../metricflow_semantics/dag/mf_dag.py | 2 +- metricflow/dataflow/dataflow_plan.py | 2 +- .../dataflow/nodes/add_generated_uuid.py | 2 +- .../dataflow/nodes/aggregate_measures.py | 2 +- .../nodes/combine_aggregated_outputs.py | 2 +- metricflow/dataflow/nodes/compute_metrics.py | 2 +- metricflow/dataflow/nodes/constrain_time.py | 2 +- metricflow/dataflow/nodes/filter_elements.py | 2 +- .../dataflow/nodes/join_conversion_events.py | 2 +- metricflow/dataflow/nodes/join_over_time.py | 2 +- metricflow/dataflow/nodes/join_to_base.py | 2 +- .../nodes/join_to_custom_granularity.py | 2 +- .../dataflow/nodes/join_to_time_spine.py | 2 +- .../dataflow/nodes/metric_time_transform.py | 2 +- metricflow/dataflow/nodes/min_max.py | 2 +- metricflow/dataflow/nodes/order_by_limit.py | 2 +- metricflow/dataflow/nodes/read_sql_source.py | 2 +- .../dataflow/nodes/semi_additive_join.py | 2 +- metricflow/dataflow/nodes/where_filter.py | 2 +- .../nodes/window_reaggregation_node.py | 2 +- .../dataflow/nodes/write_to_data_table.py | 2 +- metricflow/dataflow/nodes/write_to_table.py | 2 +- metricflow/sql/sql_exprs.py | 39 ++++++++++--------- metricflow/sql/sql_plan.py | 12 +++--- 24 files changed, 48 insertions(+), 47 deletions(-) diff --git a/metricflow-semantics/metricflow_semantics/dag/mf_dag.py b/metricflow-semantics/metricflow_semantics/dag/mf_dag.py index a3a24774a2..74b90ca94b 100644 --- a/metricflow-semantics/metricflow_semantics/dag/mf_dag.py +++ b/metricflow-semantics/metricflow_semantics/dag/mf_dag.py @@ -57,7 +57,7 @@ def visit_node(self, node: DagNode) -> VisitorOutputT: # noqa: D102 DagNodeT = TypeVar("DagNodeT", bound="DagNode") -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class DagNode(MetricFlowPrettyFormattable, Generic[DagNodeT], ABC): """A node in a DAG. These should be immutable.""" diff --git a/metricflow/dataflow/dataflow_plan.py b/metricflow/dataflow/dataflow_plan.py index b986d7b8f6..dd5f1088b9 100644 --- a/metricflow/dataflow/dataflow_plan.py +++ b/metricflow/dataflow/dataflow_plan.py @@ -44,7 +44,7 @@ NodeSelfT = TypeVar("NodeSelfT", bound="DataflowPlanNode") -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class DataflowPlanNode(DagNode["DataflowPlanNode"], Visitable, ABC): """A node in the graph representation of the dataflow. diff --git a/metricflow/dataflow/nodes/add_generated_uuid.py b/metricflow/dataflow/nodes/add_generated_uuid.py index 6a5a1c2b9f..96df59388d 100644 --- a/metricflow/dataflow/nodes/add_generated_uuid.py +++ b/metricflow/dataflow/nodes/add_generated_uuid.py @@ -10,7 +10,7 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class AddGeneratedUuidColumnNode(DataflowPlanNode): """Adds a UUID column.""" diff --git a/metricflow/dataflow/nodes/aggregate_measures.py b/metricflow/dataflow/nodes/aggregate_measures.py index 9128f6a0a6..7fa8c153ce 100644 --- a/metricflow/dataflow/nodes/aggregate_measures.py +++ b/metricflow/dataflow/nodes/aggregate_measures.py @@ -10,7 +10,7 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class AggregateMeasuresNode(DataflowPlanNode): """A node that aggregates the measures by the associated group by elements. diff --git a/metricflow/dataflow/nodes/combine_aggregated_outputs.py b/metricflow/dataflow/nodes/combine_aggregated_outputs.py index 0f022ec8a2..c1c3ad2e3f 100644 --- a/metricflow/dataflow/nodes/combine_aggregated_outputs.py +++ b/metricflow/dataflow/nodes/combine_aggregated_outputs.py @@ -12,7 +12,7 @@ ) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class CombineAggregatedOutputsNode(DataflowPlanNode): """Combines metrics from different nodes into a single output.""" diff --git a/metricflow/dataflow/nodes/compute_metrics.py b/metricflow/dataflow/nodes/compute_metrics.py index 9d4ad3dd92..3220d4374d 100644 --- a/metricflow/dataflow/nodes/compute_metrics.py +++ b/metricflow/dataflow/nodes/compute_metrics.py @@ -16,7 +16,7 @@ ) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class ComputeMetricsNode(DataflowPlanNode): """A node that computes metrics from input measures. Dimensions / entities are passed through. diff --git a/metricflow/dataflow/nodes/constrain_time.py b/metricflow/dataflow/nodes/constrain_time.py index 7ca0ace50b..9694039a44 100644 --- a/metricflow/dataflow/nodes/constrain_time.py +++ b/metricflow/dataflow/nodes/constrain_time.py @@ -12,7 +12,7 @@ from metricflow.dataflow.nodes.aggregate_measures import DataflowPlanNode -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class ConstrainTimeRangeNode(DataflowPlanNode): """Constrains the time range of the input data set. diff --git a/metricflow/dataflow/nodes/filter_elements.py b/metricflow/dataflow/nodes/filter_elements.py index e38ef4e0c2..9605f58bdd 100644 --- a/metricflow/dataflow/nodes/filter_elements.py +++ b/metricflow/dataflow/nodes/filter_elements.py @@ -12,7 +12,7 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class FilterElementsNode(DataflowPlanNode): """Only passes the listed elements. diff --git a/metricflow/dataflow/nodes/join_conversion_events.py b/metricflow/dataflow/nodes/join_conversion_events.py index ed42661222..cc6530996b 100644 --- a/metricflow/dataflow/nodes/join_conversion_events.py +++ b/metricflow/dataflow/nodes/join_conversion_events.py @@ -16,7 +16,7 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class JoinConversionEventsNode(DataflowPlanNode): """Builds a data set containing successful conversion events. diff --git a/metricflow/dataflow/nodes/join_over_time.py b/metricflow/dataflow/nodes/join_over_time.py index b137175472..92087cf97b 100644 --- a/metricflow/dataflow/nodes/join_over_time.py +++ b/metricflow/dataflow/nodes/join_over_time.py @@ -14,7 +14,7 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class JoinOverTimeRangeNode(DataflowPlanNode): """A node that allows for cumulative metric computation by doing a self join across a cumulative date range. diff --git a/metricflow/dataflow/nodes/join_to_base.py b/metricflow/dataflow/nodes/join_to_base.py index 2a6b2c458e..50791c580b 100644 --- a/metricflow/dataflow/nodes/join_to_base.py +++ b/metricflow/dataflow/nodes/join_to_base.py @@ -43,7 +43,7 @@ def __post_init__(self) -> None: # noqa: D105 raise RuntimeError("`join_on_entity` is required unless using CROSS JOIN.") -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class JoinOnEntitiesNode(DataflowPlanNode): """A node that joins data from other nodes via the entities in the inputs. diff --git a/metricflow/dataflow/nodes/join_to_custom_granularity.py b/metricflow/dataflow/nodes/join_to_custom_granularity.py index 6f13f6ece4..2e70b36037 100644 --- a/metricflow/dataflow/nodes/join_to_custom_granularity.py +++ b/metricflow/dataflow/nodes/join_to_custom_granularity.py @@ -12,7 +12,7 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class JoinToCustomGranularityNode(DataflowPlanNode, ABC): """Join parent dataset to time spine dataset to convert time dimension to a custom granularity. diff --git a/metricflow/dataflow/nodes/join_to_time_spine.py b/metricflow/dataflow/nodes/join_to_time_spine.py index 00633a0fa0..dfc0f10151 100644 --- a/metricflow/dataflow/nodes/join_to_time_spine.py +++ b/metricflow/dataflow/nodes/join_to_time_spine.py @@ -17,7 +17,7 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class JoinToTimeSpineNode(DataflowPlanNode, ABC): """Join parent dataset to time spine dataset. diff --git a/metricflow/dataflow/nodes/metric_time_transform.py b/metricflow/dataflow/nodes/metric_time_transform.py index 47e5df2ffd..5687904d76 100644 --- a/metricflow/dataflow/nodes/metric_time_transform.py +++ b/metricflow/dataflow/nodes/metric_time_transform.py @@ -11,7 +11,7 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class MetricTimeDimensionTransformNode(DataflowPlanNode): """A node transforms the input data set so that it contains the metric time dimension and relevant measures. diff --git a/metricflow/dataflow/nodes/min_max.py b/metricflow/dataflow/nodes/min_max.py index 40fa160739..c7713185f5 100644 --- a/metricflow/dataflow/nodes/min_max.py +++ b/metricflow/dataflow/nodes/min_max.py @@ -9,7 +9,7 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class MinMaxNode(DataflowPlanNode): """Calculate the min and max of a single instance data set.""" diff --git a/metricflow/dataflow/nodes/order_by_limit.py b/metricflow/dataflow/nodes/order_by_limit.py index 0bb1c77b99..f7cbacdf0c 100644 --- a/metricflow/dataflow/nodes/order_by_limit.py +++ b/metricflow/dataflow/nodes/order_by_limit.py @@ -14,7 +14,7 @@ ) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class OrderByLimitNode(DataflowPlanNode): """A node that re-orders the input data with a limit. diff --git a/metricflow/dataflow/nodes/read_sql_source.py b/metricflow/dataflow/nodes/read_sql_source.py index de1da2f604..57a272dffc 100644 --- a/metricflow/dataflow/nodes/read_sql_source.py +++ b/metricflow/dataflow/nodes/read_sql_source.py @@ -15,7 +15,7 @@ from metricflow.dataset.sql_dataset import SqlDataSet -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class ReadSqlSourceNode(DataflowPlanNode): """A source node where data from an SQL table or SQL query is read and output. diff --git a/metricflow/dataflow/nodes/semi_additive_join.py b/metricflow/dataflow/nodes/semi_additive_join.py index 3eaff4d88c..1334cde336 100644 --- a/metricflow/dataflow/nodes/semi_additive_join.py +++ b/metricflow/dataflow/nodes/semi_additive_join.py @@ -13,7 +13,7 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SemiAdditiveJoinNode(DataflowPlanNode): """A node that performs a row filter by aggregating a given non-additive dimension. diff --git a/metricflow/dataflow/nodes/where_filter.py b/metricflow/dataflow/nodes/where_filter.py index 7b0bef6cda..1152376a1d 100644 --- a/metricflow/dataflow/nodes/where_filter.py +++ b/metricflow/dataflow/nodes/where_filter.py @@ -11,7 +11,7 @@ from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class WhereConstraintNode(DataflowPlanNode): """Remove rows using a WHERE clause. diff --git a/metricflow/dataflow/nodes/window_reaggregation_node.py b/metricflow/dataflow/nodes/window_reaggregation_node.py index 93dbf950ba..3bbe202c9f 100644 --- a/metricflow/dataflow/nodes/window_reaggregation_node.py +++ b/metricflow/dataflow/nodes/window_reaggregation_node.py @@ -17,7 +17,7 @@ from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class WindowReaggregationNode(DataflowPlanNode): """A node that re-aggregates metrics using window functions. diff --git a/metricflow/dataflow/nodes/write_to_data_table.py b/metricflow/dataflow/nodes/write_to_data_table.py index 39f6eb0fb0..66701d7399 100644 --- a/metricflow/dataflow/nodes/write_to_data_table.py +++ b/metricflow/dataflow/nodes/write_to_data_table.py @@ -12,7 +12,7 @@ ) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class WriteToResultDataTableNode(DataflowPlanNode): """A node where incoming data gets written to a data_table.""" diff --git a/metricflow/dataflow/nodes/write_to_table.py b/metricflow/dataflow/nodes/write_to_table.py index a17c4bd7a7..7a55a5724d 100644 --- a/metricflow/dataflow/nodes/write_to_table.py +++ b/metricflow/dataflow/nodes/write_to_table.py @@ -13,7 +13,7 @@ ) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class WriteToResultTableNode(DataflowPlanNode): """A node where incoming data gets written to a table. diff --git a/metricflow/sql/sql_exprs.py b/metricflow/sql/sql_exprs.py index 775e962a8b..62d8e96874 100644 --- a/metricflow/sql/sql_exprs.py +++ b/metricflow/sql/sql_exprs.py @@ -22,7 +22,7 @@ from typing_extensions import override -@dataclass(frozen=True, order=True) +@dataclass(frozen=True, eq=False) class SqlExpressionNode(DagNode["SqlExpressionNode"], Visitable, ABC): """An SQL expression like my_table.my_column, CONCAT(a, b) or 1 + 1 that evaluates to a value.""" @@ -230,7 +230,7 @@ def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> VisitorOu pass -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlStringExpression(SqlExpressionNode): """An SQL expression in a string format, so it lacks information about the structure. @@ -314,7 +314,7 @@ def as_string_expression(self) -> Optional[SqlStringExpression]: return self -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlStringLiteralExpression(SqlExpressionNode): """A string literal like 'foo'. It shouldn't include delimiters as it should be added during rendering.""" @@ -375,7 +375,7 @@ class SqlColumnReference: column_name: str -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlColumnReferenceExpression(SqlExpressionNode): """An expression that evaluates to the value of a column in one of the sources in the select query. @@ -475,7 +475,7 @@ def from_table_and_column_names(table_alias: str, column_name: str) -> SqlColumn return SqlColumnReferenceExpression.create(SqlColumnReference(table_alias=table_alias, column_name=column_name)) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlColumnAliasReferenceExpression(SqlExpressionNode): """An expression that evaluates to the alias of a column, but is not qualified with a table alias. @@ -544,7 +544,7 @@ class SqlComparison(Enum): # noqa: D101 EQUALS = "=" -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlComparisonExpression(SqlExpressionNode): """A comparison using >, <, <=, >=, =. @@ -698,6 +698,7 @@ def from_aggregation_type(aggregation_type: AggregationType) -> SqlFunction: assert_values_exhausted(aggregation_type) +@dataclass(frozen=True, eq=False) class SqlFunctionExpression(SqlExpressionNode): """Denotes a function expression in SQL.""" @@ -723,7 +724,7 @@ def build_expression_from_aggregation_type( return SqlAggregateFunctionExpression.from_aggregation_type(aggregation_type, sql_column_expression) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlAggregateFunctionExpression(SqlFunctionExpression): """An aggregate function expression like SUM(1). @@ -857,7 +858,7 @@ def from_aggregation_parameters(agg_params: MeasureAggregationParameters) -> Sql ) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlPercentileExpression(SqlFunctionExpression): """A percentile aggregation expression. @@ -984,7 +985,7 @@ def suffix(self) -> str: return " ".join(result) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlWindowFunctionExpression(SqlFunctionExpression): """A window function expression like SUM(foo) OVER bar. @@ -1101,7 +1102,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 ) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlNullExpression(SqlExpressionNode): """Represents NULL.""" @@ -1151,7 +1152,7 @@ class SqlLogicalOperator(Enum): OR = "OR" -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlLogicalExpression(SqlExpressionNode): """A logical expression like "a AND b AND c".""" @@ -1203,7 +1204,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return self.operator == other.operator and self._parents_match(other) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlIsNullExpression(SqlExpressionNode): """An IS NULL expression like "foo IS NULL".""" @@ -1248,7 +1249,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return self._parents_match(other) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlSubtractTimeIntervalExpression(SqlExpressionNode): """Represents an interval subtraction from a given timestamp. @@ -1313,7 +1314,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return self.count == other.count and self.granularity == other.granularity and self._parents_match(other) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlCastToTimestampExpression(SqlExpressionNode): """Cast to the timestamp type like CAST('2020-01-01' AS TIMESTAMP).""" @@ -1360,7 +1361,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return self._parents_match(other) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlDateTruncExpression(SqlExpressionNode): """Apply a date trunc to a column like CAST('2020-01-01' AS TIMESTAMP).""" @@ -1411,7 +1412,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return self.time_granularity == other.time_granularity and self._parents_match(other) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlExtractExpression(SqlExpressionNode): """Extract a date part from a time expression. @@ -1470,7 +1471,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return self.date_part == other.date_part and self._parents_match(other) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlRatioComputationExpression(SqlExpressionNode): """Node for expressing Ratio metrics to allow for appropriate casting to float/double in each engine. @@ -1535,7 +1536,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return self._parents_match(other) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlBetweenExpression(SqlExpressionNode): """A BETWEEN clause like `column BETWEEN val1 AND val2`. @@ -1600,7 +1601,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return self._parents_match(other) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlGenerateUuidExpression(SqlExpressionNode): """Renders a SQL to generate a random UUID, which is non-deterministic.""" diff --git a/metricflow/sql/sql_plan.py b/metricflow/sql/sql_plan.py index ea4a61756b..52a7390429 100644 --- a/metricflow/sql/sql_plan.py +++ b/metricflow/sql/sql_plan.py @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlQueryPlanNode(DagNode["SqlQueryPlanNode"], ABC): """Modeling a SQL query plan like a data flow plan as well. @@ -105,7 +105,7 @@ class SqlOrderByDescription: # noqa: D101 desc: bool -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlSelectStatementNode(SqlQueryPlanNode): """Represents an SQL Select statement. @@ -197,7 +197,7 @@ def description(self) -> str: return self._description -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlTableNode(SqlQueryPlanNode): """An SQL table that can go in the FROM clause or the JOIN clause.""" @@ -234,7 +234,7 @@ def as_select_node(self) -> Optional[SqlSelectStatementNode]: # noqa: D102 return None -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlSelectQueryFromClauseNode(SqlQueryPlanNode): """An SQL select query that can go in the FROM clause. @@ -271,7 +271,7 @@ def as_select_node(self) -> Optional[SqlSelectStatementNode]: # noqa: D102 return None -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlCreateTableAsNode(SqlQueryPlanNode): """An SQL node representing a CREATE TABLE AS statement. @@ -343,7 +343,7 @@ def render_node(self) -> SqlQueryPlanNode: # noqa: D102 return self._render_node -@dataclass(frozen=True) +@dataclass(frozen=True, eq=False) class SqlCteNode(SqlQueryPlanNode): """Represents a single common table expression."""