Skip to content

Commit

Permalink
WIP - bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Sep 27, 2024
1 parent f175b9b commit df66cc9
Show file tree
Hide file tree
Showing 26 changed files with 15,394 additions and 195 deletions.
2 changes: 1 addition & 1 deletion dbt-metricflow/dbt_metricflow/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def query(
limit=limit,
time_constraint_start=start_time,
time_constraint_end=end_time,
where_constraint=where,
where_constraints=[where] if where else None,
order_by_names=order,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _resolve_dependencies(self, saved_query_name: str) -> SavedQueryDependencySe

parse_result = self._query_parser.parse_and_validate_saved_query(
saved_query_parameter=SavedQueryParameter(saved_query_name),
where_filter=None,
where_filters=None,
limit=None,
time_constraint_start=None,
time_constraint_end=None,
Expand Down
48 changes: 29 additions & 19 deletions metricflow-semantics/metricflow_semantics/query/query_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from dbt_semantic_interfaces.type_enums import TimeGranularity

from metricflow_semantics.assert_one_arg import assert_at_most_one_arg_set
from metricflow_semantics.filters.merge_where import merge_to_single_where_filter
from metricflow_semantics.filters.time_constraint import TimeRangeConstraint
from metricflow_semantics.mf_logging.formatting import indent
from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
Expand Down Expand Up @@ -93,7 +92,7 @@ def __init__( # noqa: D107
def parse_and_validate_saved_query(
self,
saved_query_parameter: SavedQueryParameter,
where_filter: Optional[WhereFilter],
where_filters: Optional[Sequence[WhereFilter]],
limit: Optional[int],
time_constraint_start: Optional[datetime.datetime],
time_constraint_end: Optional[datetime.datetime],
Expand All @@ -107,19 +106,19 @@ def parse_and_validate_saved_query(
saved_query = self._get_saved_query(saved_query_parameter)

# Merge interface could streamline this.
where_filters: List[WhereFilter] = []
parsed_where_filters: List[WhereFilter] = []
if saved_query.query_params.where is not None:
where_filters.extend(saved_query.query_params.where.where_filters)
if where_filter is not None:
where_filters.append(where_filter)
parsed_where_filters.extend(saved_query.query_params.where.where_filters)
if where_filters is not None:
parsed_where_filters.extend(where_filters)

return self._parse_and_validate_query(
metric_names=saved_query.query_params.metrics,
metrics=None,
group_by_names=saved_query.query_params.group_by,
group_by=None,
where_constraint=merge_to_single_where_filter(PydanticWhereFilterIntersection(where_filters=where_filters)),
where_constraint_str=None,
where_constraints=PydanticWhereFilterIntersection(where_filters=parsed_where_filters),
where_constraint_strs=None,
time_constraint_start=time_constraint_start,
time_constraint_end=time_constraint_end,
limit=limit,
Expand All @@ -129,6 +128,7 @@ def parse_and_validate_saved_query(
)

def _get_saved_query(self, saved_query_parameter: SavedQueryParameter) -> SavedQuery:
# Maybe: is where filter parsed properly in core? maybe it's missing when it gets here
matching_saved_queries = [
saved_query
for saved_query in self._manifest_lookup.semantic_manifest.saved_queries
Expand Down Expand Up @@ -309,8 +309,8 @@ def parse_and_validate_query(
limit: Optional[int] = None,
time_constraint_start: Optional[datetime.datetime] = None,
time_constraint_end: Optional[datetime.datetime] = None,
where_constraint: Optional[WhereFilter] = None,
where_constraint_str: Optional[str] = None,
where_constraints: Optional[Sequence[WhereFilter]] = None,
where_constraint_strs: Optional[Sequence[str]] = None,
order_by_names: Optional[Sequence[str]] = None,
order_by: Optional[Sequence[OrderByQueryParameter]] = None,
min_max_only: bool = False,
Expand All @@ -329,8 +329,8 @@ def parse_and_validate_query(
limit=limit,
time_constraint_start=time_constraint_start,
time_constraint_end=time_constraint_end,
where_constraint=where_constraint,
where_constraint_str=where_constraint_str,
where_constraints=where_constraints,
where_constraint_strs=where_constraint_strs,
order_by_names=order_by_names,
order_by=order_by,
min_max_only=min_max_only,
Expand All @@ -346,8 +346,8 @@ def _parse_and_validate_query(
limit: Optional[int],
time_constraint_start: Optional[datetime.datetime],
time_constraint_end: Optional[datetime.datetime],
where_constraint: Optional[WhereFilter],
where_constraint_str: Optional[str],
where_constraints: Optional[Sequence[WhereFilter]],
where_constraint_strs: Optional[Sequence[str]],
order_by_names: Optional[Sequence[str]],
order_by: Optional[Sequence[OrderByQueryParameter]],
min_max_only: bool,
Expand All @@ -357,7 +357,7 @@ def _parse_and_validate_query(
assert_at_most_one_arg_set(metric_names=metric_names, metrics=metrics)
assert_at_most_one_arg_set(group_by_names=group_by_names, group_by=group_by)
assert_at_most_one_arg_set(order_by_names=order_by_names, order_by=order_by)
assert_at_most_one_arg_set(where_constraint=where_constraint, where_constraint_str=where_constraint_str)
assert_at_most_one_arg_set(where_constraints=where_constraints, where_constraint_strs=where_constraint_strs)

metric_names = metric_names or ()
metrics = metrics or ()
Expand Down Expand Up @@ -455,10 +455,20 @@ def _parse_and_validate_query(

where_filters: List[PydanticWhereFilter] = []

if where_constraint is not None:
where_filters.append(PydanticWhereFilter(where_sql_template=where_constraint.where_sql_template))
if where_constraint_str is not None:
where_filters.append(PydanticWhereFilter(where_sql_template=where_constraint_str))
if where_constraints is not None:
where_filters.extend(
[
PydanticWhereFilter(where_sql_template=constraint.where_sql_template)
for constraint in where_constraints
]
)
if where_constraint_strs is not None:
where_filters.extend(
[
PydanticWhereFilter(where_sql_template=where_constraint_str)
for where_constraint_str in where_constraint_strs
]
)

resolver_input_for_filter = ResolverInputForQueryLevelWhereFilterIntersection(
where_filter_intersection=PydanticWhereFilterIntersection(where_filters=where_filters)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def test_parse_and_validate_where_constraint_dims(
group_by_names=[MTD],
time_constraint_start=as_datetime("2020-01-15"),
time_constraint_end=as_datetime("2020-02-15"),
where_constraint_str="{{ Dimension('booking__invalid_dim') }} = '1'",
where_constraint_strs=["{{ Dimension('booking__invalid_dim') }} = '1'"],
)

with pytest.raises(InvalidQueryException, match="Error parsing where filter"):
Expand All @@ -337,15 +337,15 @@ def test_parse_and_validate_where_constraint_dims(
group_by_names=[MTD],
time_constraint_start=as_datetime("2020-01-15"),
time_constraint_end=as_datetime("2020-02-15"),
where_constraint_str="{{ Dimension('invalid_format') }} = '1'",
where_constraint_strs=["{{ Dimension('invalid_format') }} = '1'"],
)

result = bookings_query_parser.parse_and_validate_query(
metric_names=["bookings"],
group_by_names=[MTD],
time_constraint_start=as_datetime("2020-01-15"),
time_constraint_end=as_datetime("2020-02-15"),
where_constraint_str="{{ Dimension('booking__is_instant') }} = '1'",
where_constraint_strs=["{{ Dimension('booking__is_instant') }} = '1'"],
)
assert_object_snapshot_equal(request=request, mf_test_configuration=mf_test_configuration, obj=result)
assert (
Expand All @@ -366,7 +366,7 @@ def test_parse_and_validate_where_constraint_metric_time(
query_parser.parse_and_validate_query(
metric_names=["revenue"],
group_by_names=[MTD],
where_constraint_str="{{ TimeDimension('metric_time', 'day') }} > '2020-01-15'",
where_constraint_strs=["{{ TimeDimension('metric_time', 'day') }} > '2020-01-15'"],
)


Expand Down Expand Up @@ -622,5 +622,5 @@ def test_invalid_group_by_metric(bookings_query_parser: MetricFlowQueryParser) -
"""Tests that a query for an invalid group by metric gives an appropriate group by metric suggestion."""
with pytest.raises(InvalidQueryException, match="Metric\\('bookings', group_by=\\['listing'\\]\\)"):
bookings_query_parser.parse_and_validate_query(
metric_names=("bookings",), where_constraint_str="{{ Metric('listings', ['garbage']) }} > 1"
metric_names=("bookings",), where_constraint_strs=["{{ Metric('listings', ['garbage']) }} > 1"]
)
12 changes: 9 additions & 3 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1715,14 +1715,20 @@ def _build_aggregated_measure_from_measure_source_node(
time_range_constraint=predicate_pushdown_state.time_range_constraint,
offset_window=after_aggregation_time_spine_join_description.offset_window,
offset_to_grain=after_aggregation_time_spine_join_description.offset_to_grain,
# Add any constraints on metric_time / agg time dim to the time spine join node, where we can properly filter.
# Make sure you handle filters on both larger and smaller granularities than what was queried
# But what if there is a filter that uses both metric_time and another dimension?
# filters=metric_input_measure_spec.filter_specs,
)

# Since new rows might have been added due to time spine join, re-apply constraints here. Only re-apply filters
# for specs that are also in the queried specs, since those are the only ones that might have changed after the
# time spine join.
# Since new rows might have been added due to time spine join, re-apply constraints here. Re-apply filters
# for specs that are were requested in the group by since those are the only ones that might have changed
# after the time spine join, and they are the only specs still available after filtering.
queried_filter_specs = [
filter_spec
for filter_spec in metric_input_measure_spec.filter_specs
# Only applying filters on queried linkable specs, but should apply filters on all linkable specs / at least metric_time
# Split this filter up by AND - possible?
if set(filter_spec.linkable_specs).issubset(set(queried_linkable_specs.as_tuple))
]
if len(queried_filter_specs) > 0:
Expand Down
18 changes: 8 additions & 10 deletions metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class MetricFlowQueryRequest:
limit: Limit the result to this many rows.
time_constraint_start: Get data for the start of this time range.
time_constraint_end: Get data for the end of this time range.
where_constraint: A SQL string using group by names that can be used like a where clause on the output data.
where_constraints: A sequence of SQL strings that can be used like a where clause on the output data.
order_by_names: metric and group by names to order by. A "-" can be used to specify reverse order e.g. "-ds".
order_by: metric, dimension, or entity objects to order by.
output_table: If specified, output the result data to this table instead of a result data_table.
Expand All @@ -106,7 +106,7 @@ class MetricFlowQueryRequest:
limit: Optional[int] = None
time_constraint_start: Optional[datetime.datetime] = None
time_constraint_end: Optional[datetime.datetime] = None
where_constraint: Optional[str] = None
where_constraints: Optional[Sequence[str]] = None
order_by_names: Optional[Sequence[str]] = None
order_by: Optional[Sequence[OrderByQueryParameter]] = None
min_max_only: bool = False
Expand All @@ -124,7 +124,7 @@ def create_with_random_request_id( # noqa: D102
limit: Optional[int] = None,
time_constraint_start: Optional[datetime.datetime] = None,
time_constraint_end: Optional[datetime.datetime] = None,
where_constraint: Optional[str] = None,
where_constraints: Optional[Sequence[str]] = None,
order_by_names: Optional[Sequence[str]] = None,
order_by: Optional[Sequence[OrderByQueryParameter]] = None,
sql_optimization_level: SqlQueryOptimizationLevel = SqlQueryOptimizationLevel.O4,
Expand All @@ -144,7 +144,7 @@ def create_with_random_request_id( # noqa: D102
limit=limit,
time_constraint_start=time_constraint_start,
time_constraint_end=time_constraint_end,
where_constraint=where_constraint,
where_constraints=where_constraints,
order_by_names=order_by_names,
order_by=order_by,
sql_optimization_level=sql_optimization_level,
Expand Down Expand Up @@ -468,11 +468,9 @@ def _create_execution_plan(self, mf_query_request: MetricFlowQueryRequest) -> Me
raise InvalidQueryException("Group by items can't be specified with a saved query.")
query_spec = self._query_parser.parse_and_validate_saved_query(
saved_query_parameter=SavedQueryParameter(mf_query_request.saved_query_name),
where_filter=(
PydanticWhereFilter(where_sql_template=mf_query_request.where_constraint)
if mf_query_request.where_constraint is not None
else None
),
where_filters=[PydanticWhereFilter(where_sql_template=mf_query_request.where_constraints)]
if mf_query_request.where_constraints is not None
else None,
limit=mf_query_request.limit,
time_constraint_start=mf_query_request.time_constraint_start,
time_constraint_end=mf_query_request.time_constraint_end,
Expand All @@ -488,7 +486,7 @@ def _create_execution_plan(self, mf_query_request: MetricFlowQueryRequest) -> Me
limit=mf_query_request.limit,
time_constraint_start=mf_query_request.time_constraint_start,
time_constraint_end=mf_query_request.time_constraint_end,
where_constraint_str=mf_query_request.where_constraint,
where_constraint_strs=mf_query_request.where_constraints,
order_by_names=mf_query_request.order_by_names,
order_by=mf_query_request.order_by,
min_max_only=mf_query_request.min_max_only,
Expand Down
Loading

0 comments on commit df66cc9

Please sign in to comment.