Skip to content

Commit

Permalink
more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kylemumma committed Jan 15, 2025
1 parent d18ed92 commit daa210c
Show file tree
Hide file tree
Showing 2 changed files with 246 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ def aggregation_filter_to_expression(agg_filter: AggregationFilter) -> Expressio
return or_cond(
*(
aggregation_filter_to_expression(x)
for x in agg_filter.and_filter.filters
for x in agg_filter.or_filter.filters
)
)
case default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from google.protobuf.json_format import MessageToDict, ParseDict
from google.protobuf.timestamp_pb2 import Timestamp
from sentry_protos.snuba.v1.endpoint_trace_item_table_pb2 import (
AggregationAndFilter,
AggregationComparisonFilter,
AggregationFilter,
AggregationOrFilter,
Column,
TraceItemColumnValues,
TraceItemTableRequest,
Expand Down Expand Up @@ -1042,7 +1044,11 @@ def test_table_with_group_by_columns_without_aggregation(
with pytest.raises(BadSnubaRPCRequestException):
EndpointTraceItemTable().execute(message)

def test_aggregation_filter(self, setup_teardown: Any) -> None:
def test_aggregation_filter_basic(self, setup_teardown: Any) -> None:
"""
This test ensures that aggregates are properly filtered out
when using an aggregation filter `val > 350`.
"""
# first I write new messages with different value of kylestags,
# theres a different number of messages for each tag so that
# each will have a different sum value when i do aggregate
Expand All @@ -1057,6 +1063,7 @@ def test_aggregation_filter(self, setup_teardown: Any) -> None:

ts = Timestamp(seconds=int(BASE_TIME.timestamp()))
hour_ago = int((BASE_TIME - timedelta(hours=1)).timestamp())

message = TraceItemTableRequest(
meta=RequestMeta(
project_ids=[1, 2, 3],
Expand Down Expand Up @@ -1102,6 +1109,7 @@ def test_aggregation_filter(self, setup_teardown: Any) -> None:
key=AttributeKey(
type=AttributeKey.TYPE_FLOAT, name="my.float.field"
),
label="this-doesnt-matter-and-can-be-left-out",
),
op=AggregationComparisonFilter.OP_GREATER_THAN,
val=350,
Expand All @@ -1126,6 +1134,242 @@ def test_aggregation_filter(self, setup_teardown: Any) -> None:
),
]

def test_aggregation_filter_and_or(self, setup_teardown: Any) -> None:
"""
This test ensures that aggregates are properly filtered out
when using an aggregation filter `val > 350 and val > 350`.
It also tests `val > 350 or val < 350`.
"""
# first I write new messages with different value of kylestags,
# theres a different number of messages for each tag so that
# each will have a different sum value when i do aggregate
spans_storage = get_storage(StorageKey("eap_spans"))
msg_timestamp = BASE_TIME - timedelta(minutes=1)
messages = (
[gen_message(msg_timestamp, tags={"kylestag": "val1"}) for i in range(3)]
+ [gen_message(msg_timestamp, tags={"kylestag": "val2"}) for i in range(12)]
+ [gen_message(msg_timestamp, tags={"kylestag": "val3"}) for i in range(30)]
)
write_raw_unprocessed_events(spans_storage, messages) # type: ignore

ts = Timestamp(seconds=int(BASE_TIME.timestamp()))
hour_ago = int((BASE_TIME - timedelta(hours=1)).timestamp())

base_message = TraceItemTableRequest(
meta=RequestMeta(
project_ids=[1, 2, 3],
organization_id=1,
cogs_category="something",
referrer="something",
start_timestamp=Timestamp(seconds=hour_ago),
end_timestamp=ts,
trace_item_name=TraceItemName.TRACE_ITEM_NAME_EAP_SPANS,
),
filter=TraceItemFilter(
exists_filter=ExistsFilter(
key=AttributeKey(type=AttributeKey.TYPE_STRING, name="kylestag")
)
),
columns=[
Column(
key=AttributeKey(type=AttributeKey.TYPE_STRING, name="kylestag")
),
Column(
aggregation=AttributeAggregation(
aggregate=Function.FUNCTION_SUM,
key=AttributeKey(
type=AttributeKey.TYPE_FLOAT, name="my.float.field"
),
label="sum(my.float.field)",
extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE,
),
),
],
group_by=[AttributeKey(type=AttributeKey.TYPE_STRING, name="kylestag")],
order_by=[
TraceItemTableRequest.OrderBy(
column=Column(
key=AttributeKey(type=AttributeKey.TYPE_STRING, name="kylestag")
)
),
],
aggregation_filter=AggregationFilter( # same filter on both sides of the and
and_filter=AggregationAndFilter(
filters=[
AggregationFilter(
comparison_filter=AggregationComparisonFilter(
aggregation=AttributeAggregation(
aggregate=Function.FUNCTION_SUM,
key=AttributeKey(
type=AttributeKey.TYPE_FLOAT,
name="my.float.field",
),
),
op=AggregationComparisonFilter.OP_GREATER_THAN,
val=350,
)
),
AggregationFilter(
comparison_filter=AggregationComparisonFilter(
aggregation=AttributeAggregation(
aggregate=Function.FUNCTION_SUM,
key=AttributeKey(
type=AttributeKey.TYPE_FLOAT,
name="my.float.field",
),
),
op=AggregationComparisonFilter.OP_GREATER_THAN,
val=350,
)
),
]
)
),
)
response = EndpointTraceItemTable().execute(base_message)
assert response.column_values == [
TraceItemColumnValues(
attribute_name="kylestag",
results=[
AttributeValue(val_str="val2"),
AttributeValue(val_str="val3"),
],
),
TraceItemColumnValues(
attribute_name="sum(my.float.field)",
results=[
AttributeValue(val_float=1214.4),
AttributeValue(val_float=3036),
],
),
]

# now try with an or filter
base_message.aggregation_filter.CopyFrom(
AggregationFilter(
or_filter=AggregationOrFilter(
filters=[
AggregationFilter(
comparison_filter=AggregationComparisonFilter(
aggregation=AttributeAggregation(
aggregate=Function.FUNCTION_SUM,
key=AttributeKey(
type=AttributeKey.TYPE_FLOAT,
name="my.float.field",
),
),
op=AggregationComparisonFilter.OP_GREATER_THAN,
val=350,
)
),
AggregationFilter(
comparison_filter=AggregationComparisonFilter(
aggregation=AttributeAggregation(
aggregate=Function.FUNCTION_SUM,
key=AttributeKey(
type=AttributeKey.TYPE_FLOAT,
name="my.float.field",
),
),
op=AggregationComparisonFilter.OP_LESS_THAN,
val=350,
)
),
]
)
)
)
response = EndpointTraceItemTable().execute(base_message)
assert response.column_values == [
TraceItemColumnValues(
attribute_name="kylestag",
results=[
AttributeValue(val_str="val1"),
AttributeValue(val_str="val2"),
AttributeValue(val_str="val3"),
],
),
TraceItemColumnValues(
attribute_name="sum(my.float.field)",
results=[
AttributeValue(val_float=303.6),
AttributeValue(val_float=1214.4),
AttributeValue(val_float=3036),
],
),
]

def test_bad_aggregation_filter(self, setup_teardown: Any) -> None:
"""
This test ensures that an error is raised when the aggregation filter is invalid.
"""
# first I write new messages with different value of kylestags,
# theres a different number of messages for each tag so that
# each will have a different sum value when i do aggregate
spans_storage = get_storage(StorageKey("eap_spans"))
msg_timestamp = BASE_TIME - timedelta(minutes=1)
messages = (
[gen_message(msg_timestamp, tags={"kylestag": "val1"}) for i in range(3)]
+ [gen_message(msg_timestamp, tags={"kylestag": "val2"}) for i in range(12)]
+ [gen_message(msg_timestamp, tags={"kylestag": "val3"}) for i in range(30)]
)
write_raw_unprocessed_events(spans_storage, messages) # type: ignore

ts = Timestamp(seconds=int(BASE_TIME.timestamp()))
hour_ago = int((BASE_TIME - timedelta(hours=1)).timestamp())

message = TraceItemTableRequest(
meta=RequestMeta(
project_ids=[1, 2, 3],
organization_id=1,
cogs_category="something",
referrer="something",
start_timestamp=Timestamp(seconds=hour_ago),
end_timestamp=ts,
trace_item_name=TraceItemName.TRACE_ITEM_NAME_EAP_SPANS,
),
filter=TraceItemFilter(
exists_filter=ExistsFilter(
key=AttributeKey(type=AttributeKey.TYPE_STRING, name="kylestag")
)
),
columns=[
Column(
key=AttributeKey(type=AttributeKey.TYPE_STRING, name="kylestag")
),
Column(
aggregation=AttributeAggregation(
aggregate=Function.FUNCTION_SUM,
key=AttributeKey(
type=AttributeKey.TYPE_FLOAT, name="my.float.field"
),
label="sum(my.float.field)",
extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE,
),
),
],
group_by=[AttributeKey(type=AttributeKey.TYPE_STRING, name="kylestag")],
order_by=[
TraceItemTableRequest.OrderBy(
column=Column(
key=AttributeKey(type=AttributeKey.TYPE_STRING, name="kylestag")
)
),
],
aggregation_filter=AggregationFilter(
comparison_filter=AggregationComparisonFilter(
aggregation=AttributeAggregation(
label="sum(my.float.field)",
),
op=AggregationComparisonFilter.OP_GREATER_THAN,
val=350,
)
),
)
with pytest.raises(BadSnubaRPCRequestException):
EndpointTraceItemTable().execute(message)


class TestUtils:
def test_apply_labels_to_columns(self) -> None:
Expand Down

0 comments on commit daa210c

Please sign in to comment.