diff --git a/snuba/web/rpc/v1/endpoint_get_traces.py b/snuba/web/rpc/v1/endpoint_get_traces.py index f0a3e40a6d..464778d6a3 100644 --- a/snuba/web/rpc/v1/endpoint_get_traces.py +++ b/snuba/web/rpc/v1/endpoint_get_traces.py @@ -69,15 +69,19 @@ AttributeKey.Type.TYPE_STRING, ), } - -_NAME_TO_ATTRIBUTE: dict[str, TraceAttribute.Key.ValueType] = { - v[0]: k for k, v in _ATTRIBUTES.items() -} - -_TYPES_TO_CLICKHOUSE: dict[AttributeKey.Type.ValueType, str] = { - AttributeKey.Type.TYPE_STRING: "String", - AttributeKey.Type.TYPE_INT: "Int64", - AttributeKey.Type.TYPE_FLOAT: "Float64", +_TYPES_TO_CLICKHOUSE: dict[AttributeKey.Type.ValueType, tuple[str, Callable]] = { + AttributeKey.Type.TYPE_STRING: ( + "String", + lambda x: AttributeValue(val_str=str(x)), + ), + AttributeKey.Type.TYPE_INT: ( + "Int64", + lambda x: AttributeValue(val_int=int(x)), + ), + AttributeKey.Type.TYPE_FLOAT: ( + "Float64", + lambda x: AttributeValue(val_float=float(x)), + ), } @@ -98,7 +102,7 @@ def _attribute_to_expression( attribute = _ATTRIBUTES[trace_attribute.key] return f.cast( f.min(column("start_timestamp")), - _TYPES_TO_CLICKHOUSE[attribute[1]], + _TYPES_TO_CLICKHOUSE[attribute[1]][0], alias=_ATTRIBUTES[trace_attribute.key][0], ) if trace_attribute.key == TraceAttribute.Key.KEY_ROOT_SPAN_NAME: @@ -112,7 +116,7 @@ def _attribute_to_expression( attribute = _ATTRIBUTES[trace_attribute.key] return f.cast( column(attribute[0]), - _TYPES_TO_CLICKHOUSE[attribute[1]], + _TYPES_TO_CLICKHOUSE[attribute[1]][0], alias=attribute[0], ) raise BadSnubaRPCRequestException( @@ -147,27 +151,6 @@ def _build_snuba_request(request: GetTracesRequest, query: Query) -> SnubaReques def _convert_results( request: GetTracesRequest, data: Iterable[Dict[str, Any]] ) -> list[GetTracesResponse.Trace]: - converters: Dict[ - TraceAttribute.Key.ValueType, - Callable[ - [Any], - AttributeValue, - ], - ] = {} - - for trace_attribute in request.attributes: - attribute_type = _ATTRIBUTES[trace_attribute.key][1] - if attribute_type == AttributeKey.TYPE_BOOLEAN: - converters[trace_attribute.key] = lambda x: AttributeValue(val_bool=bool(x)) - elif attribute_type == AttributeKey.TYPE_STRING: - converters[trace_attribute.key] = lambda x: AttributeValue(val_str=str(x)) - elif attribute_type == AttributeKey.TYPE_INT: - converters[trace_attribute.key] = lambda x: AttributeValue(val_int=int(x)) - elif attribute_type == AttributeKey.TYPE_FLOAT: - converters[trace_attribute.key] = lambda x: AttributeValue( - val_float=float(x) - ) - res: list[GetTracesResponse.Trace] = [] column_ordering = { trace_attribute.key: i for i, trace_attribute in enumerate(request.attributes) @@ -178,12 +161,13 @@ def _convert_results( TraceAttribute.Key.ValueType, TraceAttribute, ] = defaultdict(TraceAttribute) - for column_name, value in row.items(): - key = _NAME_TO_ATTRIBUTE[column_name] - values[key] = TraceAttribute( - key=key, - value=converters[key](value), - type=_ATTRIBUTES[key][1], + for attribute in request.attributes: + value = row[_ATTRIBUTES[attribute.key][0]] + type = _ATTRIBUTES[attribute.key][1] + values[attribute.key] = TraceAttribute( + key=attribute.key, + value=_TYPES_TO_CLICKHOUSE[type][1](value), + type=type, ) res.append( GetTracesResponse.Trace( @@ -345,7 +329,24 @@ def _get_metadata_for_traces( ) selected_columns: list[SelectedExpression] = [] + start_timestamp_requested = False for trace_attribute in request.attributes: + if trace_attribute.key == TraceAttribute.Key.KEY_START_TIMESTAMP: + start_timestamp_requested = True + selected_columns.append( + SelectedExpression( + name=_ATTRIBUTES[trace_attribute.key][0], + expression=_attribute_to_expression( + trace_attribute, + trace_item_filters_expression, + ), + ) + ) + + # Since we're always ordering by start_timestamp, we need to request + # the field unless it's already been requested. + if not start_timestamp_requested: + trace_attribute = TraceAttribute(key=TraceAttribute.Key.KEY_START_TIMESTAMP) selected_columns.append( SelectedExpression( name=_ATTRIBUTES[trace_attribute.key][0], diff --git a/tests/web/rpc/v1/test_endpoint_get_traces.py b/tests/web/rpc/v1/test_endpoint_get_traces.py index 46d28c0442..64e2c6c6e2 100644 --- a/tests/web/rpc/v1/test_endpoint_get_traces.py +++ b/tests/web/rpc/v1/test_endpoint_get_traces.py @@ -186,6 +186,17 @@ def test_without_data(self) -> None: def test_with_data(self, setup_teardown: Any) -> None: ts = Timestamp(seconds=int(_BASE_TIME.timestamp())) three_hours_later = int((_BASE_TIME + timedelta(hours=3)).timestamp()) + start_timestamp_per_trace_id: dict[str, float] = defaultdict(lambda: 2 * 1e10) + for s in _SPANS: + start_timestamp_per_trace_id[s["trace_id"]] = min( + start_timestamp_per_trace_id[s["trace_id"]], + s["start_timestamp_precise"], + ) + trace_id_per_start_timestamp: dict[float, str] = { + timestamp: trace_id + for trace_id, timestamp in start_timestamp_per_trace_id.items() + } + message = GetTracesRequest( meta=RequestMeta( project_ids=[1, 2, 3], @@ -211,12 +222,14 @@ def test_with_data(self, setup_teardown: Any) -> None: key=TraceAttribute.Key.KEY_TRACE_ID, type=AttributeKey.TYPE_STRING, value=AttributeValue( - val_str=trace_id, + val_str=trace_id_per_start_timestamp[start_timestamp], ), ), ], ) - for trace_id in sorted(_TRACE_IDS) + for start_timestamp in reversed( + sorted(trace_id_per_start_timestamp.keys()) + ) ], page_token=PageToken(offset=len(_TRACE_IDS)), meta=ResponseMeta(request_id=_REQUEST_ID),