Skip to content

Commit

Permalink
⚡ Implement cardinality and percentile aggregations (#14)
Browse files Browse the repository at this point in the history
Also refactors the `FieldDefinitionBuilder` in order to support specifically the percentile aggregation as it does not return a single value like the existing aggregation fields.
  • Loading branch information
mewil authored Dec 30, 2020
1 parent 93edf4e commit 5bc841f
Show file tree
Hide file tree
Showing 11 changed files with 177 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ public SearchResponse queryWithAggregation(final String index, final List<String
case AVG -> sourceBuilder.aggregation(AggregationBuilders.avg(field).field(field));
case MAX -> sourceBuilder.aggregation(AggregationBuilders.max(field).field(field));
case MIN -> sourceBuilder.aggregation(AggregationBuilders.min(field).field(field));
case CARDINALITY -> sourceBuilder.aggregation(AggregationBuilders.cardinality(field).field(field));
case PERCENTILES -> sourceBuilder.aggregation(AggregationBuilders.percentiles(field).field(field));
}
});
return doQueryFromSearchSourceBuilder(index, sourceBuilder);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import io.mewil.sturgeon.schema.argument.BooleanQueryArgumentBuilder;
import io.mewil.sturgeon.schema.argument.IdArgumentBuilder;
import io.mewil.sturgeon.schema.argument.SizeArgumentBuilder;
import io.mewil.sturgeon.schema.resolver.IndexByIdDataFetcherBuilder;
import io.mewil.sturgeon.schema.resolver.IndexDataFetcherBuilder;
import io.mewil.sturgeon.schema.resolver.DocumentByIdDataFetcherBuilder;
import io.mewil.sturgeon.schema.resolver.DocumentDataFetcherBuilder;
import io.mewil.sturgeon.schema.types.DocumentAggregationTypeBuilder;
import io.mewil.sturgeon.schema.types.DocumentTypeBuilder;
import io.mewil.sturgeon.schema.util.NameNormalizer;
Expand Down Expand Up @@ -65,13 +65,13 @@ private static Stream<GraphQLFieldDefinition> buildSchemasFromIndexMapping(
.type(GraphQLList.list(documentType))
.argument(new SizeArgumentBuilder().build())
.argument(booleanQueryArguments)
.dataFetcher(new IndexDataFetcherBuilder(normalizedIndexName).build())
.dataFetcher(new DocumentDataFetcherBuilder(normalizedIndexName).build())
.build(),
new GraphQLFieldDefinition.Builder()
.name(String.format("%s_by_id", normalizedIndexName))
.type(documentType)
.argument(new IdArgumentBuilder().build())
.dataFetcher(new IndexByIdDataFetcherBuilder(normalizedIndexName).build())
.dataFetcher(new DocumentByIdDataFetcherBuilder(normalizedIndexName).build())
.build());

return Configuration.getInstance().getEnableAggregations()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package io.mewil.sturgeon.schema.resolver;

import com.google.common.collect.Streams;
import graphql.schema.DataFetcher;
import io.mewil.sturgeon.elasticsearch.ElasticsearchClient;
import io.mewil.sturgeon.elasticsearch.QueryAdapter;
import io.mewil.sturgeon.schema.types.AggregationType;
import io.mewil.sturgeon.schema.types.KeyedResponse;
import io.mewil.sturgeon.schema.util.NameNormalizer;
import io.mewil.sturgeon.schema.util.QueryFieldSelector;
import io.mewil.sturgeon.schema.util.QueryFieldSelectorResult;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.search.aggregations.Aggregation;
import org.elasticsearch.search.aggregations.metrics.ParsedAvg;
import org.elasticsearch.search.aggregations.metrics.ParsedCardinality;
import org.elasticsearch.search.aggregations.metrics.ParsedMax;
import org.elasticsearch.search.aggregations.metrics.ParsedMin;
import org.elasticsearch.search.aggregations.metrics.ParsedPercentiles;

import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;

public class DocumentAggregationDataFetcherBuilder extends DataFetcherBuilder {

private final String index;
private final AggregationType aggregationType;

public DocumentAggregationDataFetcherBuilder(String index, AggregationType aggregationType) {
this.index = index;
this.aggregationType = aggregationType;
}

@Override
public DataFetcher<Map<String, Object>> build() {
return dataFetchingEnvironment -> {
final QueryFieldSelectorResult selectorResult =
QueryFieldSelector.getSelectedFieldsFromQuery(dataFetchingEnvironment);
final SearchResponse response = ElasticsearchClient.getInstance().queryWithAggregation(
index,
selectorResult.getFields(),
QueryAdapter.buildQueryFromArguments(dataFetchingEnvironment.getExecutionStepInfo().getParent().getArguments()),
aggregationType);
return response.getAggregations().asMap().entrySet().stream()
.map(DocumentAggregationDataFetcherBuilder::normalizeAggregationTypes)
.map(DocumentAggregationDataFetcherBuilder::changeInfinityToNull)
// Fix to avoid null values here causing exceptions
// https://stackoverflow.com/questions/24630963/java-8-nullpointerexception-in-collectors-tomap
.collect(HashMap::new, (m, e) -> m.put(NameNormalizer.getInstance().getGraphQLName(e.getKey()), e.getValue()), HashMap::putAll);
};
}

private static Map.Entry<String, Object> normalizeAggregationTypes(final Map.Entry<String, Aggregation> field) {
return switch (AggregationType.fromAggregationType(field.getValue().getType())) {
case AVG -> new HashMap.SimpleEntry<>(field.getKey(), ((ParsedAvg) field.getValue()).getValue());
case MAX -> new HashMap.SimpleEntry<>(field.getKey(), ((ParsedMax) field.getValue()).getValue());
case MIN -> new HashMap.SimpleEntry<>(field.getKey(), ((ParsedMin) field.getValue()).getValue());
case CARDINALITY -> new HashMap.SimpleEntry<>(field.getKey(), ((ParsedCardinality) field.getValue()).getValue());
case PERCENTILES -> new HashMap.SimpleEntry<>(field.getKey(), Streams.stream((ParsedPercentiles) field.getValue())
.map(KeyedResponse::new)
.collect(Collectors.toList()));
};
}

// TODO: make this more elegant
private static Map.Entry<String, Object> changeInfinityToNull(final Map.Entry<String, Object> entry) {
if (entry.getValue() instanceof Double && ((Double) entry.getValue()).isInfinite()) {
return new HashMap.SimpleEntry<>(entry.getKey(), null);
}
return entry;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

import java.util.Map;

public class IndexByIdDataFetcherBuilder extends DataFetcherBuilder {
public IndexByIdDataFetcherBuilder(String index) {
public class DocumentByIdDataFetcherBuilder extends DataFetcherBuilder {
public DocumentByIdDataFetcherBuilder(String index) {
this.index = index;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import static io.mewil.sturgeon.elasticsearch.QueryAdapter.buildQueryFromArguments;
import static io.mewil.sturgeon.elasticsearch.ElasticsearchDecoder.decodeElasticsearchDoc;

public class IndexDataFetcherBuilder extends DataFetcherBuilder {
public IndexDataFetcherBuilder(String index) {
public class DocumentDataFetcherBuilder extends DataFetcherBuilder {
public DocumentDataFetcherBuilder(String index) {
this.index = index;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,20 @@
public enum AggregationType {
AVG,
MAX,
MIN;
MIN,
CARDINALITY,
PERCENTILES;

private static final String PERCENTILE_TYPE = "tdigest_percentiles";

public String getName() {
return name().toLowerCase();
}

public static AggregationType fromAggregationType(final String name) {
if (PERCENTILE_TYPE.equals(name)) {
return PERCENTILES;
}
return AggregationType.valueOf(name.toUpperCase());
}
}
Original file line number Diff line number Diff line change
@@ -1,19 +1,8 @@
package io.mewil.sturgeon.schema.types;

import graphql.Scalars;
import graphql.schema.DataFetcher;
import graphql.schema.GraphQLFieldDefinition;
import graphql.schema.GraphQLObjectType;
import io.mewil.sturgeon.elasticsearch.ElasticsearchClient;
import io.mewil.sturgeon.elasticsearch.QueryAdapter;
import io.mewil.sturgeon.schema.util.NameNormalizer;
import io.mewil.sturgeon.schema.util.QueryFieldSelector;
import io.mewil.sturgeon.schema.util.QueryFieldSelectorResult;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.search.aggregations.Aggregation;
import org.elasticsearch.search.aggregations.metrics.ParsedAvg;
import org.elasticsearch.search.aggregations.metrics.ParsedMax;
import org.elasticsearch.search.aggregations.metrics.ParsedMin;
import io.mewil.sturgeon.schema.resolver.DocumentAggregationDataFetcherBuilder;

import java.util.Arrays;
import java.util.HashMap;
Expand Down Expand Up @@ -57,7 +46,7 @@ private static GraphQLFieldDefinition buildFieldDefinitionsForAggregateType(fina
final Map<String, Object> properties = (HashMap<String, Object>) entry.getValue();
return properties.entrySet().stream();
})
.map(field -> FieldDefinitionBuilder.fromIndexMappingField(Scalars.GraphQLFloat, field))
.map(field -> FieldDefinitionBuilder.aggregateFieldDefinitionForDocumentField(field, aggregationType))
.filter(Optional::isPresent)
.map(Optional::get)
// .sorted() TODO: implement comparator
Expand All @@ -67,43 +56,7 @@ private static GraphQLFieldDefinition buildFieldDefinitionsForAggregateType(fina
return new GraphQLFieldDefinition.Builder()
.name(aggregationType.getName())
.type(doc)
.dataFetcher(getDataFetcherForIndexAggregation(index, aggregationType))
.dataFetcher(new DocumentAggregationDataFetcherBuilder(index, aggregationType).build())
.build();
}


private static DataFetcher<Map<String, Object>> getDataFetcherForIndexAggregation(final String index,
final AggregationType aggregationType) {
return dataFetchingEnvironment -> {
final QueryFieldSelectorResult selectorResult =
QueryFieldSelector.getSelectedFieldsFromQuery(dataFetchingEnvironment);
final SearchResponse response = ElasticsearchClient.getInstance().queryWithAggregation(
index,
selectorResult.getFields(),
QueryAdapter.buildQueryFromArguments(dataFetchingEnvironment.getExecutionStepInfo().getParent().getArguments()),
aggregationType);
return response.getAggregations().asMap().entrySet().stream()
.map(DocumentAggregationTypeBuilder::normalizeAggregationTypes)
.map(DocumentAggregationTypeBuilder::changeInfinityToNull)
// Fix to avoid null values here causing exceptions
// https://stackoverflow.com/questions/24630963/java-8-nullpointerexception-in-collectors-tomap
.collect(HashMap::new, (m, e) -> m.put(NameNormalizer.getInstance().getGraphQLName(e.getKey()), e.getValue()), HashMap::putAll);
};
}

private static Map.Entry<String, Object> normalizeAggregationTypes(final Map.Entry<String, Aggregation> field) {
return switch (field.getValue().getType()) {
case "avg" -> new HashMap.SimpleEntry<>(field.getKey(), ((ParsedAvg) field.getValue()).getValue());
case "max" -> new HashMap.SimpleEntry<>(field.getKey(), ((ParsedMax) field.getValue()).getValue());
default -> new HashMap.SimpleEntry<>(field.getKey(), ((ParsedMin) field.getValue()).getValue());
};
}

// TODO: make this more elegant
private static Map.Entry<String, Object> changeInfinityToNull(final Map.Entry<String, Object> entry) {
if (entry.getValue() instanceof Double && ((Double) entry.getValue()).isInfinite()) {
return new HashMap.SimpleEntry<>(entry.getKey(), null);
}
return entry;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ private List<GraphQLFieldDefinition> buildFieldDefinitionsFromIndexMapping(
final Map<String, Object> properties = (HashMap<String, Object>) entry.getValue();
return properties.entrySet().stream();
})
.map(field -> FieldDefinitionBuilder.fromIndexMappingField(null, field))
.map(FieldDefinitionBuilder::fieldDefinitionForDocumentField)
.filter(Optional::isPresent)
.map(Optional::get)
// .sorted() TODO: implement comparator
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package io.mewil.sturgeon.schema.types;

import com.google.common.collect.ImmutableSet;
import graphql.Scalars;
import graphql.schema.GraphQLFieldDefinition;
import graphql.schema.GraphQLList;
import graphql.schema.GraphQLObjectType;
import graphql.schema.GraphQLScalarType;
import io.mewil.sturgeon.schema.SchemaConstants;
import io.mewil.sturgeon.elasticsearch.ElasticsearchDecoder;
import io.mewil.sturgeon.schema.SchemaConstants;
import io.mewil.sturgeon.schema.util.NameNormalizer;

import java.util.HashMap;
Expand All @@ -12,22 +16,72 @@

public final class FieldDefinitionBuilder {

public static Optional<GraphQLFieldDefinition> fromIndexMappingField(
final GraphQLScalarType typeFilter, final Map.Entry<String, Object> field) {
// TODO: Add support for nested types. Type will be null if we have a nested type, ignore for now.
private static Optional<GraphQLScalarType> getScalarTypeFromField(
final Map.Entry<String, Object> field) {
// TODO: Refactor unchecked cast
// TODO: Add support for nested types. Type will be null if we have a nested type, ignore for
// now.
final Object type = ((HashMap<String, Object>) field.getValue()).get(SchemaConstants.TYPE);
if (type == null) {
return Optional.empty();
}
GraphQLScalarType scalarType = ElasticsearchDecoder.mapToGraphQLScalarType(type.toString());
if (typeFilter != null && scalarType != typeFilter) {
return Optional.ofNullable(ElasticsearchDecoder.mapToGraphQLScalarType(type.toString()));
}

public static Optional<GraphQLFieldDefinition> fieldDefinitionForDocumentField(
final Map.Entry<String, Object> field) {
final Optional<GraphQLScalarType> scalarType = getScalarTypeFromField(field);
if (scalarType.isEmpty()) {
return Optional.empty();
}
NameNormalizer.getInstance().addName(field.getKey());
final String normalizedName = NameNormalizer.getInstance().getGraphQLName(field.getKey());
return Optional.of(
GraphQLFieldDefinition.newFieldDefinition().name(normalizedName).type(scalarType).build());
GraphQLFieldDefinition.newFieldDefinition()
.name(NameNormalizer.getInstance().getGraphQLName(field.getKey()))
.type(scalarType.get())
.build());
}

private static final ImmutableSet<GraphQLScalarType> SUPPORTED_AGGREGATION_SCALARS =
ImmutableSet.of(
Scalars.GraphQLFloat, Scalars.GraphQLShort, Scalars.GraphQLInt, Scalars.GraphQLLong);

public static Optional<GraphQLFieldDefinition> aggregateFieldDefinitionForDocumentField(
final Map.Entry<String, Object> field, final AggregationType aggregationType) {

final Optional<GraphQLScalarType> scalarType = getScalarTypeFromField(field);
if (scalarType.isEmpty() || !SUPPORTED_AGGREGATION_SCALARS.contains(scalarType.get())) {
return Optional.empty();
}

final GraphQLFieldDefinition.Builder builder = GraphQLFieldDefinition.newFieldDefinition()
.name(NameNormalizer.getInstance().getGraphQLName(field.getKey()));

switch (aggregationType) {
case PERCENTILES:
return Optional.of(builder
.type(GraphQLList.list(getKeyedResponseType(scalarType.get())))
.build());
default:
return Optional.of(builder
.type(scalarType.get())
.build());
}
}

private static GraphQLObjectType getKeyedResponseType(final GraphQLScalarType scalarType) {
return keyedResponseTypes.computeIfAbsent(scalarType, type -> GraphQLObjectType.newObject()
.name(String.format("keyed_%s", scalarType.getName().toLowerCase()))
.field(GraphQLFieldDefinition.newFieldDefinition()
.name("key")
.type(Scalars.GraphQLString)
.build())
.field(GraphQLFieldDefinition.newFieldDefinition()
.name("value")
.type(type)
.build())
.build());
}

private static final Map<GraphQLScalarType, GraphQLObjectType> keyedResponseTypes = new HashMap<>();

}
15 changes: 15 additions & 0 deletions src/main/java/io/mewil/sturgeon/schema/types/KeyedResponse.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package io.mewil.sturgeon.schema.types;

import lombok.Data;
import org.elasticsearch.search.aggregations.metrics.Percentile;

@Data
public class KeyedResponse {
private final String key;
private final Double value;

public KeyedResponse(final Percentile percentile) {
key = String.valueOf(percentile.getPercent());
value = percentile.getValue();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ public static QueryFieldSelectorResult getSelectedFieldsFromQuery(
final DataFetchingEnvironment dataFetchingEnvironment) {
final List<String> selectedGraphQLFields =
dataFetchingEnvironment.getSelectionSet().getFields().stream()
.filter(f -> List.of("/key", "/value").stream().noneMatch(s -> f.getQualifiedName().endsWith(s)))
.map(SelectedField::getName)
.collect(Collectors.toList());
final boolean includeId = selectedGraphQLFields.remove(SchemaConstants.ID);
Expand Down

0 comments on commit 5bc841f

Please sign in to comment.