Skip to content

Commit

Permalink
ok this method is jank but its still functional
Browse files Browse the repository at this point in the history
  • Loading branch information
Rachel Chen authored and Rachel Chen committed Jan 16, 2025
1 parent ca2cf30 commit 9ce32e8
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 42 deletions.
6 changes: 6 additions & 0 deletions snuba/web/rpc/common/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,8 @@ def trace_item_filters_to_expression(item_filter: TraceItemFilter) -> Expression
v_expression = literal(v.val_str)
case "val_float":
v_expression = literal(v.val_float)
case "val_double":
v_expression = literal(v.val_double)
case "val_int":
v_expression = literal(v.val_int)
case "val_null":
Expand All @@ -310,6 +312,10 @@ def trace_item_filters_to_expression(item_filter: TraceItemFilter) -> Expression
v_expression = literals_array(
None, list(map(lambda x: literal(x), v.val_float_array.values))
)
case "val_double_array":
v_expression = literals_array(
None, list(map(lambda x: literal(x), v.val_double_array.values))
)
case default:
raise NotImplementedError(
f"translation of AttributeValue type {default} is not implemented"
Expand Down
54 changes: 49 additions & 5 deletions snuba/web/rpc/v1/endpoint_get_traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@
_DEFAULT_ROW_LIMIT = 10_000
_BUFFER_WINDOW = 2 * 3600 # 2 hours


def _convert_key_to_support_doubles_and_floats_for_backward_compat(
key: TraceAttribute.Key.ValueType,
) -> TraceAttribute.Key.ValueType:
return TraceAttribute.Key.ValueType(-1 * key)


_ATTRIBUTES: dict[
TraceAttribute.Key.ValueType,
tuple[str, AttributeKey.Type.ValueType],
Expand All @@ -69,6 +76,16 @@
AttributeKey.Type.TYPE_STRING,
),
}
# for every AttributeKey of TYPE_FLOAT a user may add during the backward compat period, this adds the TYPE_DOUBLE equivalent
_attributes_backward_compat = dict()
for k in _ATTRIBUTES:
v = _ATTRIBUTES[k]
if v[1] == AttributeKey.Type.TYPE_FLOAT:
_attributes_backward_compat[
_convert_key_to_support_doubles_and_floats_for_backward_compat(k)
] = (v[0], AttributeKey.Type.TYPE_DOUBLE)
_ATTRIBUTES.update(_attributes_backward_compat)

_TYPES_TO_CLICKHOUSE: dict[
AttributeKey.Type.ValueType,
tuple[str, Callable[[Any], AttributeValue]],
Expand All @@ -85,6 +102,10 @@
"Float64",
lambda x: AttributeValue(val_float=float(x)),
),
AttributeKey.Type.TYPE_DOUBLE: (
"Float64",
lambda x: AttributeValue(val_double=float(x)),
),
}


Expand All @@ -102,11 +123,19 @@ def _attribute_to_expression(
alias=_ATTRIBUTES[trace_attribute.key][0],
)
if trace_attribute.key == TraceAttribute.Key.KEY_START_TIMESTAMP:
attribute = _ATTRIBUTES[trace_attribute.key]
attribute = (
_ATTRIBUTES[
_convert_key_to_support_doubles_and_floats_for_backward_compat(
trace_attribute.key
)
]
if trace_attribute.type == AttributeKey.Type.TYPE_DOUBLE
else _ATTRIBUTES[trace_attribute.key]
)
return f.cast(
f.min(column("start_timestamp")),
_TYPES_TO_CLICKHOUSE[attribute[1]][0],
alias=_ATTRIBUTES[trace_attribute.key][0],
alias=attribute[0],
)
if trace_attribute.key == TraceAttribute.Key.KEY_ROOT_SPAN_NAME:
# TODO: Change to return the root span name instead of the trace's first span's name.
Expand All @@ -116,7 +145,15 @@ def _attribute_to_expression(
alias=_ATTRIBUTES[trace_attribute.key][0],
)
if trace_attribute.key in _ATTRIBUTES:
attribute = _ATTRIBUTES[trace_attribute.key]
attribute = (
_ATTRIBUTES[
_convert_key_to_support_doubles_and_floats_for_backward_compat(
trace_attribute.key
)
]
if trace_attribute.type == AttributeKey.Type.TYPE_DOUBLE
else _ATTRIBUTES[trace_attribute.key]
)
return f.cast(
column(attribute[0]),
_TYPES_TO_CLICKHOUSE[attribute[1]][0],
Expand Down Expand Up @@ -165,8 +202,15 @@ def _convert_results(
TraceAttribute,
] = defaultdict(TraceAttribute)
for attribute in request.attributes:
value = row[_ATTRIBUTES[attribute.key][0]]
type = _ATTRIBUTES[attribute.key][1]
backward_compat_attribute_key = (
_convert_key_to_support_doubles_and_floats_for_backward_compat(
attribute.key
)
if attribute.type == AttributeKey.Type.TYPE_DOUBLE
else attribute.key
)
value = row[_ATTRIBUTES[backward_compat_attribute_key][0]]
type = _ATTRIBUTES[backward_compat_attribute_key][1]
values[attribute.key] = TraceAttribute(
key=attribute.key,
value=_TYPES_TO_CLICKHOUSE[type][1](value),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,10 @@ def _convert_results(
elif column.key.type == AttributeKey.TYPE_DOUBLE:
converters[column.label] = lambda x: AttributeValue(val_double=float(x))
elif column.HasField("aggregation"):
converters[column.label] = lambda x: AttributeValue(val_float=float(x))
if column.key.type == AttributeKey.TYPE_FLOAT:
converters[column.label] = lambda x: AttributeValue(val_float=float(x))
if column.key.type == AttributeKey.TYPE_DOUBLE:
converters[column.label] = lambda x: AttributeValue(val_double=float(x))
else:
raise BadSnubaRPCRequestException(
"column is neither an attribute or aggregation"
Expand Down
2 changes: 1 addition & 1 deletion tests/web/rpc/v1/test_endpoint_get_traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def test_with_data_and_aggregated_fields(self, setup_teardown: Any) -> None:
key=TraceAttribute.Key.KEY_START_TIMESTAMP,
type=AttributeKey.TYPE_FLOAT,
value=AttributeValue(
val_float=start_timestamp_per_trace_id[
val_double=start_timestamp_per_trace_id[
trace_id_per_start_timestamp[start_timestamp]
],
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,21 +249,21 @@ def test_booleans_and_number_compares(self, setup_teardown: Any) -> None:
TraceItemFilter(
comparison_filter=ComparisonFilter(
key=AttributeKey(
type=AttributeKey.TYPE_FLOAT,
type=AttributeKey.TYPE_DOUBLE,
name="eap.measurement",
),
op=ComparisonFilter.OP_LESS_THAN_OR_EQUALS,
value=AttributeValue(val_float=101),
value=AttributeValue(val_double=101),
),
),
TraceItemFilter(
comparison_filter=ComparisonFilter(
key=AttributeKey(
type=AttributeKey.TYPE_FLOAT,
type=AttributeKey.TYPE_DOUBLE,
name="eap.measurement",
),
op=ComparisonFilter.OP_GREATER_THAN,
value=AttributeValue(val_float=999),
value=AttributeValue(val_double=999),
),
),
]
Expand Down Expand Up @@ -486,7 +486,7 @@ def test_table_with_aggregates(self, setup_teardown: Any) -> None:
aggregation=AttributeAggregation(
aggregate=Function.FUNCTION_MAX,
key=AttributeKey(
type=AttributeKey.TYPE_FLOAT, name="my.float.field"
type=AttributeKey.TYPE_DOUBLE, name="my.float.field"
),
label="max(my.float.field)",
extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE,
Expand All @@ -496,7 +496,7 @@ def test_table_with_aggregates(self, setup_teardown: Any) -> None:
aggregation=AttributeAggregation(
aggregate=Function.FUNCTION_AVG,
key=AttributeKey(
type=AttributeKey.TYPE_FLOAT, name="my.float.field"
type=AttributeKey.TYPE_DOUBLE, name="my.float.field"
),
label="avg(my.float.field)",
extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE,
Expand Down Expand Up @@ -526,17 +526,17 @@ def test_table_with_aggregates(self, setup_teardown: Any) -> None:
TraceItemColumnValues(
attribute_name="max(my.float.field)",
results=[
AttributeValue(val_float=101.2),
AttributeValue(val_float=101.2),
AttributeValue(val_float=101.2),
AttributeValue(val_double=101.2),
AttributeValue(val_double=101.2),
AttributeValue(val_double=101.2),
],
),
TraceItemColumnValues(
attribute_name="avg(my.float.field)",
results=[
AttributeValue(val_float=101.2),
AttributeValue(val_float=101.2),
AttributeValue(val_float=101.2),
AttributeValue(val_double=101.2),
AttributeValue(val_double=101.2),
AttributeValue(val_double=101.2),
],
),
]
Expand All @@ -562,7 +562,7 @@ def test_table_with_columns_not_in_groupby(self, setup_teardown: Any) -> None:
aggregation=AttributeAggregation(
aggregate=Function.FUNCTION_MAX,
key=AttributeKey(
type=AttributeKey.TYPE_FLOAT, name="my.float.field"
type=AttributeKey.TYPE_DOUBLE, name="my.float.field"
),
label="max(my.float.field)",
extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE,
Expand Down Expand Up @@ -610,7 +610,7 @@ def test_order_by_non_selected(self) -> None:
aggregation=AttributeAggregation(
aggregate=Function.FUNCTION_AVG,
key=AttributeKey(
type=AttributeKey.TYPE_FLOAT, name="eap.measurement"
type=AttributeKey.TYPE_DOUBLE, name="eap.measurement"
),
label="avg(eap.measurment)",
extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE,
Expand All @@ -624,7 +624,7 @@ def test_order_by_non_selected(self) -> None:
aggregation=AttributeAggregation(
aggregate=Function.FUNCTION_MAX,
key=AttributeKey(
type=AttributeKey.TYPE_FLOAT, name="my.float.field"
type=AttributeKey.TYPE_DOUBLE, name="my.float.field"
),
label="max(my.float.field)",
extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE,
Expand Down Expand Up @@ -665,7 +665,7 @@ def test_order_by_aggregation(self, setup_teardown: Any) -> None:
aggregation=AttributeAggregation(
aggregate=Function.FUNCTION_AVG,
key=AttributeKey(
type=AttributeKey.TYPE_FLOAT, name="eap.measurement"
type=AttributeKey.TYPE_DOUBLE, name="eap.measurement"
),
label="avg(eap.measurment)",
extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE,
Expand All @@ -679,7 +679,7 @@ def test_order_by_aggregation(self, setup_teardown: Any) -> None:
aggregation=AttributeAggregation(
aggregate=Function.FUNCTION_AVG,
key=AttributeKey(
type=AttributeKey.TYPE_FLOAT, name="eap.measurement"
type=AttributeKey.TYPE_DOUBLE, name="eap.measurement"
),
label="avg(eap.measurment)",
extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE,
Expand All @@ -690,7 +690,7 @@ def test_order_by_aggregation(self, setup_teardown: Any) -> None:
limit=5,
)
response = EndpointTraceItemTable().execute(message)
measurements = [v.val_float for v in response.column_values[1].results]
measurements = [v.val_double for v in response.column_values[1].results]
assert sorted(measurements) == measurements

def test_aggregation_on_attribute_column(self) -> None:
Expand Down Expand Up @@ -727,7 +727,7 @@ def test_aggregation_on_attribute_column(self) -> None:
aggregation=AttributeAggregation(
aggregate=Function.FUNCTION_AVG,
key=AttributeKey(
type=AttributeKey.TYPE_FLOAT, name="custom_measurement"
type=AttributeKey.TYPE_DOUBLE, name="custom_measurement"
),
label="avg(custom_measurement)",
extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE,
Expand All @@ -738,7 +738,7 @@ def test_aggregation_on_attribute_column(self) -> None:
limit=5,
)
response = EndpointTraceItemTable().execute(message)
measurement_avg = [v.val_float for v in response.column_values[0].results][0]
measurement_avg = [v.val_double for v in response.column_values[0].results][0]
assert measurement_avg == 420

def test_different_column_label_and_attr_name(self, setup_teardown: Any) -> None:
Expand All @@ -763,7 +763,7 @@ def test_different_column_label_and_attr_name(self, setup_teardown: Any) -> None
aggregation=AttributeAggregation(
aggregate=Function.FUNCTION_COUNT,
key=AttributeKey(
type=AttributeKey.TYPE_FLOAT, name="sentry.duration_ms"
type=AttributeKey.TYPE_DOUBLE, name="sentry.duration_ms"
),
),
label="count()",
Expand Down Expand Up @@ -1049,7 +1049,7 @@ def test_apply_labels_to_columns(self) -> None:
aggregation=AttributeAggregation(
aggregate=Function.FUNCTION_AVG,
key=AttributeKey(
type=AttributeKey.TYPE_FLOAT, name="custom_measurement"
type=AttributeKey.TYPE_DOUBLE, name="custom_measurement"
),
label="avg(custom_measurement)",
extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE,
Expand All @@ -1059,7 +1059,7 @@ def test_apply_labels_to_columns(self) -> None:
aggregation=AttributeAggregation(
aggregate=Function.FUNCTION_AVG,
key=AttributeKey(
type=AttributeKey.TYPE_FLOAT, name="custom_measurement"
type=AttributeKey.TYPE_DOUBLE, name="custom_measurement"
),
label="avg(custom_measurement_2)",
extrapolation_mode=ExtrapolationMode.EXTRAPOLATION_MODE_NONE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,15 +207,15 @@ def test_aggregation_on_attribute_column(self) -> None:
limit=5,
)
response = EndpointTraceItemTable().execute(message)
measurement_sum = [v.val_float for v in response.column_values[0].results][0]
measurement_avg = [v.val_float for v in response.column_values[1].results][0]
measurement_sum = [v.val_double for v in response.column_values[0].results][0]
measurement_avg = [v.val_double for v in response.column_values[1].results][0]
measurement_count_custom_measurement = [
v.val_float for v in response.column_values[2].results
v.val_double for v in response.column_values[2].results
][0]
measurement_count_duration = [
v.val_float for v in response.column_values[3].results
v.val_double for v in response.column_values[3].results
][0]
measurement_p90 = [v.val_float for v in response.column_values[4].results][0]
measurement_p90 = [v.val_double for v in response.column_values[4].results][0]
assert measurement_sum == 98 # weighted sum - 0*1 + 1*2 + 2*4 + 3*8 + 4*16
assert (
abs(measurement_avg - 3.16129032) < 0.000001
Expand Down Expand Up @@ -277,7 +277,7 @@ def test_count_reliability(self) -> None:
limit=5,
)
response = EndpointTraceItemTable().execute(message)
measurement_count = [v.val_float for v in response.column_values[0].results][0]
measurement_count = [v.val_double for v in response.column_values[0].results][0]
measurement_reliability = [v for v in response.column_values[0].reliabilities][
0
]
Expand Down Expand Up @@ -380,15 +380,15 @@ def test_count_reliability_with_group_by(self) -> None:
measurement_tags = [v.val_str for v in response.column_values[0].results]
assert measurement_tags == ["foo", "bar"]

measurement_sums = [v.val_float for v in response.column_values[1].results]
measurement_sums = [v.val_double for v in response.column_values[1].results]
measurement_reliabilities = [v for v in response.column_values[1].reliabilities]
assert measurement_sums == [sum(range(5)), 0]
assert measurement_reliabilities == [
Reliability.RELIABILITY_LOW,
Reliability.RELIABILITY_UNSPECIFIED,
] # low reliability due to low sample count

measurement_avgs = [v.val_float for v in response.column_values[2].results]
measurement_avgs = [v.val_double for v in response.column_values[2].results]
measurement_reliabilities = [v for v in response.column_values[2].reliabilities]
assert len(measurement_avgs) == 2
assert measurement_avgs[0] == sum(range(5)) / 5
Expand All @@ -398,15 +398,15 @@ def test_count_reliability_with_group_by(self) -> None:
Reliability.RELIABILITY_UNSPECIFIED,
] # low reliability due to low sample count

measurement_counts = [v.val_float for v in response.column_values[3].results]
measurement_counts = [v.val_double for v in response.column_values[3].results]
measurement_reliabilities = [v for v in response.column_values[3].reliabilities]
assert measurement_counts == [5, 0]
assert measurement_reliabilities == [
Reliability.RELIABILITY_LOW,
Reliability.RELIABILITY_UNSPECIFIED,
] # low reliability due to low sample count

measurement_p90s = [v.val_float for v in response.column_values[4].results]
measurement_p90s = [v.val_double for v in response.column_values[4].results]
measurement_reliabilities = [v for v in response.column_values[4].reliabilities]
assert len(measurement_p90s) == 2
assert measurement_p90s[0] == 4
Expand Down
Loading

0 comments on commit 9ce32e8

Please sign in to comment.